Skip to content

Commit 1954fc8

Browse files
feat(spark): support insert/append operations (#429)
Partial implementation of the substrait WriteRel relation in order to support conversion of Spark insert and append operations. Spark does not support update or delete operations. Added tests for inserting to in-memory database tables and appending to local files. Signed-off-by: Andrew Coleman <[email protected]>
1 parent 76684d8 commit 1954fc8

File tree

4 files changed

+366
-64
lines changed

4 files changed

+366
-64
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package io.substrait.spark
2+
3+
import com.google.protobuf
4+
import io.substrait.extension.ExtensionLookup
5+
import io.substrait.relation.{ProtoRelConverter, RelProtoConverter}
6+
import io.substrait.relation.Extension.WriteExtensionObject
7+
import io.substrait.relation.files.FileOrFiles
8+
9+
case class FileHolder(fileOrFiles: FileOrFiles) extends WriteExtensionObject {
10+
11+
override def toProto(converter: RelProtoConverter): protobuf.Any = {
12+
protobuf.Any.pack(fileOrFiles.toProto)
13+
}
14+
}
15+
16+
class FileHolderHandlingProtoRelConverter(lookup: ExtensionLookup)
17+
extends ProtoRelConverter(lookup) {
18+
override def detailFromWriteExtensionObject(any: protobuf.Any): WriteExtensionObject = {
19+
FileHolder(
20+
newFileOrFiles(any.unpack(classOf[io.substrait.proto.ReadRel.LocalFiles.FileOrFiles])))
21+
}
22+
}

spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala

Lines changed: 124 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,21 @@
1616
*/
1717
package io.substrait.spark.logical
1818

19-
import io.substrait.spark.{DefaultRelVisitor, SparkExtension, ToSparkType, ToSubstraitType}
19+
import io.substrait.spark.{DefaultRelVisitor, FileHolder, SparkExtension, ToSparkType, ToSubstraitType}
2020
import io.substrait.spark.expression._
2121

22-
import org.apache.spark.sql.SparkSession
23-
import org.apache.spark.sql.catalyst.InternalRow
22+
import org.apache.spark.sql.{SaveMode, SparkSession}
23+
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
2424
import org.apache.spark.sql.catalyst.analysis.{caseSensitiveResolution, MultiInstanceRelation, UnresolvedRelation}
25+
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType}
2526
import org.apache.spark.sql.catalyst.expressions._
2627
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction}
2728
import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter}
2829
import org.apache.spark.sql.catalyst.plans.logical._
2930
import org.apache.spark.sql.catalyst.util.toPrettySQL
3031
import org.apache.spark.sql.execution.QueryExecution
31-
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, InMemoryFileIndex, LogicalRelation}
32+
import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, DataWritingCommand, LeafRunnableCommand}
33+
import org.apache.spark.sql.execution.datasources.{FileFormat => SparkFileFormat, HadoopFsRelation, InMemoryFileIndex, InsertIntoHadoopFsRelationCommand, LogicalRelation, V1Writes}
3234
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
3335
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
3436
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
@@ -40,8 +42,9 @@ import io.substrait.{expression => exp}
4042
import io.substrait.expression.{Expression => SExpression}
4143
import io.substrait.plan.Plan
4244
import io.substrait.relation
45+
import io.substrait.relation.{ExtensionWrite, LocalFiles, NamedWrite}
46+
import io.substrait.relation.AbstractWriteRel.{CreateMode, WriteOp}
4347
import io.substrait.relation.Expand.{ConsistentField, SwitchingField}
44-
import io.substrait.relation.LocalFiles
4548
import io.substrait.relation.Set.SetOp
4649
import io.substrait.relation.files.FileFormat
4750
import io.substrait.util.EmptyVisitationContext
@@ -388,7 +391,28 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate())
388391
if (formats.length != 1) {
389392
throw new UnsupportedOperationException(s"All files must have the same format")
390393
}
391-
val (format, options) = formats.head match {
394+
val (format, options) = convertFileFormat(formats.head)
395+
new LogicalRelation(
396+
relation = HadoopFsRelation(
397+
location = new InMemoryFileIndex(
398+
spark,
399+
localFiles.getItems.asScala.map(i => new Path(i.getPath.get())),
400+
Map(),
401+
Some(schema)),
402+
partitionSchema = new StructType(),
403+
dataSchema = schema,
404+
bucketSpec = None,
405+
fileFormat = format,
406+
options = options
407+
)(spark),
408+
output = output,
409+
catalogTable = None,
410+
isStreaming = false
411+
)
412+
}
413+
414+
def convertFileFormat(fileFormat: FileFormat): (SparkFileFormat, Map[String, String]) = {
415+
fileFormat match {
392416
case csv: FileFormat.DelimiterSeparatedTextReadOptions =>
393417
val opts = scala.collection.mutable.Map[String, String](
394418
"delimiter" -> csv.getFieldDelimiter,
@@ -409,23 +433,98 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate())
409433
case format =>
410434
throw new UnsupportedOperationException(s"File format not currently supported: $format")
411435
}
412-
new LogicalRelation(
413-
relation = HadoopFsRelation(
414-
location = new InMemoryFileIndex(
415-
spark,
416-
localFiles.getItems.asScala.map(i => new Path(i.getPath.get())),
417-
Map(),
418-
Some(schema)),
419-
partitionSchema = new StructType(),
420-
dataSchema = schema,
421-
bucketSpec = None,
422-
fileFormat = format,
423-
options = options
424-
)(spark),
425-
output = output,
426-
catalogTable = None,
427-
isStreaming = false
436+
}
437+
438+
override def visit(write: NamedWrite, context: EmptyVisitationContext): LogicalPlan = {
439+
val child = write.getInput.accept(this, context)
440+
441+
val (table, database, catalog) = write.getNames.asScala match {
442+
case Seq(table) => (table, None, None)
443+
case Seq(database, table) => (table, Some(database), None)
444+
case Seq(catalog, database, table) => (table, Some(database), Some(catalog))
445+
case names =>
446+
throw new UnsupportedOperationException(
447+
s"NamedWrite requires up to three names ([[catalog,] database,] table): $names")
448+
}
449+
val id = TableIdentifier(table, database, catalog)
450+
val catalogTable = CatalogTable(
451+
id,
452+
CatalogTableType.MANAGED,
453+
CatalogStorageFormat.empty,
454+
new StructType(),
455+
Some("parquet")
428456
)
457+
write.getOperation match {
458+
case WriteOp.CTAS =>
459+
withChild(child) {
460+
CreateDataSourceTableAsSelectCommand(
461+
catalogTable,
462+
saveMode(write.getCreateMode),
463+
child,
464+
write.getTableSchema.names().asScala
465+
)
466+
}
467+
case op => throw new UnsupportedOperationException(s"Write mode $op not supported")
468+
}
469+
}
470+
471+
override def visit(write: ExtensionWrite, context: EmptyVisitationContext): LogicalPlan = {
472+
val child = write.getInput.accept(this, context)
473+
val mode = write.getOperation match {
474+
case WriteOp.INSERT => SaveMode.Append
475+
case WriteOp.UPDATE => SaveMode.Overwrite
476+
case op => throw new UnsupportedOperationException(s"Write mode $op not supported")
477+
}
478+
479+
val file = write.getDetail match {
480+
case FileHolder(f) => f
481+
case d =>
482+
throw new UnsupportedOperationException(s"Unsupported extension detail: ${d.getClass}")
483+
}
484+
485+
if (file.getPath.isEmpty)
486+
throw new UnsupportedOperationException("The File extension detail must contain a Path field")
487+
if (file.getFileFormat.isEmpty)
488+
throw new UnsupportedOperationException(
489+
"The File extension detail must contain a FileFormat field")
490+
491+
val (format, options) = convertFileFormat(file.getFileFormat.get)
492+
493+
val name = file.getPath.get.split('/').reverse.head
494+
val id = TableIdentifier(name)
495+
val table = CatalogTable(
496+
id,
497+
CatalogTableType.MANAGED,
498+
CatalogStorageFormat.empty,
499+
new StructType(),
500+
None
501+
)
502+
503+
withChild(child) {
504+
V1Writes.apply(
505+
InsertIntoHadoopFsRelationCommand(
506+
outputPath = new Path(file.getPath.get),
507+
staticPartitions = Map(),
508+
ifPartitionNotExists = false,
509+
partitionColumns = Seq.empty,
510+
bucketSpec = None,
511+
fileFormat = format,
512+
options = options,
513+
query = child,
514+
mode = mode,
515+
catalogTable = Some(table),
516+
fileIndex = None,
517+
outputColumnNames = write.getTableSchema.names.asScala
518+
))
519+
}
520+
}
521+
522+
private def saveMode(mode: CreateMode): SaveMode = mode match {
523+
case CreateMode.APPEND_IF_EXISTS => SaveMode.Append
524+
case CreateMode.REPLACE_IF_EXISTS => SaveMode.Overwrite
525+
case CreateMode.ERROR_IF_EXISTS => SaveMode.ErrorIfExists
526+
case CreateMode.IGNORE_IF_EXISTS => SaveMode.Ignore
527+
case _ => throw new UnsupportedOperationException(s"Unsupported mode: $mode")
429528
}
430529

431530
private def withChild(child: LogicalPlan*)(body: => LogicalPlan): LogicalPlan = {
@@ -501,6 +600,9 @@ class ToLogicalPlan(spark: SparkSession = SparkSession.builder().getOrCreate())
501600
aggregate.groupingExpressions,
502601
renameAndCastExprs(aggregate.aggregateExpressions),
503602
aggregate.child)
603+
// if the plan represents a 'write' command, then leave as is
604+
case _: DataWritingCommand => logicalPlan
605+
case _: LeafRunnableCommand => logicalPlan
504606
// Otherwise we add a project to enforce correct names in the output
505607
case _ => Project(renameAndCastExprs(logicalPlan.output), logicalPlan)
506608
}

0 commit comments

Comments
 (0)