16
16
*/
17
17
package io .substrait .spark .logical
18
18
19
- import io .substrait .spark .{DefaultRelVisitor , SparkExtension , ToSparkType , ToSubstraitType }
19
+ import io .substrait .spark .{DefaultRelVisitor , FileHolder , SparkExtension , ToSparkType , ToSubstraitType }
20
20
import io .substrait .spark .expression ._
21
21
22
- import org .apache .spark .sql .SparkSession
23
- import org .apache .spark .sql .catalyst .InternalRow
22
+ import org .apache .spark .sql .{ SaveMode , SparkSession }
23
+ import org .apache .spark .sql .catalyst .{ InternalRow , TableIdentifier }
24
24
import org .apache .spark .sql .catalyst .analysis .{caseSensitiveResolution , MultiInstanceRelation , UnresolvedRelation }
25
+ import org .apache .spark .sql .catalyst .catalog .{CatalogStorageFormat , CatalogTable , CatalogTableType }
25
26
import org .apache .spark .sql .catalyst .expressions ._
26
27
import org .apache .spark .sql .catalyst .expressions .aggregate .{AggregateExpression , AggregateFunction }
27
28
import org .apache .spark .sql .catalyst .plans .{FullOuter , Inner , LeftAnti , LeftOuter , LeftSemi , RightOuter }
28
29
import org .apache .spark .sql .catalyst .plans .logical ._
29
30
import org .apache .spark .sql .catalyst .util .toPrettySQL
30
31
import org .apache .spark .sql .execution .QueryExecution
31
- import org .apache .spark .sql .execution .datasources .{HadoopFsRelation , InMemoryFileIndex , LogicalRelation }
32
+ import org .apache .spark .sql .execution .command .{CreateDataSourceTableAsSelectCommand , DataWritingCommand , LeafRunnableCommand }
33
+ import org .apache .spark .sql .execution .datasources .{FileFormat => SparkFileFormat , HadoopFsRelation , InMemoryFileIndex , InsertIntoHadoopFsRelationCommand , LogicalRelation , V1Writes }
32
34
import org .apache .spark .sql .execution .datasources .csv .CSVFileFormat
33
35
import org .apache .spark .sql .execution .datasources .orc .OrcFileFormat
34
36
import org .apache .spark .sql .execution .datasources .parquet .ParquetFileFormat
@@ -40,8 +42,9 @@ import io.substrait.{expression => exp}
40
42
import io .substrait .expression .{Expression => SExpression }
41
43
import io .substrait .plan .Plan
42
44
import io .substrait .relation
45
+ import io .substrait .relation .{ExtensionWrite , LocalFiles , NamedWrite }
46
+ import io .substrait .relation .AbstractWriteRel .{CreateMode , WriteOp }
43
47
import io .substrait .relation .Expand .{ConsistentField , SwitchingField }
44
- import io .substrait .relation .LocalFiles
45
48
import io .substrait .relation .Set .SetOp
46
49
import io .substrait .relation .files .FileFormat
47
50
import io .substrait .util .EmptyVisitationContext
@@ -388,7 +391,28 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate())
388
391
if (formats.length != 1 ) {
389
392
throw new UnsupportedOperationException (s " All files must have the same format " )
390
393
}
391
- val (format, options) = formats.head match {
394
+ val (format, options) = convertFileFormat(formats.head)
395
+ new LogicalRelation (
396
+ relation = HadoopFsRelation (
397
+ location = new InMemoryFileIndex (
398
+ spark,
399
+ localFiles.getItems.asScala.map(i => new Path (i.getPath.get())),
400
+ Map (),
401
+ Some (schema)),
402
+ partitionSchema = new StructType (),
403
+ dataSchema = schema,
404
+ bucketSpec = None ,
405
+ fileFormat = format,
406
+ options = options
407
+ )(spark),
408
+ output = output,
409
+ catalogTable = None ,
410
+ isStreaming = false
411
+ )
412
+ }
413
+
414
+ def convertFileFormat (fileFormat : FileFormat ): (SparkFileFormat , Map [String , String ]) = {
415
+ fileFormat match {
392
416
case csv : FileFormat .DelimiterSeparatedTextReadOptions =>
393
417
val opts = scala.collection.mutable.Map [String , String ](
394
418
" delimiter" -> csv.getFieldDelimiter,
@@ -409,23 +433,98 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate())
409
433
case format =>
410
434
throw new UnsupportedOperationException (s " File format not currently supported: $format" )
411
435
}
412
- new LogicalRelation (
413
- relation = HadoopFsRelation (
414
- location = new InMemoryFileIndex (
415
- spark,
416
- localFiles.getItems.asScala.map(i => new Path (i.getPath.get())),
417
- Map (),
418
- Some (schema)),
419
- partitionSchema = new StructType (),
420
- dataSchema = schema,
421
- bucketSpec = None ,
422
- fileFormat = format,
423
- options = options
424
- )(spark),
425
- output = output,
426
- catalogTable = None ,
427
- isStreaming = false
436
+ }
437
+
438
+ override def visit (write : NamedWrite , context : EmptyVisitationContext ): LogicalPlan = {
439
+ val child = write.getInput.accept(this , context)
440
+
441
+ val (table, database, catalog) = write.getNames.asScala match {
442
+ case Seq (table) => (table, None , None )
443
+ case Seq (database, table) => (table, Some (database), None )
444
+ case Seq (catalog, database, table) => (table, Some (database), Some (catalog))
445
+ case names =>
446
+ throw new UnsupportedOperationException (
447
+ s " NamedWrite requires up to three names ([[catalog,] database,] table): $names" )
448
+ }
449
+ val id = TableIdentifier (table, database, catalog)
450
+ val catalogTable = CatalogTable (
451
+ id,
452
+ CatalogTableType .MANAGED ,
453
+ CatalogStorageFormat .empty,
454
+ new StructType (),
455
+ Some (" parquet" )
428
456
)
457
+ write.getOperation match {
458
+ case WriteOp .CTAS =>
459
+ withChild(child) {
460
+ CreateDataSourceTableAsSelectCommand (
461
+ catalogTable,
462
+ saveMode(write.getCreateMode),
463
+ child,
464
+ write.getTableSchema.names().asScala
465
+ )
466
+ }
467
+ case op => throw new UnsupportedOperationException (s " Write mode $op not supported " )
468
+ }
469
+ }
470
+
471
+ override def visit (write : ExtensionWrite , context : EmptyVisitationContext ): LogicalPlan = {
472
+ val child = write.getInput.accept(this , context)
473
+ val mode = write.getOperation match {
474
+ case WriteOp .INSERT => SaveMode .Append
475
+ case WriteOp .UPDATE => SaveMode .Overwrite
476
+ case op => throw new UnsupportedOperationException (s " Write mode $op not supported " )
477
+ }
478
+
479
+ val file = write.getDetail match {
480
+ case FileHolder (f) => f
481
+ case d =>
482
+ throw new UnsupportedOperationException (s " Unsupported extension detail: ${d.getClass}" )
483
+ }
484
+
485
+ if (file.getPath.isEmpty)
486
+ throw new UnsupportedOperationException (" The File extension detail must contain a Path field" )
487
+ if (file.getFileFormat.isEmpty)
488
+ throw new UnsupportedOperationException (
489
+ " The File extension detail must contain a FileFormat field" )
490
+
491
+ val (format, options) = convertFileFormat(file.getFileFormat.get)
492
+
493
+ val name = file.getPath.get.split('/' ).reverse.head
494
+ val id = TableIdentifier (name)
495
+ val table = CatalogTable (
496
+ id,
497
+ CatalogTableType .MANAGED ,
498
+ CatalogStorageFormat .empty,
499
+ new StructType (),
500
+ None
501
+ )
502
+
503
+ withChild(child) {
504
+ V1Writes .apply(
505
+ InsertIntoHadoopFsRelationCommand (
506
+ outputPath = new Path (file.getPath.get),
507
+ staticPartitions = Map (),
508
+ ifPartitionNotExists = false ,
509
+ partitionColumns = Seq .empty,
510
+ bucketSpec = None ,
511
+ fileFormat = format,
512
+ options = options,
513
+ query = child,
514
+ mode = mode,
515
+ catalogTable = Some (table),
516
+ fileIndex = None ,
517
+ outputColumnNames = write.getTableSchema.names.asScala
518
+ ))
519
+ }
520
+ }
521
+
522
+ private def saveMode (mode : CreateMode ): SaveMode = mode match {
523
+ case CreateMode .APPEND_IF_EXISTS => SaveMode .Append
524
+ case CreateMode .REPLACE_IF_EXISTS => SaveMode .Overwrite
525
+ case CreateMode .ERROR_IF_EXISTS => SaveMode .ErrorIfExists
526
+ case CreateMode .IGNORE_IF_EXISTS => SaveMode .Ignore
527
+ case _ => throw new UnsupportedOperationException (s " Unsupported mode: $mode" )
429
528
}
430
529
431
530
private def withChild (child : LogicalPlan * )(body : => LogicalPlan ): LogicalPlan = {
@@ -501,6 +600,9 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate())
501
600
aggregate.groupingExpressions,
502
601
renameAndCastExprs(aggregate.aggregateExpressions),
503
602
aggregate.child)
603
+ // if the plan represents a 'write' command, then leave as is
604
+ case _ : DataWritingCommand => logicalPlan
605
+ case _ : LeafRunnableCommand => logicalPlan
504
606
// Otherwise we add a project to enforce correct names in the output
505
607
case _ => Project (renameAndCastExprs(logicalPlan.output), logicalPlan)
506
608
}
0 commit comments