Skip to content

Commit 66efab8

Browse files
authored
Fetch task template in dynamic workflow task (#254)
* Fetch task template Signed-off-by: Hongxin Liang <[email protected]> * Update integration test Signed-off-by: Hongxin Liang <[email protected]> * Deploy scala examples to staging Signed-off-by: Hongxin Liang <[email protected]> * Expose it Signed-off-by: Hongxin Liang <[email protected]> --------- Signed-off-by: Hongxin Liang <[email protected]>
1 parent 03b87f2 commit 66efab8

File tree

12 files changed

+254
-64
lines changed

12 files changed

+254
-64
lines changed

flytekit-examples-scala/src/main/resources/META-INF/services/org.flyte.flytekit.SdkRunnableTask

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ org.flyte.examples.flytekitscala.HelloWorldTask
22
org.flyte.examples.flytekitscala.SumTask
33
org.flyte.examples.flytekitscala.GreetTask
44
org.flyte.examples.flytekitscala.AddQuestionTask
5+
org.flyte.examples.flytekitscala.NoInputsTask

flytekit-examples/src/main/java/org/flyte/examples/DynamicFibonacciWorkflowTask.java

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
*/
1717
package org.flyte.examples;
1818

19-
import static org.flyte.examples.FlyteEnvironment.DOMAIN;
19+
import static org.flyte.examples.FlyteEnvironment.DEVELOPMENT_DOMAIN;
2020
import static org.flyte.examples.FlyteEnvironment.PROJECT;
21+
import static org.flyte.examples.FlyteEnvironment.STAGING_DOMAIN;
2122

2223
import com.google.auto.service.AutoService;
2324
import com.google.auto.value.AutoValue;
@@ -65,22 +66,34 @@ public Output run(SdkWorkflowBuilder builder, Input input) {
6566
} else if (input.n().get() == 0) {
6667
return Output.create(SdkBindingDataFactory.of(0));
6768
} else {
69+
// remote task that is discoverable in current classpath
6870
SdkNode<Void> hello =
6971
builder.apply(
7072
"hello",
7173
SdkRemoteTask.create(
72-
DOMAIN,
74+
DEVELOPMENT_DOMAIN,
7375
PROJECT,
7476
HelloWorldTask.class.getName(),
7577
SdkTypes.nulls(),
7678
SdkTypes.nulls()));
79+
// a fully remote task
80+
SdkNode<Void> world =
81+
builder.apply(
82+
"world",
83+
SdkRemoteTask.create(
84+
STAGING_DOMAIN,
85+
PROJECT,
86+
"org.flyte.examples.flytekitscala.NoInputsTask",
87+
SdkTypes.nulls(),
88+
SdkTypes.nulls())
89+
.withUpstreamNode(hello));
7790
@Var SdkBindingData<Long> prev = SdkBindingDataFactory.of(0);
7891
@Var SdkBindingData<Long> value = SdkBindingDataFactory.of(1);
7992
for (int i = 2; i <= input.n().get(); i++) {
8093
SdkBindingData<Long> next =
8194
builder
8295
.apply(
83-
"fib-" + i, new SumTask().withUpstreamNode(hello), SumInput.create(value, prev))
96+
"fib-" + i, new SumTask().withUpstreamNode(world), SumInput.create(value, prev))
8497
.getOutputs();
8598
prev = value;
8699
value = next;

flytekit-examples/src/main/java/org/flyte/examples/FlyteEnvironment.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818

1919
public final class FlyteEnvironment {
2020

21-
public static final String DOMAIN = "development";
21+
public static final String DEVELOPMENT_DOMAIN = "development";
22+
public static final String STAGING_DOMAIN = "staging";
2223
public static final String PROJECT = "flytesnacks";
2324

2425
private FlyteEnvironment() {

integration-tests/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@
5858
<artifactId>flytekit-examples</artifactId>
5959
<scope>test</scope>
6060
</dependency>
61+
<dependency>
62+
<groupId>org.flyte</groupId>
63+
<artifactId>flytekit-examples-scala_2.13</artifactId>
64+
<scope>test</scope>
65+
</dependency>
6166
<dependency>
6267
<groupId>org.flyte</groupId>
6368
<artifactId>jflyte</artifactId>

integration-tests/src/test/java/org/flyte/JavaExamplesIT.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.flyte;
1818

1919
import static org.flyte.FlyteContainer.CLIENT;
20+
import static org.flyte.examples.FlyteEnvironment.STAGING_DOMAIN;
2021
import static org.flyte.utils.Literal.ofIntegerMap;
2122
import static org.hamcrest.MatcherAssert.assertThat;
2223
import static org.hamcrest.Matchers.equalTo;
@@ -29,11 +30,13 @@
2930

3031
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
3132
public class JavaExamplesIT {
32-
private static final String CLASSPATH = "flytekit-examples/target/lib";
33+
private static final String CLASSPATH_EXAMPLES = "flytekit-examples/target/lib";
34+
private static final String CLASSPATH_EXAMPLES_SCALA = "flytekit-examples-scala/target/lib";
3335

3436
@BeforeAll
3537
public static void beforeAll() {
36-
CLIENT.registerWorkflows(CLASSPATH);
38+
CLIENT.registerWorkflows(CLASSPATH_EXAMPLES);
39+
CLIENT.registerWorkflows(CLASSPATH_EXAMPLES_SCALA, STAGING_DOMAIN);
3740
}
3841

3942
@Test

integration-tests/src/test/java/org/flyte/utils/FlyteSandboxClient.java

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
*/
1717
package org.flyte.utils;
1818

19-
import static org.flyte.examples.FlyteEnvironment.DOMAIN;
19+
import static org.flyte.examples.FlyteEnvironment.DEVELOPMENT_DOMAIN;
2020
import static org.flyte.examples.FlyteEnvironment.PROJECT;
2121

2222
import flyteidl.admin.ExecutionOuterClass;
@@ -59,7 +59,7 @@ public Literals.LiteralMap createTaskExecution(String name, Literals.LiteralMap
5959
return createExecution(
6060
IdentifierOuterClass.Identifier.newBuilder()
6161
.setResourceType(IdentifierOuterClass.ResourceType.TASK)
62-
.setDomain(DOMAIN)
62+
.setDomain(DEVELOPMENT_DOMAIN)
6363
.setProject(PROJECT)
6464
.setName(name)
6565
.setVersion(version)
@@ -71,7 +71,7 @@ public Literals.LiteralMap createExecution(String name, Literals.LiteralMap inpu
7171
return createExecution(
7272
IdentifierOuterClass.Identifier.newBuilder()
7373
.setResourceType(IdentifierOuterClass.ResourceType.LAUNCH_PLAN)
74-
.setDomain(DOMAIN)
74+
.setDomain(DEVELOPMENT_DOMAIN)
7575
.setProject(PROJECT)
7676
.setName(name)
7777
.setVersion(version)
@@ -84,7 +84,7 @@ private Literals.LiteralMap createExecution(
8484
ExecutionOuterClass.ExecutionCreateResponse response =
8585
stub.createExecution(
8686
ExecutionOuterClass.ExecutionCreateRequest.newBuilder()
87-
.setDomain(DOMAIN)
87+
.setDomain(DEVELOPMENT_DOMAIN)
8888
.setProject(PROJECT)
8989
.setInputs(inputs)
9090
.setSpec(ExecutionOuterClass.ExecutionSpec.newBuilder().setLaunchPlan(id).build())
@@ -148,21 +148,25 @@ private boolean isRunning(Execution.WorkflowExecution.Phase phase) {
148148
return false;
149149
}
150150

151-
public void registerWorkflows(String classpath) {
151+
public void registerWorkflows(String classpath, String domain) {
152152
try {
153153
jflyte(
154154
"jflyte",
155155
"register",
156156
"workflows",
157157
"-p=" + PROJECT,
158-
"-d=" + DOMAIN,
158+
"-d=" + domain,
159159
"-v=" + version,
160160
"-cp=" + classpath);
161161
} catch (Exception e) {
162162
throw new RuntimeException("Could not register workflows from: " + classpath, e);
163163
}
164164
}
165165

166+
public void registerWorkflows(String classpath) {
167+
registerWorkflows(classpath, DEVELOPMENT_DOMAIN);
168+
}
169+
166170
public void serializeWorkflows(String classpath, String folder) {
167171
jflyte("jflyte", "serialize", "workflows", "-cp=" + classpath, "-f=" + folder);
168172
}

jflyte-utils/src/main/java/org/flyte/jflyte/utils/FlyteAdminClient.java

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import flyteidl.admin.LaunchPlanOuterClass;
2626
import flyteidl.admin.TaskOuterClass;
2727
import flyteidl.admin.WorkflowOuterClass;
28-
import flyteidl.core.IdentifierOuterClass;
2928
import flyteidl.service.AdminServiceGrpc;
3029
import io.grpc.Channel;
3130
import io.grpc.ClientInterceptor;
@@ -185,34 +184,38 @@ public TaskIdentifier fetchLatestTaskId(NamedEntityIdentifier taskId) {
185184
return fetchLatestResource(
186185
taskId,
187186
request -> stub.listTasks(request).getTasksList(),
188-
TaskOuterClass.Task::getId,
189-
ProtoUtil::deserializeTaskId);
187+
task -> ProtoUtil.deserializeTaskId(task.getId()));
188+
}
189+
190+
@Nullable
191+
public TaskTemplate fetchLatestTaskTemplate(NamedEntityIdentifier taskId) {
192+
return fetchLatestResource(
193+
taskId,
194+
request -> stub.listTasks(request).getTasksList(),
195+
task -> ProtoUtil.deserialize(task.getClosure().getCompiledTask().getTemplate()));
190196
}
191197

192198
@Nullable
193199
public WorkflowIdentifier fetchLatestWorkflowId(NamedEntityIdentifier workflowId) {
194200
return fetchLatestResource(
195201
workflowId,
196202
request -> stub.listWorkflows(request).getWorkflowsList(),
197-
WorkflowOuterClass.Workflow::getId,
198-
ProtoUtil::deserializeWorkflowId);
203+
workflow -> ProtoUtil.deserializeWorkflowId(workflow.getId()));
199204
}
200205

201206
@Nullable
202207
public LaunchPlanIdentifier fetchLatestLaunchPlanId(NamedEntityIdentifier launchPlanId) {
203208
return fetchLatestResource(
204209
launchPlanId,
205210
request -> stub.listLaunchPlans(request).getLaunchPlansList(),
206-
LaunchPlanOuterClass.LaunchPlan::getId,
207-
ProtoUtil::deserializeLaunchPlanId);
211+
launchPlan -> ProtoUtil.deserializeLaunchPlanId(launchPlan.getId()));
208212
}
209213

210214
@Nullable
211215
private <T, RespT> T fetchLatestResource(
212216
NamedEntityIdentifier nameId,
213217
Function<ResourceListRequest, List<RespT>> performRequestFn,
214-
Function<RespT, IdentifierOuterClass.Identifier> extractIdFn,
215-
Function<IdentifierOuterClass.Identifier, T> deserializeFn) {
218+
Function<RespT, T> deserializeFn) {
216219
ResourceListRequest request =
217220
ResourceListRequest.newBuilder()
218221
.setLimit(1)
@@ -230,8 +233,7 @@ private <T, RespT> T fetchLatestResource(
230233
return null;
231234
}
232235

233-
IdentifierOuterClass.Identifier id = extractIdFn.apply(list.get(0));
234-
return deserializeFn.apply(id);
236+
return deserializeFn.apply(list.get(0));
235237
}
236238

237239
private <T> void idempotentCreate(String label, Object id, GrpcRetries.Retryable<T> retryable) {

jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
import java.util.Set;
4848
import java.util.concurrent.atomic.AtomicInteger;
4949
import java.util.function.BiConsumer;
50+
import java.util.function.Function;
5051
import java.util.function.Supplier;
5152
import java.util.stream.Stream;
5253
import org.flyte.api.v1.Container;
@@ -295,8 +296,8 @@ static void checkCycles(Map<WorkflowIdentifier, WorkflowTemplate> allWorkflows)
295296
checkCycles(
296297
workflowId,
297298
allWorkflows,
298-
/*beingVisited=*/ new HashSet<>(),
299-
/*visited=*/ new HashSet<>()))
299+
/* beingVisited= */ new HashSet<>(),
300+
/* visited= */ new HashSet<>()))
300301
.findFirst();
301302
if (cycle.isPresent()) {
302303
throw new IllegalArgumentException(
@@ -374,8 +375,10 @@ public static Map<WorkflowIdentifier, WorkflowTemplate> collectSubWorkflows(
374375
.collect(toUnmodifiableMap());
375376
}
376377

377-
public static Map<TaskIdentifier, TaskTemplate> collectTasks(
378-
List<Node> rewrittenNodes, Map<TaskIdentifier, TaskTemplate> allTasks) {
378+
public static Map<TaskIdentifier, TaskTemplate> collectDynamicWorkflowTasks(
379+
List<Node> rewrittenNodes,
380+
Map<TaskIdentifier, TaskTemplate> allTasks,
381+
Function<TaskIdentifier, TaskTemplate> remoteTaskTemplateFetcher) {
379382
return collectTaskIds(rewrittenNodes).stream()
380383
// all identifiers should be rewritten at this point
381384
.map(
@@ -389,7 +392,9 @@ public static Map<TaskIdentifier, TaskTemplate> collectTasks(
389392
.distinct()
390393
.map(
391394
taskId -> {
392-
TaskTemplate taskTemplate = allTasks.get(taskId);
395+
TaskTemplate taskTemplate =
396+
Optional.ofNullable(allTasks.get(taskId))
397+
.orElseGet(() -> remoteTaskTemplateFetcher.apply(taskId));
393398

394399
if (taskTemplate == null) {
395400
throw new NoSuchElementException("Can't find referenced task " + taskId);
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Copyright 2023 Flyte Authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing,
11+
* software distributed under the License is distributed on an
12+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13+
* KIND, either express or implied. See the License for the
14+
* specific language governing permissions and limitations
15+
* under the License.
16+
*/
17+
package org.flyte.jflyte.utils;
18+
19+
import static java.util.Collections.emptyMap;
20+
21+
import com.google.common.collect.ImmutableList;
22+
import com.google.common.collect.ImmutableMap;
23+
import org.flyte.api.v1.Container;
24+
import org.flyte.api.v1.KeyValuePair;
25+
import org.flyte.api.v1.RetryStrategy;
26+
import org.flyte.api.v1.SimpleType;
27+
import org.flyte.api.v1.Struct;
28+
import org.flyte.api.v1.TaskTemplate;
29+
import org.flyte.api.v1.TypedInterface;
30+
31+
final class Fixtures {
32+
static final String IMAGE_NAME = "alpine:latest";
33+
static final String COMMAND = "date";
34+
35+
static final Container CONTAINER =
36+
Container.builder()
37+
.command(ImmutableList.of(COMMAND))
38+
.args(ImmutableList.of())
39+
.image(IMAGE_NAME)
40+
.env(ImmutableList.of(KeyValuePair.of("key", "value")))
41+
.build();
42+
static final TypedInterface INTERFACE_ =
43+
TypedInterface.builder()
44+
.inputs(ImmutableMap.of("x", ApiUtils.createVar(SimpleType.STRING)))
45+
.outputs(ImmutableMap.of("y", ApiUtils.createVar(SimpleType.INTEGER)))
46+
.build();
47+
static final RetryStrategy RETRIES = RetryStrategy.builder().retries(4).build();
48+
static final TaskTemplate TASK_TEMPLATE =
49+
TaskTemplate.builder()
50+
.container(CONTAINER)
51+
.type("custom-task")
52+
.interface_(INTERFACE_)
53+
.custom(Struct.of(emptyMap()))
54+
.retries(RETRIES)
55+
.discoverable(false)
56+
.cacheSerializable(false)
57+
.build();
58+
59+
private Fixtures() {
60+
throw new UnsupportedOperationException();
61+
}
62+
}

0 commit comments

Comments
 (0)