|
6 | 6 |
|
7 | 7 | from . import cmd_with_io
|
8 | 8 | from .parameters import Parameter
|
9 |
| -from .exception import MetaflowException, MetaflowInternalError, MergeArtifactsException |
| 9 | +from .exception import MetaflowException, MetaflowInternalError, \ |
| 10 | + MissingInMergeArtifactsException, UnhandledInMergeArtifactsException |
10 | 11 | from .graph import FlowGraph
|
11 | 12 |
|
12 | 13 | # For Python 3 compatibility
|
@@ -253,7 +254,7 @@ def _find_input(self, stack_index=None):
|
253 | 254 | frame.index + 1))
|
254 | 255 | return self._cached_input[stack_index]
|
255 | 256 |
|
256 |
| - def merge_artifacts(self, inputs, exclude=[]): |
| 257 | + def merge_artifacts(self, inputs, exclude=[], include=[]): |
257 | 258 | """
|
258 | 259 | Merge the artifacts coming from each merge branch (from inputs)
|
259 | 260 |
|
@@ -288,40 +289,68 @@ def merge_artifacts(self, inputs, exclude=[]):
|
288 | 289 | inputs : List[Steps]
|
289 | 290 | Incoming steps to the join point
|
290 | 291 | exclude : List[str], optional
|
| 292 | + If specified, do not consider merging artifacts with a name in `exclude`. |
| 293 | + Cannot specify if `include` is also specified |
| 294 | + include : List[str], optional |
| 295 | + If specified, only merge artifacts specified. Cannot specify if `exclude` is |
| 296 | + also specified |
291 | 297 |
|
292 | 298 | Raises
|
293 | 299 | ------
|
294 | 300 | MetaflowException
|
295 | 301 | This exception is thrown if this is not called in a join step
|
296 |
| - MergeArtifactsException |
| 302 | + UnhandledInMergeArtifactsException |
297 | 303 | This exception is thrown in case of unresolved conflicts
|
| 304 | + MissingInMergeArtifactsException |
| 305 | + This exception is thrown in case an artifact specified in `include cannot |
| 306 | + be found |
298 | 307 | """
|
299 | 308 | node = self._graph[self._current_step]
|
300 | 309 | if node.type != 'join':
|
301 | 310 | msg = "merge_artifacts can only be called in a join and step *{step}* "\
|
302 | 311 | "is not a join".format(step=self._current_step)
|
303 | 312 | raise MetaflowException(msg)
|
| 313 | + if len(exclude) > 0 and len(include) > 0: |
| 314 | + msg = "`exclude` and `include` are mutually exclusive in merge_artifacts" |
| 315 | + raise MetaflowException(msg) |
304 | 316 |
|
305 | 317 | to_merge = {}
|
306 | 318 | unresolved = []
|
307 | 319 | for inp in inputs:
|
308 | 320 | # available_vars is the list of variables from inp that should be considered
|
309 |
| - available_vars = ((var, sha) for var, sha in inp._datastore.items() |
310 |
| - if (var not in exclude) and (not hasattr(self, var))) |
| 321 | + if include: |
| 322 | + available_vars = ((var, sha) for var, sha in inp._datastore.items() |
| 323 | + if (var in include) and (not hasattr(self, var))) |
| 324 | + else: |
| 325 | + available_vars = ((var, sha) for var, sha in inp._datastore.items() |
| 326 | + if (var not in exclude) and (not hasattr(self, var))) |
311 | 327 | for var, sha in available_vars:
|
312 | 328 | _, previous_sha = to_merge.setdefault(var, (inp, sha))
|
313 | 329 | if previous_sha != sha:
|
314 | 330 | # We have a conflict here
|
315 | 331 | unresolved.append(var)
|
316 |
| - |
| 332 | + # Check if everything in include is present in to_merge |
| 333 | + missing = [] |
| 334 | + for v in include: |
| 335 | + if v not in to_merge and not hasattr(self, v): |
| 336 | + missing.append(v) |
317 | 337 | if unresolved:
|
318 | 338 | # We have unresolved conflicts so we do not set anything and error out
|
319 | 339 | msg = "Step *{step}* cannot merge the following artifacts due to them "\
|
320 | 340 | "having conflicting values:\n[{artifacts}].\nTo remedy this issue, "\
|
321 | 341 | "be sure to explictly set those artifacts (using "\
|
322 | 342 | "self.<artifact_name> = ...) prior to calling merge_artifacts."\
|
323 | 343 | .format(step=self._current_step, artifacts=', '.join(unresolved))
|
324 |
| - raise MergeArtifactsException(msg, unresolved) |
| 344 | + raise UnhandledInMergeArtifactsException(msg, unresolved) |
| 345 | + if missing: |
| 346 | + msg = "Step *{step}* specifies that [{include}] should be merged but "\ |
| 347 | + "[{missing}] are not present.\nTo remedy this issue, make sure "\ |
| 348 | + "that the values specified in only come from at least one branch"\ |
| 349 | + .format( |
| 350 | + step=self._current_step, |
| 351 | + include=', '.join(include), |
| 352 | + missing=', '.join(missing)) |
| 353 | + raise MissingInMergeArtifactsException(msg, missing) |
325 | 354 | # If things are resolved, we go and fetch from the datastore and set here
|
326 | 355 | for var, (inp, _) in to_merge.items():
|
327 | 356 | setattr(self, var, getattr(inp, var))
|
|
0 commit comments