Skip to content

Commit 5866d4e

Browse files
authored
feature: support recursive conditionals on Argo Workflows (#2562)
adds support for recursive conditionals on argo workflows
1 parent a3d62bd commit 5866d4e

File tree

1 file changed

+132
-2
lines changed

1 file changed

+132
-2
lines changed

metaflow/plugins/argo/argo_workflows.py

Lines changed: 132 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -926,6 +926,7 @@ def _parse_conditional_branches(self):
926926
self.conditional_nodes = set()
927927
self.conditional_join_nodes = set()
928928
self.matching_conditional_join_dict = {}
929+
self.recursive_nodes = set()
929930

930931
node_conditional_parents = {}
931932
node_conditional_branches = {}
@@ -948,6 +949,12 @@ def _visit(node, seen, conditional_branch, conditional_parents=None):
948949
)
949950
node_conditional_parents[node.name] = conditional_parents
950951

952+
# check for recursion. this split is recursive if any of its out functions are itself.
953+
if any(
954+
out_func for out_func in node.out_funcs if out_func == node.name
955+
):
956+
self.recursive_nodes.add(node.name)
957+
951958
if conditional_parents and not node.type == "split-switch":
952959
node_conditional_parents[node.name] = conditional_parents
953960
conditional_branch = conditional_branch + [node.name]
@@ -1033,6 +1040,9 @@ def _is_conditional_node(self, node):
10331040
def _is_conditional_join_node(self, node):
10341041
return node.name in self.conditional_join_nodes
10351042

1043+
def _is_recursive_node(self, node):
1044+
return node.name in self.recursive_nodes
1045+
10361046
def _matching_conditional_join(self, node):
10371047
return self.matching_conditional_join_dict.get(node.name, None)
10381048

@@ -1044,6 +1054,7 @@ def _visit(
10441054
templates=None,
10451055
dag_tasks=None,
10461056
parent_foreach=None,
1057+
seen=None,
10471058
): # Returns Tuple[List[Template], List[DAGTask]]
10481059
""" """
10491060
# Every for-each node results in a separate subDAG and an equivalent
@@ -1053,13 +1064,22 @@ def _visit(
10531064
# of the for-each node.
10541065

10551066
# Emit if we have reached the end of the sub workflow
1067+
if seen is None:
1068+
seen = []
10561069
if dag_tasks is None:
10571070
dag_tasks = []
10581071
if templates is None:
10591072
templates = []
10601073

10611074
if exit_node is not None and exit_node is node.name:
10621075
return templates, dag_tasks
1076+
if node.name in seen:
1077+
return templates, dag_tasks
1078+
1079+
seen.append(node.name)
1080+
1081+
# helper variable for recursive conditional inputs
1082+
has_foreach_inputs = False
10631083
if node.name == "start":
10641084
# Start node has no dependencies.
10651085
dag_task = DAGTask(self._sanitize(node.name)).template(
@@ -1073,9 +1093,10 @@ def _visit(
10731093
# vs what is a "num_parallel" based foreach (i.e. something that follows gang semantics.)
10741094
# A `regular` foreach is basically any arbitrary kind of foreach.
10751095
):
1096+
# helper variable for recursive conditional inputs
1097+
has_foreach_inputs = True
10761098
# Child of a foreach node needs input-paths as well as split-index
10771099
# This child is the first node of the sub workflow and has no dependency
1078-
10791100
parameters = [
10801101
Parameter("input-paths").value("{{inputs.parameters.input-paths}}"),
10811102
Parameter("split-index").value("{{inputs.parameters.split-index}}"),
@@ -1253,22 +1274,118 @@ def _visit(
12531274
templates,
12541275
dag_tasks,
12551276
parent_foreach,
1277+
seen,
12561278
)
12571279
return _visit(
12581280
self.graph[node.matching_join],
12591281
exit_node,
12601282
templates,
12611283
dag_tasks,
12621284
parent_foreach,
1285+
seen,
12631286
)
12641287
elif node.type == "split-switch":
1288+
if self._is_recursive_node(node):
1289+
# we need an additional recursive template if the step is recursive
1290+
# NOTE: in the recursive case, the original step is renamed in the container templates to 'recursive-<step_name>'
1291+
# so that we do not have to touch the step references in the DAG.
1292+
#
1293+
# NOTE: The way that recursion in Argo Workflows is achieved is with the following structure:
1294+
# - the usual 'example-step' template which would match example_step in flow code is renamed to 'recursive-example-step'
1295+
# - templates has another template with the original task name: 'example-step'
1296+
# - the template 'example-step' in turn has steps
1297+
# - 'example-step-internal' which uses the metaflow step executing template 'recursive-example-step'
1298+
# - 'example-step-recursion' which calls the parent template 'example-step' if switch-step output from 'example-step-internal' matches the condition.
1299+
sanitized_name = self._sanitize(node.name)
1300+
templates.append(
1301+
Template(sanitized_name)
1302+
.steps(
1303+
[
1304+
WorkflowStep()
1305+
.name("%s-internal" % sanitized_name)
1306+
.template("recursive-%s" % sanitized_name)
1307+
.arguments(
1308+
Arguments().parameters(
1309+
[
1310+
Parameter("input-paths").value(
1311+
"{{inputs.parameters.input-paths}}"
1312+
)
1313+
]
1314+
# Add the additional inputs required by specific node types.
1315+
# We do not need to cover joins or @parallel, as a split-switch step can not be either one of these.
1316+
+ (
1317+
[
1318+
Parameter("split-index").value(
1319+
"{{inputs.parameters.split-index}}"
1320+
)
1321+
]
1322+
if has_foreach_inputs
1323+
else []
1324+
)
1325+
)
1326+
)
1327+
]
1328+
)
1329+
.steps(
1330+
[
1331+
WorkflowStep()
1332+
.name("%s-recursion" % sanitized_name)
1333+
.template(sanitized_name)
1334+
.when(
1335+
"{{steps.%s-internal.outputs.parameters.switch-step}}==%s"
1336+
% (sanitized_name, node.name)
1337+
)
1338+
.arguments(
1339+
Arguments().parameters(
1340+
[
1341+
Parameter("input-paths").value(
1342+
"argo-{{workflow.name}}/%s/{{steps.%s-internal.outputs.parameters.task-id}}"
1343+
% (node.name, sanitized_name)
1344+
)
1345+
]
1346+
+ (
1347+
[
1348+
Parameter("split-index").value(
1349+
"{{inputs.parameters.split-index}}"
1350+
)
1351+
]
1352+
if has_foreach_inputs
1353+
else []
1354+
)
1355+
)
1356+
),
1357+
]
1358+
)
1359+
.inputs(Inputs().parameters(parameters))
1360+
.outputs(
1361+
# NOTE: We try to read the output parameters from the recursive template call first (<step>-recursion), and the internal step second (<step>-internal).
1362+
# This guarantees that we always get the output parameters of the last recursive step that executed.
1363+
Outputs().parameters(
1364+
[
1365+
Parameter("task-id").valueFrom(
1366+
{
1367+
"expression": "(steps['%s-recursion']?.outputs ?? steps['%s-internal']?.outputs).parameters['task-id']"
1368+
% (sanitized_name, sanitized_name)
1369+
}
1370+
),
1371+
Parameter("switch-step").valueFrom(
1372+
{
1373+
"expression": "(steps['%s-recursion']?.outputs ?? steps['%s-internal']?.outputs).parameters['switch-step']"
1374+
% (sanitized_name, sanitized_name)
1375+
}
1376+
),
1377+
]
1378+
)
1379+
)
1380+
)
12651381
for n in node.out_funcs:
12661382
_visit(
12671383
self.graph[n],
12681384
self._matching_conditional_join(node),
12691385
templates,
12701386
dag_tasks,
12711387
parent_foreach,
1388+
seen,
12721389
)
12731390

12741391
return _visit(
@@ -1277,6 +1394,7 @@ def _visit(
12771394
templates,
12781395
dag_tasks,
12791396
parent_foreach,
1397+
seen,
12801398
)
12811399
# For foreach nodes generate a new sub DAGTemplate
12821400
# We do this for "regular" foreaches (ie. `self.next(self.a, foreach=)`)
@@ -1367,6 +1485,7 @@ def _visit(
13671485
templates,
13681486
[],
13691487
node.name,
1488+
seen,
13701489
)
13711490

13721491
# How do foreach's work on Argo:
@@ -1500,6 +1619,7 @@ def _visit(
15001619
templates,
15011620
dag_tasks,
15021621
parent_foreach,
1622+
seen,
15031623
)
15041624
# For linear nodes continue traversing to the next node
15051625
if node.type in ("linear", "join", "start"):
@@ -1509,6 +1629,7 @@ def _visit(
15091629
templates,
15101630
dag_tasks,
15111631
parent_foreach,
1632+
seen,
15121633
)
15131634
else:
15141635
raise ArgoWorkflowsException(
@@ -2290,8 +2411,13 @@ def _container_templates(self):
22902411
)
22912412
)
22922413
else:
2414+
template_name = self._sanitize(node.name)
2415+
if self._is_recursive_node(node):
2416+
# The recursive template has the original step name,
2417+
# this becomes a template within the recursive ones 'steps'
2418+
template_name = self._sanitize("recursive-%s" % node.name)
22932419
yield (
2294-
Template(self._sanitize(node.name))
2420+
Template(template_name)
22952421
# Set @timeout values
22962422
.active_deadline_seconds(run_time_limit)
22972423
# Set service account
@@ -3750,6 +3876,10 @@ def template(self, template):
37503876
self.payload["template"] = str(template)
37513877
return self
37523878

3879+
def arguments(self, arguments):
3880+
self.payload["arguments"] = arguments.to_json()
3881+
return self
3882+
37533883
def when(self, condition):
37543884
self.payload["when"] = str(condition)
37553885
return self

0 commit comments

Comments
 (0)