Skip to content

Commit 0b1fde4

Browse files
author
maxtext authors
committed
Merge pull request #1404 from AI-Hypercomputer:sujinesh/add_disruption_manager_sigterm
PiperOrigin-RevId: 742898218
2 parents a06dc23 + 1a71ed1 commit 0b1fde4

File tree

5 files changed

+406
-74
lines changed

5 files changed

+406
-74
lines changed

benchmarks/disruption_management/disruption_handler.py

Lines changed: 43 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@
1515
import dataclasses
1616
import enum
1717
import os
18-
import subprocess
1918
import sys
2019

2120
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
2221
sys.path.append(parent_dir)
2322

23+
24+
from disruption_management.disruption_utils import execute_command_as_subprocess
25+
from disruption_management.disruption_utils import get_pod_name_from_regex
2426
from xpk_configs import XpkClusterConfig
2527

2628

@@ -29,8 +31,6 @@
2931
PATHWAYS_STANDARD_TARGET_POD_REGEX_SUFFIX = ".*worker-0-0.*"
3032
PATHWAYS_STANDARD_STEP_POD_REGEX_SUFFIX = ".*main-0-0.*"
3133

32-
STANDARD_STEP_LOG_REGEX = "completed step: (\\d+)"
33-
3434
PATHWAYS_WORKER_CONTAINER_NAME = "pathways-worker"
3535
MCJAX_WORKER_CONTAINER_NAME = "jax-tpu"
3636

@@ -71,6 +71,9 @@ class DisruptionConfig:
7171
# Target pod regex needed for triggering disruption.
7272
target_pod_regex: str = MCJAX_STANDARD_TARGET_POD_REGEX_SUFFIX
7373

74+
# Step pod regex needed for step-based disruption.
75+
step_pod_regex: str = MCJAX_STANDARD_STEP_POD_REGEX_SUFFIX
76+
7477

7578
class DisruptionHandler(abc.ABC):
7679
"""Abstract interface for disruption handlers."""
@@ -103,73 +106,50 @@ def trigger_disruption(
103106
f"🔥🔥🔥 Beginning SIGILL for workload: {workload_name} with pod regex:"
104107
f" {target_pod_regex} 🔥🔥🔥"
105108
)
106-
pod_name_command = [
107-
"kubectl",
108-
"get",
109-
"pods",
110-
"-o=name",
111-
"--no-headers",
112-
f"| grep -E '{target_pod_regex}'",
113-
]
114-
pod_name_command_str = " ".join(pod_name_command)
115109
container_name = disruption_config.worker_container_name
116110

117-
try:
118-
pod_name_process = subprocess.run(
119-
pod_name_command_str,
120-
shell=True,
121-
check=True,
122-
capture_output=True,
123-
text=True,
124-
)
125-
pod_name = pod_name_process.stdout.strip()
126-
if not pod_name:
127-
print(
128-
f"Warning: No pod found matching regex: '{target_pod_regex}' for"
129-
f" workload '{workload_name}'"
130-
)
131-
return
132-
133-
print(f"🔍 Found pod: {pod_name}")
134-
kill_command = [
135-
"kubectl",
136-
"exec",
137-
"-it",
138-
pod_name,
139-
"-c",
140-
container_name,
141-
"--",
142-
"/bin/bash",
143-
"-c",
144-
"\"kill -s SIGILL 1\"",
145-
]
146-
kill_command_str = " ".join(kill_command)
147-
print(f"🔥🔥🔥 Executing command in pod: {kill_command_str} 🔥🔥🔥")
148-
subprocess.run(
149-
kill_command_str,
150-
shell=True,
151-
check=True,
152-
capture_output=True,
153-
text=True,
154-
)
155-
print(
156-
f"✅ Successfully sent SIGILL to pod: {pod_name} in container:"
157-
f" {container_name}"
158-
)
159-
160-
except subprocess.CalledProcessError as e:
161-
print(
162-
"❌ Error sending SIGILL to pod(s) matching regex"
163-
f" '{target_pod_regex}' for workload '{workload_name}'"
164-
)
165-
print(f"Return code: {e.returncode}")
166-
print(f"error: {e}")
111+
pod_name = get_pod_name_from_regex(workload_name, target_pod_regex)
112+
if not pod_name:
113+
return
114+
115+
kill_command = (
116+
f"kubectl exec -it {pod_name} -c {container_name} -- /bin/sh -c "
117+
f'"kill -s SIGILL 1"'
118+
)
119+
print(f"🔥🔥🔥 Executing command in pod: {kill_command} 🔥🔥🔥")
120+
execute_command_as_subprocess(kill_command)
121+
122+
123+
class SIGTERMHandler(DisruptionHandler):
124+
"""Handles SIGTERM disruption by sending a SIGTERM signal to the pod."""
125+
126+
def trigger_disruption(
127+
self, workload_name: str, cluster_config: XpkClusterConfig,
128+
disruption_config, target_pod_regex: str
129+
) -> None:
130+
"""Triggers SIGTERM disruption by executing kill -s SIGTERM 1 in the pod."""
131+
print(
132+
f"🔥🔥🔥 Beginning SIGTERM for workload: {workload_name} with pod regex:"
133+
f" {target_pod_regex} 🔥🔥🔥"
134+
)
135+
container_name = disruption_config.worker_container_name
136+
137+
pod_name = get_pod_name_from_regex(workload_name, target_pod_regex)
138+
if not pod_name:
139+
return
140+
141+
kill_command = (
142+
f"kubectl exec -it {pod_name} -c {container_name} -- /bin/sh -c "
143+
f'"kill -s SIGTERM 1"'
144+
)
145+
print(f"🔥🔥🔥 Executing command in pod: {kill_command} 🔥🔥🔥")
146+
execute_command_as_subprocess(kill_command)
167147

168148

169149
def create_disruption_handler(disruption_config):
170150
"""Factory function to create the appropriate disruption handler."""
171151
if disruption_config.disruption_method == DisruptionMethod.SIGTERM:
172-
raise NotImplementedError("SIGTERM Disruption Handler not implemented yet.")
152+
return SIGTERMHandler()
173153
elif disruption_config.disruption_method == DisruptionMethod.SIGILL:
174154
return SIGILLHandler()
175155
else:

benchmarks/disruption_management/disruption_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,11 @@ def _monitor_and_disrupt_workload(
113113
) -> None:
114114
"""Monitors workload progress, triggers disruptions, and recoveries."""
115115
target_pod_regex = f"{workload_name}{disruption_config.target_pod_regex}"
116+
step_pod_regex = f"{workload_name}{disruption_config.step_pod_regex}"
116117

117118
# Create Monitor based on trigger type
118119
monitor: Monitor = create_monitor(
119-
workload_name, disruption_config, target_pod_regex
120+
workload_name, disruption_config, step_pod_regex
120121
)
121122
disruption_handler: DisruptionHandler = create_disruption_handler(
122123
disruption_config
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
https://www.apache.org/licenses/LICENSE-2.0
7+
Unless required by applicable law or agreed to in writing, software
8+
distributed under the License is distributed on an "AS IS" BASIS,
9+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
See the License for the specific language governing permissions and
11+
limitations under the License.
12+
"""
13+
14+
"""Utility functions for disrupting workloads."""
15+
16+
import subprocess
17+
import time
18+
19+
POLLING_INTERVAL_SECONDS = 10
20+
21+
22+
def execute_command_as_subprocess(
23+
command: str,
24+
) -> None:
25+
"""Executes the command in the pod."""
26+
print(f"Executing command: {command}")
27+
try:
28+
process = subprocess.run(
29+
command,
30+
shell=True,
31+
check=True,
32+
capture_output=True,
33+
text=True,
34+
)
35+
if process.stdout:
36+
print(process.stdout)
37+
if process.stderr:
38+
print(f"stderr: {process.stderr}")
39+
print(f"✅ Successfully executed command: {command}")
40+
except subprocess.CalledProcessError as e:
41+
print(f"❌ Error executing command: {command}")
42+
print(f"Return code: {e.returncode}")
43+
print(f"Error: {e}")
44+
45+
46+
def get_pod_name_from_regex(workload_name: str, pod_regex: str) -> str | None:
47+
"""Returns the name of the first pod matching the regex."""
48+
print(
49+
f"Workload '{workload_name}': Getting pod name matching"
50+
f" '{pod_regex}'..."
51+
)
52+
pod_name_command = [
53+
"kubectl",
54+
"get",
55+
"pods",
56+
"-o=custom-columns=NAME:.metadata.name",
57+
"--no-headers",
58+
f"| grep -E '{pod_regex}'",
59+
]
60+
pod_name_command_str = " ".join(pod_name_command)
61+
try:
62+
process = subprocess.run(
63+
pod_name_command_str,
64+
shell=True,
65+
check=True,
66+
capture_output=True,
67+
text=True,
68+
)
69+
pod_names = process.stdout.strip().splitlines()
70+
if pod_names:
71+
# Assuming there's only one step pod.
72+
pod_name = pod_names[0]
73+
print(f"Workload '{workload_name}': Found pod: {pod_name}")
74+
return pod_name
75+
else:
76+
print(
77+
f"Workload '{workload_name}': No pod found matching"
78+
f" regex '{pod_regex}'."
79+
)
80+
except subprocess.CalledProcessError as e:
81+
print(
82+
f"Workload '{workload_name}': Error getting pod information:"
83+
f" {e}"
84+
)
85+
86+
87+
def get_pod_status(workload_name: str, pod_name: str) -> str | None:
88+
"""Returns the status of the pod."""
89+
print(
90+
f"Workload '{workload_name}': Getting status of pod '{pod_name}'..."
91+
)
92+
pod_status_command = [
93+
"kubectl",
94+
"get",
95+
"pod",
96+
pod_name,
97+
"-o=jsonpath='{.status.phase}'",
98+
]
99+
pod_status_command_str = " ".join(pod_status_command)
100+
status_process = subprocess.run(
101+
pod_status_command_str,
102+
shell=True,
103+
check=True,
104+
capture_output=True,
105+
text=True,
106+
)
107+
pod_status = status_process.stdout.strip()
108+
print(
109+
f"Workload '{workload_name}': Pod '{pod_name}' is in '{pod_status}'"
110+
" state."
111+
)
112+
return pod_status
113+
114+
115+
def wait_for_pod_to_start(workload_name: str, pod_regex: str) -> str | None:
116+
"""Waits for the step pod to be in 'Running' state and returns its name."""
117+
print(
118+
f"Workload '{workload_name}': Waiting for pod matching"
119+
f" '{pod_regex}' to be in 'Running' state..."
120+
)
121+
while True:
122+
pod_name = get_pod_name_from_regex(workload_name, pod_regex)
123+
if pod_name:
124+
pod_status = get_pod_status(workload_name, pod_name)
125+
if pod_status == "Running":
126+
print(
127+
f"Workload '{workload_name}': Step pod '{pod_name}'"
128+
f" is now in 'Running' state."
129+
)
130+
return pod_name
131+
time.sleep(POLLING_INTERVAL_SECONDS)
132+
133+
print(
134+
f"Workload '{workload_name}': Timed out waiting for step pod"
135+
f" matching '{pod_regex}' to reach 'Running' state."
136+
)
137+
return None

0 commit comments

Comments
 (0)