Skip to content

Commit aa6440e

Browse files
honnixHongxin Liang
andauthored
Explicitly fall back to real test execution when testing (#218)
Explicitly fall back to real test execution when testing Instead of silently falling back to real task execution when mocked inputs are unmatched, this now requires an explicit mock passing in the task's `run` method. Signed-off-by: Hongxin Liang <[email protected]> Co-authored-by: Hongxin Liang <[email protected]>
1 parent 707398f commit aa6440e

File tree

5 files changed

+60
-12
lines changed

5 files changed

+60
-12
lines changed

flytekit-testing/src/main/java/org/flyte/flytekit/testing/TestingRunnableLaunchPlan.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,14 @@ public class TestingRunnableLaunchPlan<InputT, OutputT>
3535
SdkType<InputT> inputType,
3636
SdkType<OutputT> outputType,
3737
Function<InputT, OutputT> runFn,
38+
boolean runFnProvided,
3839
Map<InputT, OutputT> fixedOutputs) {
3940
super(
4041
launchPlanId,
4142
inputType,
4243
outputType,
4344
runFn,
45+
runFnProvided,
4446
fixedOutputs,
4547
TestingRunnableLaunchPlan::new,
4648
"launch plan",
@@ -52,6 +54,7 @@ static <InputT, OutputT> TestingRunnableLaunchPlan<InputT, OutputT> create(
5254
PartialLaunchPlanIdentifier launchPlanId =
5355
PartialLaunchPlanIdentifier.builder().name(name).build();
5456

55-
return new TestingRunnableLaunchPlan<>(launchPlanId, inputType, outputType, null, emptyMap());
57+
return new TestingRunnableLaunchPlan<>(
58+
launchPlanId, inputType, outputType, null, false, emptyMap());
5659
}
5760
}

flytekit-testing/src/main/java/org/flyte/flytekit/testing/TestingRunnableNode.java

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ public abstract class TestingRunnableNode<
3838

3939
// @Nullable - signal nullable field but without adding the dependency
4040
protected final Function<InputT, OutputT> runFn;
41+
private final boolean runFnProvided;
4142

4243
protected final Map<InputT, OutputT> fixedOutputs;
4344
private final Creator<IdT, InputT, OutputT, T> creatorFn;
@@ -54,6 +55,7 @@ T create(
5455
SdkType<InputT> inputType,
5556
SdkType<OutputT> outputType,
5657
Function<InputT, OutputT> runFn,
58+
boolean runFnProvided,
5759
Map<InputT, OutputT> fixedOutputs);
5860
}
5961

@@ -62,6 +64,7 @@ protected TestingRunnableNode(
6264
SdkType<InputT> inputType,
6365
SdkType<OutputT> outputType,
6466
Function<InputT, OutputT> runFn,
67+
boolean runFnProvided,
6568
Map<InputT, OutputT> fixedOutputs,
6669
Creator<IdT, InputT, OutputT, T> creatorFn,
6770
String type,
@@ -70,6 +73,7 @@ protected TestingRunnableNode(
7073
this.inputType = requireNonNull(inputType, "inputType");
7174
this.outputType = requireNonNull(outputType, "outputType");
7275
this.runFn = runFn; // Nullable
76+
this.runFnProvided = runFnProvided;
7377
this.fixedOutputs = requireNonNull(fixedOutputs, "fixedOutputs");
7478
this.creatorFn = requireNonNull(creatorFn, "creatorFn");
7579
this.type = requireNonNull(type, "type");
@@ -80,19 +84,28 @@ protected TestingRunnableNode(
8084
public Map<String, Literal> run(Map<String, Literal> inputs) {
8185
InputT input = inputType.fromLiteralMap(inputs);
8286

83-
if (fixedOutputs.containsKey(input)) {
84-
return outputType.toLiteralMap(fixedOutputs.get(input));
85-
} else if (runFn != null) {
86-
return outputType.toLiteralMap(runFn.apply(input));
87+
if (fixedOutputs.size() == 0) {
88+
// No mocking via input matching, either run the real thing or run the provided lambda
89+
if (runFn != null) {
90+
return outputType.toLiteralMap(runFn.apply(input));
91+
}
92+
} else {
93+
if (fixedOutputs.containsKey(input)) {
94+
return outputType.toLiteralMap(fixedOutputs.get(input));
95+
}
96+
// Inputs not matching, run the provided lambda
97+
if (runFn != null && runFnProvided) {
98+
return outputType.toLiteralMap(runFn.apply(input));
99+
}
87100
}
88101

89-
// TODO see if we can improve this error message as input is hard to read
90-
// We can improve the SdkBindingData toString method
91102
String message =
92103
String.format(
93104
"Can't find input %s for remote %s [%s] across known %s inputs, "
94105
+ "use %s to provide a test double",
95106
input, type, getName(), type, testingSuggestion);
107+
108+
// Not matching inputs and there is nothing to run
96109
throw new IllegalArgumentException(message);
97110
}
98111

@@ -105,10 +118,10 @@ public T withFixedOutput(InputT input, OutputT output) {
105118
Map<InputT, OutputT> newFixedOutputs = new HashMap<>(fixedOutputs);
106119
newFixedOutputs.put(input, output);
107120

108-
return creatorFn.create(id, inputType, outputType, runFn, newFixedOutputs);
121+
return creatorFn.create(id, inputType, outputType, runFn, runFnProvided, newFixedOutputs);
109122
}
110123

111124
public T withRunFn(Function<InputT, OutputT> runFn) {
112-
return creatorFn.create(id, inputType, outputType, runFn, fixedOutputs);
125+
return creatorFn.create(id, inputType, outputType, runFn, true, fixedOutputs);
113126
}
114127
}

flytekit-testing/src/main/java/org/flyte/flytekit/testing/TestingRunnableTask.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,14 @@ private TestingRunnableTask(
3737
SdkType<InputT> inputType,
3838
SdkType<OutputT> outputType,
3939
Function<InputT, OutputT> runFn,
40+
boolean runFnProvided,
4041
Map<InputT, OutputT> fixedOutputs) {
4142
super(
4243
taskId,
4344
inputType,
4445
outputType,
4546
runFn,
47+
runFnProvided,
4648
fixedOutputs,
4749
TestingRunnableTask::new,
4850
"task",
@@ -54,14 +56,15 @@ static <InputT, OutputT> TestingRunnableTask<InputT, OutputT> create(
5456
PartialTaskIdentifier taskId = PartialTaskIdentifier.builder().name(task.getName()).build();
5557

5658
return new TestingRunnableTask<>(
57-
taskId, task.getInputType(), task.getOutputType(), task::run, emptyMap());
59+
taskId, task.getInputType(), task.getOutputType(), task::run, false, emptyMap());
5860
}
5961

6062
static <InputT, OutputT> TestingRunnableTask<InputT, OutputT> create(
6163
String name, SdkType<InputT> inputType, SdkType<OutputT> outputType) {
6264
PartialTaskIdentifier taskId = PartialTaskIdentifier.builder().name(name).build();
6365

64-
return new TestingRunnableTask<>(taskId, inputType, outputType, /* runFn= */ null, emptyMap());
66+
return new TestingRunnableTask<>(
67+
taskId, inputType, outputType, /* runFn= */ null, false, emptyMap());
6568
}
6669

6770
@Override

flytekit-testing/src/test/java/org/flyte/flytekit/testing/FibonacciWorkflowTest.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ public void testWithTaskOutput_runnableTask() {
7373
new SumTask(),
7474
SumInput.create(SdkBindingDataFactory.of(3L), SdkBindingDataFactory.of(5L)),
7575
SumOutput.create(SdkBindingDataFactory.of(42L)))
76+
.withTask(new SumTask(), new SumTask()::run)
7677
.execute();
7778

7879
assertThat(result.getIntegerOutput("fib2"), equalTo(2L));

flytekit-testing/src/test/java/org/flyte/flytekit/testing/TestingRunnableNodeTest.java

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,17 @@ void testRun_notFound() {
8585
+ "across known test inputs, use a magic wang to provide a test double"));
8686
}
8787

88+
@Test
89+
void testRun_notFoundRunFnProvided() {
90+
Function<Input, Output> fn = in -> Output.create(Long.parseLong(in.in().get()));
91+
Map<Input, Output> fixedOutputs = singletonMap(Input.create("7"), Output.create(7L));
92+
TestNode node = new TestNode(fn, fixedOutputs);
93+
94+
Map<String, Literal> output = node.run(singletonMap("in", Literals.ofString("10")));
95+
96+
assertThat(output, hasEntry("out", Literals.ofInteger(10L)));
97+
}
98+
8899
@Test
89100
void testWithFixedOutput() {
90101
TestNode node =
@@ -105,6 +116,22 @@ void testWithRunFn() {
105116
assertThat(output, hasEntry("out", Literals.ofInteger(7L)));
106117
}
107118

119+
@Test
120+
void testWithoutRunFn() {
121+
TestNode node = new TestNode(null, emptyMap());
122+
123+
IllegalArgumentException ex =
124+
assertThrows(
125+
IllegalArgumentException.class,
126+
() -> node.run(singletonMap("in", Literals.ofString("7"))));
127+
128+
assertThat(
129+
ex.getMessage(),
130+
equalTo(
131+
"Can't find input Input{in=SdkBindingData{type=strings, value=7}} for remote test [TestTask] "
132+
+ "across known test inputs, use a magic wang to provide a test double"));
133+
}
134+
108135
@Test
109136
void testGetNameShouldDeriveFromId() {
110137
TestNode node = new TestNode(null, emptyMap());
@@ -121,8 +148,9 @@ protected TestNode(Function<Input, Output> runFn, Map<Input, Output> fixedOutput
121148
JacksonSdkType.of(Input.class),
122149
JacksonSdkType.of(Output.class),
123150
runFn,
151+
true,
124152
fixedOutputs,
125-
(id, inType, outType, f, m) -> new TestNode(f, m),
153+
(id, inType, outType, f, fProvided, m) -> new TestNode(f, m),
126154
"test",
127155
"a magic wang");
128156
}

0 commit comments

Comments
 (0)