diff --git a/spark-3.1-spanner-lib/src/main/java/com/google/cloud/spark/spanner/DefaultSource.java b/spark-3.1-spanner-lib/src/main/java/com/google/cloud/spark/spanner/DefaultSource.java deleted file mode 100644 index f34891fa..00000000 --- a/spark-3.1-spanner-lib/src/main/java/com/google/cloud/spark/spanner/DefaultSource.java +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright 2023 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package com.google.cloud.spark.spanner; - -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.sources.BaseRelation; -import org.apache.spark.sql.sources.DataSourceRegister; -import org.apache.spark.sql.sources.RelationProvider; -import org.apache.spark.sql.sources.SchemaRelationProvider; -import org.apache.spark.sql.types.StructType; -import scala.collection.immutable.Map; - -public class DefaultSource implements DataSourceRegister, RelationProvider, SchemaRelationProvider { - @Override - public String shortName() { - return "cloud-spanner"; - } - - /* - * This method overrides SchemaRelationProvider.createRelation - */ - @Override - public BaseRelation createRelation( - SQLContext sqlContext, Map parameters, StructType schema) { - return new SpannerBaseRelation(sqlContext, parameters, schema); - } - - /* - * This method overrides RelationProvider.createRelation - */ - @Override - public BaseRelation createRelation(SQLContext sqlContext, Map parameters) { - return new SpannerBaseRelation(sqlContext, parameters, null); - } -} diff --git a/spark-3.1-spanner-lib/src/main/java/com/google/cloud/spark/spanner/SpannerBaseRelation.java b/spark-3.1-spanner-lib/src/main/java/com/google/cloud/spark/spanner/SpannerBaseRelation.java index d9e90c00..b1fa4ab9 100644 --- a/spark-3.1-spanner-lib/src/main/java/com/google/cloud/spark/spanner/SpannerBaseRelation.java +++ b/spark-3.1-spanner-lib/src/main/java/com/google/cloud/spark/spanner/SpannerBaseRelation.java @@ -15,24 +15,27 @@ package com.google.cloud.spark.spanner; import java.util.HashMap; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SaveMode; import org.apache.spark.sql.sources.BaseRelation; import org.apache.spark.sql.sources.Filter; import org.apache.spark.sql.types.StructType; -import scala.collection.immutable.Map; -/* - * SpannerBaseRelation implements BaseRelation. - */ public class SpannerBaseRelation extends BaseRelation { private final SQLContext sqlContext; - private final StructType schema; private final SpannerScanner scan; + private final Dataset dataToWrite; - public SpannerBaseRelation(SQLContext sqlContext, Map opts, StructType schema) { - this.scan = new SpannerScanner(scalaToJavaMap(opts)); + public SpannerBaseRelation( + SQLContext sqlContext, + SaveMode mode, + scala.collection.immutable.Map parameters, + Dataset data) { + this.scan = new SpannerScanner(scalaToJavaMap(parameters)); this.sqlContext = sqlContext; - this.schema = schema; + this.dataToWrite = data; } /* @@ -43,7 +46,7 @@ public SpannerBaseRelation(SQLContext sqlContext, Map opts, Stru */ @Override public boolean needConversion() { - return true; + return false; } @Override @@ -67,9 +70,6 @@ public SQLContext sqlContext() { @Override public StructType schema() { - if (this.schema == null) { - return this.schema; - } return this.scan.readSchema(); } diff --git a/spark-3.1-spanner-lib/src/main/java/com/google/cloud/spark/spanner/SpannerScanner.java b/spark-3.1-spanner-lib/src/main/java/com/google/cloud/spark/spanner/SpannerScanner.java index 6a68d7da..1bf2c544 100644 --- a/spark-3.1-spanner-lib/src/main/java/com/google/cloud/spark/spanner/SpannerScanner.java +++ b/spark-3.1-spanner-lib/src/main/java/com/google/cloud/spark/spanner/SpannerScanner.java @@ -26,7 +26,7 @@ public class SpannerScanner implements Scan { private SpannerTable spannerTable; public SpannerScanner(Map opts) { - this.spannerTable = new SpannerTable(opts); + this.spannerTable = new SpannerTable(null, opts); } @Override diff --git a/spark-3.1-spanner-lib/src/main/java/com/google/cloud/spark/spanner/SpannerSpark.java b/spark-3.1-spanner-lib/src/main/java/com/google/cloud/spark/spanner/SpannerSpark.java index c34663ef..d4df7347 100644 --- a/spark-3.1-spanner-lib/src/main/java/com/google/cloud/spark/spanner/SpannerSpark.java +++ b/spark-3.1-spanner-lib/src/main/java/com/google/cloud/spark/spanner/SpannerSpark.java @@ -29,7 +29,6 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -40,20 +39,12 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SparkSession; -import org.apache.spark.sql.connector.catalog.SupportsRead; -import org.apache.spark.sql.connector.catalog.Table; -import org.apache.spark.sql.connector.catalog.TableCapability; -import org.apache.spark.sql.connector.catalog.TableProvider; -import org.apache.spark.sql.connector.expressions.Transform; -import org.apache.spark.sql.connector.read.ScanBuilder; -import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.util.CaseInsensitiveStringMap; /* * SpannerSpark is the main entry point to * connect Cloud Spanner with Apache Spark. */ -public class SpannerSpark implements TableProvider, SupportsRead { +public class SpannerSpark { private BatchClient batchClient; private Map properties; @@ -175,51 +166,4 @@ private Row resultSetIndexToRow(ResultSet rs) { return RowFactory.create(objects.toArray(new Object[0])); } - - @Override - public Transform[] inferPartitioning(CaseInsensitiveStringMap options) { - return null; - } - - @Override - public StructType inferSchema(CaseInsensitiveStringMap options) { - SpannerTable st = new SpannerTable(properties); - return st.schema(); - } - - @Override - public boolean supportsExternalMetadata() { - return false; - } - - @Override - public Table getTable( - StructType schema, Transform[] partitioning, Map properties) { - return new SpannerTable(properties); - } - - @Override - public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { - return new SpannerScanBuilder(options); - } - - @Override - public Set capabilities() { - return null; - } - - @Override - public Map properties() { - return this.properties; - } - - @Override - public String name() { - return "cloud-spanner"; - } - - @Override - public StructType schema() { - return null; - } } diff --git a/spark-3.1-spanner-lib/src/main/java/com/google/cloud/spark/spanner/SpannerTable.java b/spark-3.1-spanner-lib/src/main/java/com/google/cloud/spark/spanner/SpannerTable.java index 202612e7..1eab0120 100644 --- a/spark-3.1-spanner-lib/src/main/java/com/google/cloud/spark/spanner/SpannerTable.java +++ b/spark-3.1-spanner-lib/src/main/java/com/google/cloud/spark/spanner/SpannerTable.java @@ -22,20 +22,24 @@ import java.util.HashSet; import java.util.Map; import java.util.Set; +import org.apache.spark.sql.connector.catalog.SupportsRead; import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.connector.catalog.TableCapability; +import org.apache.spark.sql.connector.read.ScanBuilder; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; /* * SpannerTable implements Table. */ -public class SpannerTable implements Table { +public class SpannerTable implements Table, SupportsRead { private String tableName; private StructType tableSchema; - public SpannerTable(Map properties) { + public SpannerTable(StructType providedSchema, Map properties) { + // TODO: Use providedSchema in building the SpannerTable. String connUriPrefix = "cloudspanner:"; String emulatorHost = properties.get("emulatorHost"); if (emulatorHost != null) { @@ -89,7 +93,7 @@ public StructType createSchema(String tableName, ResultSet rs) { // Integer ordinalPosition = column.getInt(1); boolean isNullable = row.getBoolean(2); DataType catalogType = SpannerTable.ofSpannerStrType(row.getString(3), isNullable); - schema = schema.add(columnName, catalogType, isNullable); + schema = schema.add(columnName, catalogType, isNullable, "" /* No comments for the text */); } this.tableSchema = schema; return schema; @@ -180,4 +184,9 @@ public Set capabilities() { public String name() { return this.tableName; } + + @Override + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { + return new SpannerScanBuilder(options); + } } diff --git a/spark-3.1-spanner-lib/src/main/java/com/google/cloud/spark/spanner/Spark31SpannerTableProvider.java b/spark-3.1-spanner-lib/src/main/java/com/google/cloud/spark/spanner/Spark31SpannerTableProvider.java index 614c4bac..a3bdc7e3 100644 --- a/spark-3.1-spanner-lib/src/main/java/com/google/cloud/spark/spanner/Spark31SpannerTableProvider.java +++ b/spark-3.1-spanner-lib/src/main/java/com/google/cloud/spark/spanner/Spark31SpannerTableProvider.java @@ -16,33 +16,69 @@ package com.google.cloud.spark.spanner; import java.util.Map; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.SaveMode; import org.apache.spark.sql.connector.catalog.Table; import org.apache.spark.sql.connector.catalog.TableProvider; import org.apache.spark.sql.connector.expressions.Transform; +import org.apache.spark.sql.sources.BaseRelation; +import org.apache.spark.sql.sources.CreatableRelationProvider; import org.apache.spark.sql.sources.DataSourceRegister; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.util.CaseInsensitiveStringMap; -public class Spark31SpannerTableProvider implements DataSourceRegister, TableProvider { +public class Spark31SpannerTableProvider + implements CreatableRelationProvider, DataSourceRegister, TableProvider { + /* + * Infers the schema of the table identified by the given options. + */ @Override public StructType inferSchema(CaseInsensitiveStringMap options) { - return null; + SpannerTable st = new SpannerTable(null, options); + return st.schema(); } + /* + * Returns a Table instance with the specified table schema, + * partitioning and properties to perform a read or write. + */ @Override public Table getTable( StructType schema, Transform[] partitioning, Map properties) { - return null; + return new SpannerTable(schema, properties); } + /* + * Returns true if the source has the ability of + * accepting external table metadata when getting tables. + */ @Override public boolean supportsExternalMetadata() { return false; } + /* + * Implements DataSourceRegister.shortName(). This method allows Spark to match + * the DataSource when spark.read(...).format("cloud-spanner") is invoked. + */ @Override public String shortName() { return "cloud-spanner"; } + + /* + * Implements CreateRelationProvider.createRelation which essentially saves + * a DataFrame to the destination using the data-source specific parameters. + */ + @Override + public BaseRelation createRelation( + SQLContext sqlContext, + SaveMode mode, + scala.collection.immutable.Map parameters, + Dataset data) { + return new SpannerBaseRelation(sqlContext, mode, parameters, data); + } } diff --git a/spark-3.1-spanner-lib/src/test/java/com/google/cloud/spark/spanner/SpannerSparkTest.java b/spark-3.1-spanner-lib/src/test/java/com/google/cloud/spark/spanner/SpannerTableTest.java similarity index 88% rename from spark-3.1-spanner-lib/src/test/java/com/google/cloud/spark/spanner/SpannerSparkTest.java rename to spark-3.1-spanner-lib/src/test/java/com/google/cloud/spark/spanner/SpannerTableTest.java index 17f5e80e..39bf5c25 100644 --- a/spark-3.1-spanner-lib/src/test/java/com/google/cloud/spark/spanner/SpannerSparkTest.java +++ b/spark-3.1-spanner-lib/src/test/java/com/google/cloud/spark/spanner/SpannerTableTest.java @@ -12,15 +12,12 @@ import com.google.cloud.spanner.InstanceInfo; import com.google.cloud.spanner.Spanner; import com.google.cloud.spanner.SpannerOptions; -import com.google.cloud.spark.spanner.SpannerSpark; import com.google.cloud.spark.spanner.SpannerTable; import com.google.spanner.admin.database.v1.CreateDatabaseMetadata; import com.google.spanner.admin.instance.v1.CreateInstanceMetadata; import java.util.Arrays; import java.util.HashMap; import java.util.Map; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Row; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -31,7 +28,7 @@ import org.junit.runners.JUnit4; @RunWith(JUnit4.class) -public class SpannerSparkTest { +public class SpannerTableTest { String databaseId = System.getenv("SPANNER_DATABASE_ID"); String instanceId = System.getenv("SPANNER_INSTANCE_ID"); @@ -108,9 +105,9 @@ private Map connectionProperties() { } @Test - public void testSpannerTable() { + public void createSchema() { Map props = connectionProperties(); - SpannerTable st = new SpannerTable(props); + SpannerTable st = new SpannerTable(null, props); StructType actualSchema = st.schema(); StructType expectSchema = new StructType( @@ -123,16 +120,12 @@ public void testSpannerTable() { new StructField("E", DataTypes.createDecimalType(38, 9), true, null), new StructField( "F", DataTypes.createArrayType(DataTypes.StringType, true), true, null)) - .toArray(new StructField[6])); + .toArray(new StructField[0])); - assertEquals(expectSchema, actualSchema); - } - - @Test - public void testSpannerSpark() { - Map props = connectionProperties(); - SpannerSpark sp = new SpannerSpark(props); - - Dataset data = sp.execute(null, "SELECT * FROM ATable"); + // Object.equals fails for StructType with fields so we'll + // firstly compare lengths, then fieldNames then the simpleString. + assertEquals(expectSchema.length(), actualSchema.length()); + assertEquals(expectSchema.fieldNames(), actualSchema.fieldNames()); + assertEquals(expectSchema.simpleString(), actualSchema.simpleString()); } }