Skip to content

Commit 4fe9279

Browse files
authored
Add binding collection/map for scala (#219)
Signed-off-by: Nelson Arapé <[email protected]>
1 parent 20f90cb commit 4fe9279

File tree

9 files changed

+344
-14
lines changed

9 files changed

+344
-14
lines changed

flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingData.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ public abstract class SdkBindingData<T> {
5858
/**
5959
* Returns a version of this {@code SdkBindingData} with a new type.
6060
*
61-
* @param newType the {@link SdkLiteralType} type to be casted to
61+
* @param newType the {@link SdkLiteralType} type to be cast to
6262
* @param castFunction function to apply to the value to be converted to the new type
63-
* @return the type casted version of this instance
63+
* @return the type cast version of this instance
6464
* @param <NewT> the java or scala type for the corresponding to {@code newType}
6565
* @throws UnsupportedOperationException if a cast cannot be performed over this instance.
6666
*/
@@ -196,7 +196,7 @@ private static <T> BindingCollection<T> create(
196196
SdkLiteralType<T> elementType, List<SdkBindingData<T>> bindingCollection) {
197197
checkIncompatibleTypes(elementType, bindingCollection);
198198
return new AutoValue_SdkBindingData_BindingCollection<>(
199-
collections(elementType), bindingCollection);
199+
collections(elementType), List.copyOf(bindingCollection));
200200
}
201201

202202
@Override
@@ -230,7 +230,7 @@ public abstract static class BindingMap<T> extends SdkBindingData<Map<String, T>
230230
private static <T> BindingMap<T> create(
231231
SdkLiteralType<T> valuesType, Map<String, SdkBindingData<T>> bindingMap) {
232232
checkIncompatibleTypes(valuesType, bindingMap.values());
233-
return new AutoValue_SdkBindingData_BindingMap<>(maps(valuesType), bindingMap);
233+
return new AutoValue_SdkBindingData_BindingMap<>(maps(valuesType), Map.copyOf(bindingMap));
234234
}
235235

236236
@Override
@@ -258,7 +258,7 @@ public final String toString() {
258258
}
259259
}
260260

261-
private static <T> void checkIncompatibleTypes(
261+
protected static <T> void checkIncompatibleTypes(
262262
SdkLiteralType<T> elementType, Collection<SdkBindingData<T>> elements) {
263263
List<LiteralType> incompatibleTypes =
264264
elements.stream()

flytekit-java/src/main/java/org/flyte/flytekit/SdkBindingDataFactory.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,9 +286,9 @@ public static <T> SdkBindingData<List<T>> ofBindingCollection(
286286
* Creates a {@code SdkBindingData} for a flyte map given a java {@code Map<String,
287287
* SdkBindingData<T>>} and a {@link SdkLiteralType} for the values of the map.
288288
*
289-
* @param valueMap collection to represent on this data.
290289
* @param valuesType a {@link SdkLiteralType} expressing the types for the values of the map. The
291-
* keys are always String. LiteralType.Kind#MAP_VALUE_TYPE}.
290+
* keys are always String.
291+
* @param valueMap map to represent on this data.
292292
* @return the new {@code SdkBindingData}
293293
*/
294294
public static <T> SdkBindingData<Map<String, T>> ofBindingMap(

flytekit-scala-tests/pom.xml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,6 @@
4343

4444
<dependencies>
4545
<!-- test -->
46-
<dependency>
47-
<groupId>org.flyte</groupId>
48-
<artifactId>flytekit-api</artifactId>
49-
<scope>test</scope>
50-
</dependency>
5146
<dependency>
5247
<groupId>org.flyte</groupId>
5348
<artifactId>flytekit-examples</artifactId>
@@ -63,6 +58,11 @@
6358
<artifactId>junit-jupiter</artifactId>
6459
<scope>test</scope>
6560
</dependency>
61+
<dependency>
62+
<groupId>org.hamcrest</groupId>
63+
<artifactId>hamcrest</artifactId>
64+
<scope>test</scope>
65+
</dependency>
6666
</dependencies>
6767

6868
<build>
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
/*
2+
* Copyright 2021 Flyte Authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing,
11+
* software distributed under the License is distributed on an
12+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13+
* KIND, either express or implied. See the License for the
14+
* specific language governing permissions and limitations
15+
* under the License.
16+
*/
17+
package org.flyte.flytekitscala
18+
19+
import org.flyte.flytekit.SdkBindingData
20+
import org.flyte.flytekitscala.SdkLiteralTypes._
21+
import org.hamcrest.MatcherAssert.assertThat
22+
import org.hamcrest.Matchers.equalTo
23+
import org.junit.jupiter.api.Test
24+
25+
import java.time.ZoneOffset.UTC
26+
import java.time.{Duration, LocalDate}
27+
class SdkBindingDataFactoryTest {
28+
@Test
29+
def testOfBindingCollection(): Unit = {
30+
val collection = List(42L, 1337L)
31+
val bindingCollection = collection.map(SdkBindingDataFactory.of)
32+
val output =
33+
SdkBindingDataFactory.ofBindingCollection(integers(), bindingCollection)
34+
assertThat(output.get, equalTo(collection))
35+
assertThat(output.`type`, equalTo(collections(integers())))
36+
}
37+
38+
@Test
39+
def testOfBindingCollection_empty(): Unit = {
40+
val output = SdkBindingDataFactory.ofBindingCollection(integers(), List())
41+
assertThat(output.get, equalTo(List[Long]()))
42+
assertThat(output.`type`, equalTo(collections(integers())))
43+
}
44+
45+
@Test
46+
def testOfStringCollection(): Unit = {
47+
val expectedValue = List("1", "2")
48+
val output = SdkBindingDataFactory.ofStringCollection(expectedValue)
49+
assertThat(output.get, equalTo(expectedValue))
50+
assertThat(output.`type`, equalTo(collections(strings())))
51+
}
52+
53+
@Test
54+
def testOfFloatCollection(): Unit = {
55+
val expectedValue = List(1.1, 1.2)
56+
val output = SdkBindingDataFactory.ofFloatCollection(expectedValue)
57+
assertThat(output.get, equalTo(expectedValue))
58+
assertThat(output.`type`, equalTo(collections(floats())))
59+
}
60+
61+
@Test
62+
def testOfIntegerCollection(): Unit = {
63+
val expectedValue = List(1L, 2L)
64+
val output = SdkBindingDataFactory.ofIntegerCollection(expectedValue)
65+
assertThat(output.get, equalTo(expectedValue))
66+
assertThat(output.`type`, equalTo(collections(integers())))
67+
}
68+
69+
@Test
70+
def testOfBooleanCollection(): Unit = {
71+
val expectedValue = List(true, false)
72+
val output = SdkBindingDataFactory.ofBooleanCollection(expectedValue)
73+
assertThat(output.get, equalTo(expectedValue))
74+
assertThat(output.`type`, equalTo(collections(booleans())))
75+
}
76+
77+
@Test
78+
def testOfDurationCollection(): Unit = {
79+
val expectedValue = List(Duration.ofDays(1), Duration.ofDays(2))
80+
val output = SdkBindingDataFactory.ofDurationCollection(expectedValue)
81+
assertThat(output.get, equalTo(expectedValue))
82+
assertThat(output.`type`, equalTo(collections(durations())))
83+
}
84+
85+
@Test
86+
def testOfDatetimeCollection(): Unit = {
87+
val first = LocalDate.of(2022, 1, 16).atStartOfDay.toInstant(UTC)
88+
val second = LocalDate.of(2022, 1, 17).atStartOfDay.toInstant(UTC)
89+
val expectedValue = List(first, second)
90+
val output = SdkBindingDataFactory.ofDatetimeCollection(expectedValue)
91+
assertThat(output.get, equalTo(expectedValue))
92+
assertThat(output.`type`, equalTo(collections(datetimes())))
93+
}
94+
95+
@Test
96+
def testOfBindingMap(): Unit = {
97+
val input = Map(
98+
"a" -> SdkBindingDataFactory.of(42L),
99+
"b" -> SdkBindingDataFactory.of(1337L)
100+
)
101+
val output = SdkBindingDataFactory.ofBindingMap(integers(), input)
102+
assertThat(output.get, equalTo(Map("a" -> 42L, "b" -> 1337L)))
103+
assertThat(output.`type`, equalTo(maps(integers())))
104+
}
105+
106+
@Test
107+
def testOfBindingMap_empty(): Unit = {
108+
val output = SdkBindingDataFactory.ofBindingMap(
109+
integers(),
110+
Map[String, SdkBindingData[Long]]()
111+
)
112+
assertThat(output.get, equalTo(Map[String, Long]()))
113+
assertThat(output.`type`, equalTo(maps(integers())))
114+
}
115+
116+
@Test def testOfStringMap(): Unit = {
117+
val expectedValue = Map("a" -> "1", "b" -> "2")
118+
val output: SdkBindingData[Map[String, String]] =
119+
SdkBindingDataFactory.ofStringMap(expectedValue)
120+
assertThat(output.get, equalTo(expectedValue))
121+
assertThat(output.`type`, equalTo(maps(strings())))
122+
}
123+
124+
@Test def testOfFloatMap(): Unit = {
125+
val expectedValue = Map("a" -> 1.1, "b" -> 1.2)
126+
val output = SdkBindingDataFactory.ofFloatMap(expectedValue)
127+
assertThat(output.get, equalTo(expectedValue))
128+
assertThat(output.`type`, equalTo(maps(floats())))
129+
}
130+
131+
@Test def testOfIntegerMap(): Unit = {
132+
val expectedValue = Map("a" -> 1L, "b" -> 2L)
133+
val output = SdkBindingDataFactory.ofIntegerMap(expectedValue)
134+
assertThat(output.get, equalTo(expectedValue))
135+
assertThat(output.`type`, equalTo(maps(integers())))
136+
}
137+
138+
@Test def testOfBooleanMap(): Unit = {
139+
val expectedValue = Map("a" -> true, "b" -> false)
140+
val output = SdkBindingDataFactory.ofBooleanMap(expectedValue)
141+
assertThat(output.get, equalTo(expectedValue))
142+
assertThat(output.`type`, equalTo(maps(booleans())))
143+
}
144+
145+
@Test def testOfDurationMap(): Unit = {
146+
val expectedValue =
147+
Map("a" -> Duration.ofDays(1), "b" -> Duration.ofDays(2))
148+
val output = SdkBindingDataFactory.ofDurationMap(expectedValue)
149+
assertThat(output.get, equalTo(expectedValue))
150+
assertThat(output.`type`, equalTo(maps(durations())))
151+
}
152+
153+
@Test def testOfDatetimeMap(): Unit = {
154+
val first = LocalDate.of(2022, 1, 16).atStartOfDay.toInstant(UTC)
155+
val second = LocalDate.of(2022, 1, 17).atStartOfDay.toInstant(UTC)
156+
val expectedValue = Map("a" -> first, "b" -> second)
157+
val output = SdkBindingDataFactory.ofDatetimeMap(expectedValue)
158+
assertThat(output.get, equalTo(expectedValue))
159+
assertThat(output.`type`, equalTo(maps(datetimes())))
160+
}
161+
}

flytekit-scala_2.13/pom.xml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@
4343

4444
<dependencies>
4545
<!-- compile -->
46+
<dependency>
47+
<groupId>org.flyte</groupId>
48+
<artifactId>flytekit-api</artifactId>
49+
</dependency>
4650
<dependency>
4751
<groupId>org.flyte</groupId>
4852
<artifactId>flytekit-java</artifactId>
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Copyright 2021 Flyte Authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing,
11+
* software distributed under the License is distributed on an
12+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13+
* KIND, either express or implied. See the License for the
14+
* specific language governing permissions and limitations
15+
* under the License.
16+
*/
17+
package org.flyte.flytekit
18+
19+
import org.flyte.api.v1.BindingData
20+
import org.flyte.flytekitscala.SdkLiteralTypes.collections
21+
22+
import java.util.function
23+
import scala.collection.JavaConverters._
24+
25+
private[flyte] class BindingCollection[T](
26+
elementType: SdkLiteralType[T],
27+
bindingCollection: List[SdkBindingData[T]]
28+
) extends SdkBindingData[List[T]] {
29+
SdkBindingData.checkIncompatibleTypes(elementType, bindingCollection.asJava)
30+
31+
override def idl: BindingData =
32+
BindingData.ofCollection(bindingCollection.map(_.idl()).asJava)
33+
34+
override def `type`: SdkLiteralType[List[T]] = collections(elementType)
35+
36+
override def get(): List[T] = bindingCollection.map(_.get())
37+
38+
override def as[NewT](
39+
newType: SdkLiteralType[NewT],
40+
castFunction: function.Function[List[T], NewT]
41+
): SdkBindingData[NewT] =
42+
throw new UnsupportedOperationException(
43+
"SdkBindingData of binding collection cannot be casted"
44+
)
45+
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Copyright 2021 Flyte Authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing,
11+
* software distributed under the License is distributed on an
12+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13+
* KIND, either express or implied. See the License for the
14+
* specific language governing permissions and limitations
15+
* under the License.
16+
*/
17+
package org.flyte.flytekit
18+
19+
import org.flyte.api.v1.BindingData
20+
import org.flyte.flytekitscala.SdkLiteralTypes.maps
21+
22+
import java.util.function
23+
import scala.collection.JavaConverters._
24+
25+
private[flyte] class BindingMap[T](
26+
valuesType: SdkLiteralType[T],
27+
bindingMap: Map[String, SdkBindingData[T]]
28+
) extends SdkBindingData[Map[String, T]] {
29+
SdkBindingData.checkIncompatibleTypes(
30+
valuesType,
31+
bindingMap.values.toSeq.asJava
32+
)
33+
34+
override def idl: BindingData =
35+
BindingData.ofMap(bindingMap.mapValues(_.idl()).toMap.asJava)
36+
37+
override def `type`: SdkLiteralType[Map[String, T]] = maps(valuesType)
38+
39+
override def get(): Map[String, T] = bindingMap.mapValues(_.get()).toMap
40+
41+
override def as[NewT](
42+
newType: SdkLiteralType[NewT],
43+
castFunction: function.Function[Map[String, T], NewT]
44+
): SdkBindingData[NewT] =
45+
throw new UnsupportedOperationException(
46+
"SdkBindingData of binding map cannot be casted"
47+
)
48+
49+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
* Copyright 2021 Flyte Authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing,
11+
* software distributed under the License is distributed on an
12+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
13+
* KIND, either express or implied. See the License for the
14+
* specific language governing permissions and limitations
15+
* under the License.
16+
*/
17+
package org.flyte
18+
19+
/** Contains subclasses for [[SdkBindingData]]. We are forced to define this
20+
* package here because [[SdkBindingData#idl()]] is package private (we don´t
21+
* want to expose it to users). We cannot make it protected either as it would
22+
* be good for the own object but both implementations deal with list or maps
23+
* of [[SdkBindingData]] and therefore cannot call this method because it is in
24+
* a different class.
25+
*
26+
* This is not ideal because we are splitting the flytekit package in two maven
27+
* modules. This would create problems when we decide to add java 9 style
28+
* modules.
29+
*/
30+
package object flytekit {}

0 commit comments

Comments
 (0)