Skip to content

Commit 142c574

Browse files
feat(spark): add LogicalRDD support (#451)
Support conversion of dataframes that are created using the Spark createDataFrame() method. This produces a LogicalRDD in the query plan which can be converted to a substrait VirtualTableScan. Introduces overridable `rddLimit` to guard against serialising very large datasets. Signed-off-by: Andrew Coleman <[email protected]>
1 parent 7cf1ccf commit 142c574

File tree

3 files changed

+91
-10
lines changed

3 files changed

+91
-10
lines changed

spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@ import io.substrait.spark.expression._
2121

2222
import org.apache.spark.internal.Logging
2323
import org.apache.spark.sql.SaveMode
24+
import org.apache.spark.sql.catalyst.InternalRow
2425
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, HiveTableRelation}
2526
import org.apache.spark.sql.catalyst.expressions._
2627
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Average, Sum}
2728
import org.apache.spark.sql.catalyst.plans._
2829
import org.apache.spark.sql.catalyst.plans.logical._
30+
import org.apache.spark.sql.execution.LogicalRDD
2931
import org.apache.spark.sql.execution.command.CreateDataSourceTableAsSelectCommand
3032
import org.apache.spark.sql.execution.datasources.{FileFormat => DSFileFormat, HadoopFsRelation, InsertIntoHadoopFsRelationCommand, LogicalRelation, V1WriteCommand, WriteFiles}
3133
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
@@ -52,7 +54,7 @@ import io.substrait.utils.Util
5254
import java.util
5355
import java.util.{Collections, Optional}
5456

55-
import scala.collection.JavaConverters.{asJavaIterableConverter, seqAsJavaList}
57+
import scala.collection.JavaConverters.asJavaIterableConverter
5658
import scala.collection.mutable
5759
import scala.collection.mutable.ArrayBuffer
5860

@@ -64,6 +66,10 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
6466

6567
private val existenceJoins = scala.collection.mutable.Map[Long, SExpression.InPredicate]()
6668

69+
private var _rddLimit = 100
70+
def rddLimit: Int = _rddLimit
71+
def rddLimit_=(rddLimit: Int): Unit = _rddLimit = rddLimit
72+
6773
def getExistenceJoin(id: Long): Option[SExpression.InPredicate] = existenceJoins.get(id)
6874

6975
override def default(p: LogicalPlan): relation.Rel = p match {
@@ -439,23 +445,25 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
439445
.build
440446
namedScan
441447
}
442-
private def buildVirtualTableScan(localRelation: LocalRelation): relation.AbstractReadRel = {
443-
val namedStruct = ToSubstraitType.toNamedStruct(localRelation.schema)
448+
private def buildVirtualTableScan(
449+
schema: StructType,
450+
data: Seq[InternalRow]): relation.AbstractReadRel = {
451+
val namedStruct = ToSubstraitType.toNamedStruct(schema)
444452

445-
if (localRelation.data.isEmpty) {
453+
if (data.isEmpty) {
446454
relation.EmptyScan.builder().initialSchema(namedStruct).build()
447455
} else {
448456
relation.VirtualTableScan
449457
.builder()
450458
.initialSchema(namedStruct)
451459
.addAllRows(
452-
localRelation.data
460+
data
453461
.map(
454462
row => {
455463
var idx = 0
456464
val buf = new ArrayBuffer[SExpression.Literal](row.numFields)
457465
while (idx < row.numFields) {
458-
val dt = localRelation.schema(idx).dataType
466+
val dt = schema(idx).dataType
459467
val l = Literal.apply(row.get(idx, dt), dt)
460468
buf += ToSubstraitLiteral.apply(l)
461469
idx += 1
@@ -528,7 +536,14 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
528536
case hiveTableRelation: HiveTableRelation =>
529537
tableNames = hiveTableRelation.tableMeta.identifier.unquotedString.split("\\.").toList
530538
buildNamedScan(hiveTableRelation.schema, tableNames)
531-
case localRelation: LocalRelation => buildVirtualTableScan(localRelation)
539+
case localRelation: LocalRelation =>
540+
buildVirtualTableScan(localRelation.schema, localRelation.data)
541+
case rdd: LogicalRDD =>
542+
if (rdd.rdd.count() > _rddLimit) {
543+
logWarning(
544+
s"LogicalRDD relation contains ${rdd.rdd.count()} rows. Truncating to ${_rddLimit}. This limit can be changed by setting the `rddLimit` property on this ToSubstraitRel instance.")
545+
}
546+
buildVirtualTableScan(rdd.schema, rdd.rdd.take(_rddLimit))
532547
case logicalRelation: LogicalRelation =>
533548
logicalRelation.relation match {
534549
case fsRelation: HadoopFsRelation =>

spark/src/test/scala/io/substrait/spark/RelationsSuite.scala

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package io.substrait.spark
22

33
import org.apache.spark.SparkFunSuite
4+
import org.apache.spark.sql.Row
45
import org.apache.spark.sql.test.SharedSparkSession
6+
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
57

68
class RelationsSuite extends SparkFunSuite with SharedSparkSession with SubstraitPlanTestBase {
79

@@ -38,4 +40,63 @@ class RelationsSuite extends SparkFunSuite with SharedSparkSession with Substrai
3840
"select * from (values (1, cast(struct(1, 'a') as struct<f1: int, f2: string>)) as table(int_col, col))"
3941
)
4042
}
43+
44+
test("create_dataset - LocalRelation") {
45+
val spark = this.spark
46+
import spark.implicits._
47+
48+
val df = Seq(
49+
(1, "one"),
50+
(2, "two"),
51+
(3, "three")
52+
).toDF("id", "value")
53+
54+
assertSparkSubstraitRelRoundTrip(df.queryExecution.optimizedPlan)
55+
}
56+
57+
test("createdataframe - LogicalRDD") {
58+
val data = Seq(
59+
Row(1, "one"),
60+
Row(2, "two"),
61+
Row(3, "three")
62+
)
63+
64+
val schema = StructType(
65+
List(
66+
StructField("id", IntegerType, true),
67+
StructField("value", StringType, true)
68+
))
69+
70+
val df = spark.createDataFrame(
71+
spark.sparkContext.parallelize(data),
72+
schema
73+
)
74+
75+
assertSparkSubstraitRelRoundTrip(df.queryExecution.optimizedPlan)
76+
}
77+
78+
test("Limit RDD size") {
79+
val data = Seq(
80+
Row(1, "one"),
81+
Row(2, "two"),
82+
Row(3, "three"),
83+
Row(4, "four")
84+
)
85+
86+
val schema = StructType(
87+
List(
88+
StructField("id", IntegerType, true),
89+
StructField("value", StringType, true)
90+
))
91+
92+
val df = spark.createDataFrame(
93+
spark.sparkContext.parallelize(data),
94+
schema
95+
)
96+
97+
assertResult(4)(df.count())
98+
99+
val plan = assertSparkSubstraitRelRoundTrip(df.queryExecution.optimizedPlan, 2)
100+
assertResult(2)(plan.count())
101+
}
41102
}

spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,14 @@ trait SubstraitPlanTestBase { self: SharedSparkSession =>
5757

5858
def assertSqlSubstraitRelRoundTrip(query: String): LogicalPlan = {
5959
val sparkPlan = plan(query)
60+
assertSparkSubstraitRelRoundTrip(sparkPlan)
61+
}
6062

63+
def assertSparkSubstraitRelRoundTrip(sparkPlan: LogicalPlan, rddLimit: Int = 10): LogicalPlan = {
6164
// convert spark logical plan to substrait
62-
val substraitRel = new ToSubstraitRel().visit(sparkPlan)
65+
val toSubstrait = new ToSubstraitRel
66+
toSubstrait.rddLimit = rddLimit
67+
val substraitRel = toSubstrait.visit(sparkPlan)
6368

6469
// Serialize to protobuf byte array
6570
val extensionCollector = new ExtensionCollector
@@ -77,15 +82,15 @@ trait SubstraitPlanTestBase { self: SharedSparkSession =>
7782
require(sparkPlan2.resolved)
7883

7984
// and back to substrait again
80-
val substraitRel3 = new ToSubstraitRel().visit(sparkPlan2)
85+
val substraitRel3 = toSubstrait.visit(sparkPlan2)
8186

8287
// compare with original substrait plan to ensure it round-tripped (via proto bytes) correctly
8388
substraitRel3.shouldEqualPlainly(substraitRel)
8489

8590
// Do one more roundtrip, this time with Substrait Plan object which contains also names,
8691
// to test that the Spark schemas match. This in some cases adds an extra Project
8792
// to rename fields, which then would break the round trip test we do above.
88-
val substraitPlan = new ToSubstraitRel().convert(sparkPlan)
93+
val substraitPlan = toSubstrait.convert(sparkPlan)
8994
val sparkPlan3 = toLogicalPlan.convert(substraitPlan);
9095
require(sparkPlan3.resolved);
9196

0 commit comments

Comments
 (0)