Skip to content

Commit 3341831

Browse files
authored
Add 'include' parameter to merge_artifacts for easier control of artifacts to merge (#287)
Co-authored-by: Romain Cledat <[email protected]>
1 parent c9716b0 commit 3341831

File tree

4 files changed

+109
-11
lines changed

4 files changed

+109
-11
lines changed

metaflow/exception.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,16 @@ class CommandException(MetaflowException):
105105
class MetaflowDataMissing(MetaflowException):
106106
headline = "Data missing"
107107

108-
class MergeArtifactsException(MetaflowException):
108+
class UnhandledInMergeArtifactsException(MetaflowException):
109109
headline = "Unhandled artifacts in merge"
110110

111111
def __init__(self, msg, unhandled):
112-
super(MergeArtifactsException, self).__init__(msg)
112+
super(UnhandledInMergeArtifactsException, self).__init__(msg)
113+
self.artifact_names = unhandled
114+
115+
class MissingInMergeArtifactsException(MetaflowException):
116+
headline = "Missing artifacts in merge"
117+
118+
def __init__(self, msg, unhandled):
119+
super(MissingInMergeArtifactsException, self).__init__(msg)
113120
self.artifact_names = unhandled

metaflow/flowspec.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
from . import cmd_with_io
88
from .parameters import Parameter
9-
from .exception import MetaflowException, MetaflowInternalError, MergeArtifactsException
9+
from .exception import MetaflowException, MetaflowInternalError, \
10+
MissingInMergeArtifactsException, UnhandledInMergeArtifactsException
1011
from .graph import FlowGraph
1112

1213
# For Python 3 compatibility
@@ -253,7 +254,7 @@ def _find_input(self, stack_index=None):
253254
frame.index + 1))
254255
return self._cached_input[stack_index]
255256

256-
def merge_artifacts(self, inputs, exclude=[]):
257+
def merge_artifacts(self, inputs, exclude=[], include=[]):
257258
"""
258259
Merge the artifacts coming from each merge branch (from inputs)
259260
@@ -288,40 +289,68 @@ def merge_artifacts(self, inputs, exclude=[]):
288289
inputs : List[Steps]
289290
Incoming steps to the join point
290291
exclude : List[str], optional
292+
If specified, do not consider merging artifacts with a name in `exclude`.
293+
Cannot specify if `include` is also specified
294+
include : List[str], optional
295+
If specified, only merge artifacts specified. Cannot specify if `exclude` is
296+
also specified
291297
292298
Raises
293299
------
294300
MetaflowException
295301
This exception is thrown if this is not called in a join step
296-
MergeArtifactsException
302+
UnhandledInMergeArtifactsException
297303
This exception is thrown in case of unresolved conflicts
304+
MissingInMergeArtifactsException
305+
This exception is thrown in case an artifact specified in `include cannot
306+
be found
298307
"""
299308
node = self._graph[self._current_step]
300309
if node.type != 'join':
301310
msg = "merge_artifacts can only be called in a join and step *{step}* "\
302311
"is not a join".format(step=self._current_step)
303312
raise MetaflowException(msg)
313+
if len(exclude) > 0 and len(include) > 0:
314+
msg = "`exclude` and `include` are mutually exclusive in merge_artifacts"
315+
raise MetaflowException(msg)
304316

305317
to_merge = {}
306318
unresolved = []
307319
for inp in inputs:
308320
# available_vars is the list of variables from inp that should be considered
309-
available_vars = ((var, sha) for var, sha in inp._datastore.items()
310-
if (var not in exclude) and (not hasattr(self, var)))
321+
if include:
322+
available_vars = ((var, sha) for var, sha in inp._datastore.items()
323+
if (var in include) and (not hasattr(self, var)))
324+
else:
325+
available_vars = ((var, sha) for var, sha in inp._datastore.items()
326+
if (var not in exclude) and (not hasattr(self, var)))
311327
for var, sha in available_vars:
312328
_, previous_sha = to_merge.setdefault(var, (inp, sha))
313329
if previous_sha != sha:
314330
# We have a conflict here
315331
unresolved.append(var)
316-
332+
# Check if everything in include is present in to_merge
333+
missing = []
334+
for v in include:
335+
if v not in to_merge and not hasattr(self, v):
336+
missing.append(v)
317337
if unresolved:
318338
# We have unresolved conflicts so we do not set anything and error out
319339
msg = "Step *{step}* cannot merge the following artifacts due to them "\
320340
"having conflicting values:\n[{artifacts}].\nTo remedy this issue, "\
321341
"be sure to explictly set those artifacts (using "\
322342
"self.<artifact_name> = ...) prior to calling merge_artifacts."\
323343
.format(step=self._current_step, artifacts=', '.join(unresolved))
324-
raise MergeArtifactsException(msg, unresolved)
344+
raise UnhandledInMergeArtifactsException(msg, unresolved)
345+
if missing:
346+
msg = "Step *{step}* specifies that [{include}] should be merged but "\
347+
"[{missing}] are not present.\nTo remedy this issue, make sure "\
348+
"that the values specified in only come from at least one branch"\
349+
.format(
350+
step=self._current_step,
351+
include=', '.join(include),
352+
missing=', '.join(missing))
353+
raise MissingInMergeArtifactsException(msg, missing)
325354
# If things are resolved, we go and fetch from the datastore and set here
326355
for var, (inp, _) in to_merge.items():
327356
setattr(self, var, getattr(inp, var))

test/core/tests/merge_artifacts.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,20 @@ def modify_things(self):
2222
@steps(0, ['join'], required=True)
2323
def merge_things(self, inputs):
2424
from metaflow.current import current
25-
from metaflow.exception import MergeArtifactsException
25+
from metaflow.exception import UnhandledInMergeArtifactsException, MetaflowException
2626

2727
# Test to make sure non-merged values are reported
28-
assert_exception(lambda: self.merge_artifacts(inputs), MergeArtifactsException)
28+
assert_exception(lambda: self.merge_artifacts(inputs), UnhandledInMergeArtifactsException)
29+
30+
# Test to make sure nothing is set if failed merge_artifacts
31+
assert(not hasattr(self, 'non_modified_passdown'))
32+
assert(not hasattr(self, 'manual_merge_required'))
33+
34+
# Test to make sure that only one of exclude/include is used
35+
assert_exception(lambda: self.merge_artifacts(
36+
inputs,
37+
exclude=['ignore_me'],
38+
include=['non_modified_passdown']), MetaflowException)
2939

3040
# Test to make sure nothing is set if failed merge_artifacts
3141
assert(not hasattr(self, 'non_modified_passdown'))
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from metaflow_test import MetaflowTest, ExpectationFailed, steps
2+
3+
class MergeArtifactsIncludeTest(MetaflowTest):
4+
PRIORITY = 1
5+
6+
@steps(0, ['start'])
7+
def start(self):
8+
self.non_modified_passdown = 'a'
9+
self.modified_to_same_value = 'b'
10+
self.manual_merge_required = 'c'
11+
self.ignore_me = 'd'
12+
13+
@steps(2, ['linear'])
14+
def modify_things(self):
15+
# Set to different things
16+
from metaflow.current import current
17+
self.manual_merge_required = current.task_id
18+
self.ignore_me = current.task_id
19+
self.modified_to_same_value = 'e'
20+
assert_equals(self.non_modified_passdown, 'a')
21+
22+
@steps(0, ['join'], required=True)
23+
def merge_things(self, inputs):
24+
from metaflow.current import current
25+
from metaflow.exception import MissingInMergeArtifactsException
26+
27+
self.manual_merge_required = current.task_id
28+
# Test to see if we raise an exception if include specifies non-merged things
29+
assert_exception(lambda: self.merge_artifacts(
30+
inputs, include=['manual_merge_required', 'foobar']), MissingInMergeArtifactsException)
31+
32+
# Test to make sure nothing is set if failed merge_artifacts
33+
assert(not hasattr(self, 'non_modified_passdown'))
34+
35+
# Merge include non_modified_passdown
36+
self.merge_artifacts(inputs, include=['non_modified_passdown'])
37+
38+
# Ensure that everything we expect is passed down
39+
assert_equals(self.non_modified_passdown, 'a')
40+
assert_equals(self.manual_merge_required, current.task_id)
41+
assert(not hasattr(self, 'ignore_me'))
42+
assert(not hasattr(self, 'modified_to_same_value'))
43+
44+
@steps(0, ['end'])
45+
def end(self):
46+
# Check that all values made it through
47+
assert_equals(self.non_modified_passdown, 'a')
48+
assert(hasattr(self, 'manual_merge_required'))
49+
50+
@steps(3, ['all'])
51+
def step_all(self):
52+
assert_equals(self.non_modified_passdown, 'a')

0 commit comments

Comments
 (0)