|
15 | 15 | import dataclasses
|
16 | 16 | import enum
|
17 | 17 | import os
|
18 |
| -import subprocess |
19 | 18 | import sys
|
20 | 19 |
|
21 | 20 | parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
22 | 21 | sys.path.append(parent_dir)
|
23 | 22 |
|
| 23 | + |
| 24 | +from disruption_management.disruption_utils import execute_command_as_subprocess |
| 25 | +from disruption_management.disruption_utils import get_pod_name_from_regex |
24 | 26 | from xpk_configs import XpkClusterConfig
|
25 | 27 |
|
26 | 28 |
|
|
29 | 31 | PATHWAYS_STANDARD_TARGET_POD_REGEX_SUFFIX = ".*worker-0-0.*"
|
30 | 32 | PATHWAYS_STANDARD_STEP_POD_REGEX_SUFFIX = ".*main-0-0.*"
|
31 | 33 |
|
32 |
| -STANDARD_STEP_LOG_REGEX = "completed step: (\\d+)" |
33 |
| - |
34 | 34 | PATHWAYS_WORKER_CONTAINER_NAME = "pathways-worker"
|
35 | 35 | MCJAX_WORKER_CONTAINER_NAME = "jax-tpu"
|
36 | 36 |
|
@@ -71,6 +71,9 @@ class DisruptionConfig:
|
71 | 71 | # Target pod regex needed for triggering disruption.
|
72 | 72 | target_pod_regex: str = MCJAX_STANDARD_TARGET_POD_REGEX_SUFFIX
|
73 | 73 |
|
| 74 | + # Step pod regex needed for step-based disruption. |
| 75 | + step_pod_regex: str = MCJAX_STANDARD_STEP_POD_REGEX_SUFFIX |
| 76 | + |
74 | 77 |
|
75 | 78 | class DisruptionHandler(abc.ABC):
|
76 | 79 | """Abstract interface for disruption handlers."""
|
@@ -103,73 +106,50 @@ def trigger_disruption(
|
103 | 106 | f"🔥🔥🔥 Beginning SIGILL for workload: {workload_name} with pod regex:"
|
104 | 107 | f" {target_pod_regex} 🔥🔥🔥"
|
105 | 108 | )
|
106 |
| - pod_name_command = [ |
107 |
| - "kubectl", |
108 |
| - "get", |
109 |
| - "pods", |
110 |
| - "-o=name", |
111 |
| - "--no-headers", |
112 |
| - f"| grep -E '{target_pod_regex}'", |
113 |
| - ] |
114 |
| - pod_name_command_str = " ".join(pod_name_command) |
115 | 109 | container_name = disruption_config.worker_container_name
|
116 | 110 |
|
117 |
| - try: |
118 |
| - pod_name_process = subprocess.run( |
119 |
| - pod_name_command_str, |
120 |
| - shell=True, |
121 |
| - check=True, |
122 |
| - capture_output=True, |
123 |
| - text=True, |
124 |
| - ) |
125 |
| - pod_name = pod_name_process.stdout.strip() |
126 |
| - if not pod_name: |
127 |
| - print( |
128 |
| - f"Warning: No pod found matching regex: '{target_pod_regex}' for" |
129 |
| - f" workload '{workload_name}'" |
130 |
| - ) |
131 |
| - return |
132 |
| - |
133 |
| - print(f"🔍 Found pod: {pod_name}") |
134 |
| - kill_command = [ |
135 |
| - "kubectl", |
136 |
| - "exec", |
137 |
| - "-it", |
138 |
| - pod_name, |
139 |
| - "-c", |
140 |
| - container_name, |
141 |
| - "--", |
142 |
| - "/bin/bash", |
143 |
| - "-c", |
144 |
| - "\"kill -s SIGILL 1\"", |
145 |
| - ] |
146 |
| - kill_command_str = " ".join(kill_command) |
147 |
| - print(f"🔥🔥🔥 Executing command in pod: {kill_command_str} 🔥🔥🔥") |
148 |
| - subprocess.run( |
149 |
| - kill_command_str, |
150 |
| - shell=True, |
151 |
| - check=True, |
152 |
| - capture_output=True, |
153 |
| - text=True, |
154 |
| - ) |
155 |
| - print( |
156 |
| - f"✅ Successfully sent SIGILL to pod: {pod_name} in container:" |
157 |
| - f" {container_name}" |
158 |
| - ) |
159 |
| - |
160 |
| - except subprocess.CalledProcessError as e: |
161 |
| - print( |
162 |
| - "❌ Error sending SIGILL to pod(s) matching regex" |
163 |
| - f" '{target_pod_regex}' for workload '{workload_name}'" |
164 |
| - ) |
165 |
| - print(f"Return code: {e.returncode}") |
166 |
| - print(f"error: {e}") |
| 111 | + pod_name = get_pod_name_from_regex(workload_name, target_pod_regex) |
| 112 | + if not pod_name: |
| 113 | + return |
| 114 | + |
| 115 | + kill_command = ( |
| 116 | + f"kubectl exec -it {pod_name} -c {container_name} -- /bin/sh -c " |
| 117 | + f'"kill -s SIGILL 1"' |
| 118 | + ) |
| 119 | + print(f"🔥🔥🔥 Executing command in pod: {kill_command} 🔥🔥🔥") |
| 120 | + execute_command_as_subprocess(kill_command) |
| 121 | + |
| 122 | + |
| 123 | +class SIGTERMHandler(DisruptionHandler): |
| 124 | + """Handles SIGTERM disruption by sending a SIGTERM signal to the pod.""" |
| 125 | + |
| 126 | + def trigger_disruption( |
| 127 | + self, workload_name: str, cluster_config: XpkClusterConfig, |
| 128 | + disruption_config, target_pod_regex: str |
| 129 | + ) -> None: |
| 130 | + """Triggers SIGTERM disruption by executing kill -s SIGTERM 1 in the pod.""" |
| 131 | + print( |
| 132 | + f"🔥🔥🔥 Beginning SIGTERM for workload: {workload_name} with pod regex:" |
| 133 | + f" {target_pod_regex} 🔥🔥🔥" |
| 134 | + ) |
| 135 | + container_name = disruption_config.worker_container_name |
| 136 | + |
| 137 | + pod_name = get_pod_name_from_regex(workload_name, target_pod_regex) |
| 138 | + if not pod_name: |
| 139 | + return |
| 140 | + |
| 141 | + kill_command = ( |
| 142 | + f"kubectl exec -it {pod_name} -c {container_name} -- /bin/sh -c " |
| 143 | + f'"kill -s SIGTERM 1"' |
| 144 | + ) |
| 145 | + print(f"🔥🔥🔥 Executing command in pod: {kill_command} 🔥🔥🔥") |
| 146 | + execute_command_as_subprocess(kill_command) |
167 | 147 |
|
168 | 148 |
|
169 | 149 | def create_disruption_handler(disruption_config):
|
170 | 150 | """Factory function to create the appropriate disruption handler."""
|
171 | 151 | if disruption_config.disruption_method == DisruptionMethod.SIGTERM:
|
172 |
| - raise NotImplementedError("SIGTERM Disruption Handler not implemented yet.") |
| 152 | + return SIGTERMHandler() |
173 | 153 | elif disruption_config.disruption_method == DisruptionMethod.SIGILL:
|
174 | 154 | return SIGILLHandler()
|
175 | 155 | else:
|
|
0 commit comments