@@ -21,11 +21,13 @@ import io.substrait.spark.expression._
21
21
22
22
import org .apache .spark .internal .Logging
23
23
import org .apache .spark .sql .SaveMode
24
+ import org .apache .spark .sql .catalyst .InternalRow
24
25
import org .apache .spark .sql .catalyst .catalog .{CatalogTable , HiveTableRelation }
25
26
import org .apache .spark .sql .catalyst .expressions ._
26
27
import org .apache .spark .sql .catalyst .expressions .aggregate .{AggregateExpression , Average , Sum }
27
28
import org .apache .spark .sql .catalyst .plans ._
28
29
import org .apache .spark .sql .catalyst .plans .logical ._
30
+ import org .apache .spark .sql .execution .LogicalRDD
29
31
import org .apache .spark .sql .execution .command .CreateDataSourceTableAsSelectCommand
30
32
import org .apache .spark .sql .execution .datasources .{FileFormat => DSFileFormat , HadoopFsRelation , InsertIntoHadoopFsRelationCommand , LogicalRelation , V1WriteCommand , WriteFiles }
31
33
import org .apache .spark .sql .execution .datasources .csv .CSVFileFormat
@@ -52,7 +54,7 @@ import io.substrait.utils.Util
52
54
import java .util
53
55
import java .util .{Collections , Optional }
54
56
55
- import scala .collection .JavaConverters .{ asJavaIterableConverter , seqAsJavaList }
57
+ import scala .collection .JavaConverters .asJavaIterableConverter
56
58
import scala .collection .mutable
57
59
import scala .collection .mutable .ArrayBuffer
58
60
@@ -64,6 +66,10 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
64
66
65
67
private val existenceJoins = scala.collection.mutable.Map [Long , SExpression .InPredicate ]()
66
68
69
+ private var _rddLimit = 100
70
+ def rddLimit : Int = _rddLimit
71
+ def rddLimit_= (rddLimit : Int ): Unit = _rddLimit = rddLimit
72
+
67
73
def getExistenceJoin (id : Long ): Option [SExpression .InPredicate ] = existenceJoins.get(id)
68
74
69
75
override def default (p : LogicalPlan ): relation.Rel = p match {
@@ -439,23 +445,25 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
439
445
.build
440
446
namedScan
441
447
}
442
- private def buildVirtualTableScan (localRelation : LocalRelation ): relation.AbstractReadRel = {
443
- val namedStruct = ToSubstraitType .toNamedStruct(localRelation.schema)
448
+ private def buildVirtualTableScan (
449
+ schema : StructType ,
450
+ data : Seq [InternalRow ]): relation.AbstractReadRel = {
451
+ val namedStruct = ToSubstraitType .toNamedStruct(schema)
444
452
445
- if (localRelation. data.isEmpty) {
453
+ if (data.isEmpty) {
446
454
relation.EmptyScan .builder().initialSchema(namedStruct).build()
447
455
} else {
448
456
relation.VirtualTableScan
449
457
.builder()
450
458
.initialSchema(namedStruct)
451
459
.addAllRows(
452
- localRelation. data
460
+ data
453
461
.map(
454
462
row => {
455
463
var idx = 0
456
464
val buf = new ArrayBuffer [SExpression .Literal ](row.numFields)
457
465
while (idx < row.numFields) {
458
- val dt = localRelation. schema(idx).dataType
466
+ val dt = schema(idx).dataType
459
467
val l = Literal .apply(row.get(idx, dt), dt)
460
468
buf += ToSubstraitLiteral .apply(l)
461
469
idx += 1
@@ -528,7 +536,14 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
528
536
case hiveTableRelation : HiveTableRelation =>
529
537
tableNames = hiveTableRelation.tableMeta.identifier.unquotedString.split(" \\ ." ).toList
530
538
buildNamedScan(hiveTableRelation.schema, tableNames)
531
- case localRelation : LocalRelation => buildVirtualTableScan(localRelation)
539
+ case localRelation : LocalRelation =>
540
+ buildVirtualTableScan(localRelation.schema, localRelation.data)
541
+ case rdd : LogicalRDD =>
542
+ if (rdd.rdd.count() > _rddLimit) {
543
+ logWarning(
544
+ s " LogicalRDD relation contains ${rdd.rdd.count()} rows. Truncating to ${_rddLimit}. This limit can be changed by setting the `rddLimit` property on this ToSubstraitRel instance. " )
545
+ }
546
+ buildVirtualTableScan(rdd.schema, rdd.rdd.take(_rddLimit))
532
547
case logicalRelation : LogicalRelation =>
533
548
logicalRelation.relation match {
534
549
case fsRelation : HadoopFsRelation =>
0 commit comments