@@ -926,6 +926,7 @@ def _parse_conditional_branches(self):
926
926
self .conditional_nodes = set ()
927
927
self .conditional_join_nodes = set ()
928
928
self .matching_conditional_join_dict = {}
929
+ self .recursive_nodes = set ()
929
930
930
931
node_conditional_parents = {}
931
932
node_conditional_branches = {}
@@ -948,6 +949,12 @@ def _visit(node, seen, conditional_branch, conditional_parents=None):
948
949
)
949
950
node_conditional_parents [node .name ] = conditional_parents
950
951
952
+ # check for recursion. this split is recursive if any of its out functions are itself.
953
+ if any (
954
+ out_func for out_func in node .out_funcs if out_func == node .name
955
+ ):
956
+ self .recursive_nodes .add (node .name )
957
+
951
958
if conditional_parents and not node .type == "split-switch" :
952
959
node_conditional_parents [node .name ] = conditional_parents
953
960
conditional_branch = conditional_branch + [node .name ]
@@ -1033,6 +1040,9 @@ def _is_conditional_node(self, node):
1033
1040
def _is_conditional_join_node (self , node ):
1034
1041
return node .name in self .conditional_join_nodes
1035
1042
1043
+ def _is_recursive_node (self , node ):
1044
+ return node .name in self .recursive_nodes
1045
+
1036
1046
def _matching_conditional_join (self , node ):
1037
1047
return self .matching_conditional_join_dict .get (node .name , None )
1038
1048
@@ -1044,6 +1054,7 @@ def _visit(
1044
1054
templates = None ,
1045
1055
dag_tasks = None ,
1046
1056
parent_foreach = None ,
1057
+ seen = None ,
1047
1058
): # Returns Tuple[List[Template], List[DAGTask]]
1048
1059
""" """
1049
1060
# Every for-each node results in a separate subDAG and an equivalent
@@ -1053,13 +1064,22 @@ def _visit(
1053
1064
# of the for-each node.
1054
1065
1055
1066
# Emit if we have reached the end of the sub workflow
1067
+ if seen is None :
1068
+ seen = []
1056
1069
if dag_tasks is None :
1057
1070
dag_tasks = []
1058
1071
if templates is None :
1059
1072
templates = []
1060
1073
1061
1074
if exit_node is not None and exit_node is node .name :
1062
1075
return templates , dag_tasks
1076
+ if node .name in seen :
1077
+ return templates , dag_tasks
1078
+
1079
+ seen .append (node .name )
1080
+
1081
+ # helper variable for recursive conditional inputs
1082
+ has_foreach_inputs = False
1063
1083
if node .name == "start" :
1064
1084
# Start node has no dependencies.
1065
1085
dag_task = DAGTask (self ._sanitize (node .name )).template (
@@ -1073,9 +1093,10 @@ def _visit(
1073
1093
# vs what is a "num_parallel" based foreach (i.e. something that follows gang semantics.)
1074
1094
# A `regular` foreach is basically any arbitrary kind of foreach.
1075
1095
):
1096
+ # helper variable for recursive conditional inputs
1097
+ has_foreach_inputs = True
1076
1098
# Child of a foreach node needs input-paths as well as split-index
1077
1099
# This child is the first node of the sub workflow and has no dependency
1078
-
1079
1100
parameters = [
1080
1101
Parameter ("input-paths" ).value ("{{inputs.parameters.input-paths}}" ),
1081
1102
Parameter ("split-index" ).value ("{{inputs.parameters.split-index}}" ),
@@ -1253,22 +1274,118 @@ def _visit(
1253
1274
templates ,
1254
1275
dag_tasks ,
1255
1276
parent_foreach ,
1277
+ seen ,
1256
1278
)
1257
1279
return _visit (
1258
1280
self .graph [node .matching_join ],
1259
1281
exit_node ,
1260
1282
templates ,
1261
1283
dag_tasks ,
1262
1284
parent_foreach ,
1285
+ seen ,
1263
1286
)
1264
1287
elif node .type == "split-switch" :
1288
+ if self ._is_recursive_node (node ):
1289
+ # we need an additional recursive template if the step is recursive
1290
+ # NOTE: in the recursive case, the original step is renamed in the container templates to 'recursive-<step_name>'
1291
+ # so that we do not have to touch the step references in the DAG.
1292
+ #
1293
+ # NOTE: The way that recursion in Argo Workflows is achieved is with the following structure:
1294
+ # - the usual 'example-step' template which would match example_step in flow code is renamed to 'recursive-example-step'
1295
+ # - templates has another template with the original task name: 'example-step'
1296
+ # - the template 'example-step' in turn has steps
1297
+ # - 'example-step-internal' which uses the metaflow step executing template 'recursive-example-step'
1298
+ # - 'example-step-recursion' which calls the parent template 'example-step' if switch-step output from 'example-step-internal' matches the condition.
1299
+ sanitized_name = self ._sanitize (node .name )
1300
+ templates .append (
1301
+ Template (sanitized_name )
1302
+ .steps (
1303
+ [
1304
+ WorkflowStep ()
1305
+ .name ("%s-internal" % sanitized_name )
1306
+ .template ("recursive-%s" % sanitized_name )
1307
+ .arguments (
1308
+ Arguments ().parameters (
1309
+ [
1310
+ Parameter ("input-paths" ).value (
1311
+ "{{inputs.parameters.input-paths}}"
1312
+ )
1313
+ ]
1314
+ # Add the additional inputs required by specific node types.
1315
+ # We do not need to cover joins or @parallel, as a split-switch step can not be either one of these.
1316
+ + (
1317
+ [
1318
+ Parameter ("split-index" ).value (
1319
+ "{{inputs.parameters.split-index}}"
1320
+ )
1321
+ ]
1322
+ if has_foreach_inputs
1323
+ else []
1324
+ )
1325
+ )
1326
+ )
1327
+ ]
1328
+ )
1329
+ .steps (
1330
+ [
1331
+ WorkflowStep ()
1332
+ .name ("%s-recursion" % sanitized_name )
1333
+ .template (sanitized_name )
1334
+ .when (
1335
+ "{{steps.%s-internal.outputs.parameters.switch-step}}==%s"
1336
+ % (sanitized_name , node .name )
1337
+ )
1338
+ .arguments (
1339
+ Arguments ().parameters (
1340
+ [
1341
+ Parameter ("input-paths" ).value (
1342
+ "argo-{{workflow.name}}/%s/{{steps.%s-internal.outputs.parameters.task-id}}"
1343
+ % (node .name , sanitized_name )
1344
+ )
1345
+ ]
1346
+ + (
1347
+ [
1348
+ Parameter ("split-index" ).value (
1349
+ "{{inputs.parameters.split-index}}"
1350
+ )
1351
+ ]
1352
+ if has_foreach_inputs
1353
+ else []
1354
+ )
1355
+ )
1356
+ ),
1357
+ ]
1358
+ )
1359
+ .inputs (Inputs ().parameters (parameters ))
1360
+ .outputs (
1361
+ # NOTE: We try to read the output parameters from the recursive template call first (<step>-recursion), and the internal step second (<step>-internal).
1362
+ # This guarantees that we always get the output parameters of the last recursive step that executed.
1363
+ Outputs ().parameters (
1364
+ [
1365
+ Parameter ("task-id" ).valueFrom (
1366
+ {
1367
+ "expression" : "(steps['%s-recursion']?.outputs ?? steps['%s-internal']?.outputs).parameters['task-id']"
1368
+ % (sanitized_name , sanitized_name )
1369
+ }
1370
+ ),
1371
+ Parameter ("switch-step" ).valueFrom (
1372
+ {
1373
+ "expression" : "(steps['%s-recursion']?.outputs ?? steps['%s-internal']?.outputs).parameters['switch-step']"
1374
+ % (sanitized_name , sanitized_name )
1375
+ }
1376
+ ),
1377
+ ]
1378
+ )
1379
+ )
1380
+ )
1265
1381
for n in node .out_funcs :
1266
1382
_visit (
1267
1383
self .graph [n ],
1268
1384
self ._matching_conditional_join (node ),
1269
1385
templates ,
1270
1386
dag_tasks ,
1271
1387
parent_foreach ,
1388
+ seen ,
1272
1389
)
1273
1390
1274
1391
return _visit (
@@ -1277,6 +1394,7 @@ def _visit(
1277
1394
templates ,
1278
1395
dag_tasks ,
1279
1396
parent_foreach ,
1397
+ seen ,
1280
1398
)
1281
1399
# For foreach nodes generate a new sub DAGTemplate
1282
1400
# We do this for "regular" foreaches (ie. `self.next(self.a, foreach=)`)
@@ -1367,6 +1485,7 @@ def _visit(
1367
1485
templates ,
1368
1486
[],
1369
1487
node .name ,
1488
+ seen ,
1370
1489
)
1371
1490
1372
1491
# How do foreach's work on Argo:
@@ -1500,6 +1619,7 @@ def _visit(
1500
1619
templates ,
1501
1620
dag_tasks ,
1502
1621
parent_foreach ,
1622
+ seen ,
1503
1623
)
1504
1624
# For linear nodes continue traversing to the next node
1505
1625
if node .type in ("linear" , "join" , "start" ):
@@ -1509,6 +1629,7 @@ def _visit(
1509
1629
templates ,
1510
1630
dag_tasks ,
1511
1631
parent_foreach ,
1632
+ seen ,
1512
1633
)
1513
1634
else :
1514
1635
raise ArgoWorkflowsException (
@@ -2290,8 +2411,13 @@ def _container_templates(self):
2290
2411
)
2291
2412
)
2292
2413
else :
2414
+ template_name = self ._sanitize (node .name )
2415
+ if self ._is_recursive_node (node ):
2416
+ # The recursive template has the original step name,
2417
+ # this becomes a template within the recursive ones 'steps'
2418
+ template_name = self ._sanitize ("recursive-%s" % node .name )
2293
2419
yield (
2294
- Template (self . _sanitize ( node . name ) )
2420
+ Template (template_name )
2295
2421
# Set @timeout values
2296
2422
.active_deadline_seconds (run_time_limit )
2297
2423
# Set service account
@@ -3750,6 +3876,10 @@ def template(self, template):
3750
3876
self .payload ["template" ] = str (template )
3751
3877
return self
3752
3878
3879
+ def arguments (self , arguments ):
3880
+ self .payload ["arguments" ] = arguments .to_json ()
3881
+ return self
3882
+
3753
3883
def when (self , condition ):
3754
3884
self .payload ["when" ] = str (condition )
3755
3885
return self
0 commit comments