Skip to content

Commit

Permalink
Clean up and simpler comparision of StructType
Browse files Browse the repository at this point in the history
  • Loading branch information
odeke-em committed Aug 3, 2023
1 parent 02fc710 commit 1394021
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 139 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,27 @@
package com.google.cloud.spark.spanner;

import java.util.HashMap;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.sources.BaseRelation;
import org.apache.spark.sql.sources.Filter;
import org.apache.spark.sql.types.StructType;
import scala.collection.immutable.Map;

/*
* SpannerBaseRelation implements BaseRelation.
*/
public class SpannerBaseRelation extends BaseRelation {
private final SQLContext sqlContext;
private final StructType schema;
private final SpannerScanner scan;
private final Dataset<Row> dataToWrite;

public SpannerBaseRelation(SQLContext sqlContext, Map<String, String> opts, StructType schema) {
this.scan = new SpannerScanner(scalaToJavaMap(opts));
public SpannerBaseRelation(
SQLContext sqlContext,
SaveMode mode,
scala.collection.immutable.Map<String, String> parameters,
Dataset<Row> data) {
this.scan = new SpannerScanner(scalaToJavaMap(parameters));
this.sqlContext = sqlContext;
this.schema = schema;
this.dataToWrite = data;
}

/*
Expand All @@ -43,7 +46,7 @@ public SpannerBaseRelation(SQLContext sqlContext, Map<String, String> opts, Stru
*/
@Override
public boolean needConversion() {
return true;
return false;
}

@Override
Expand All @@ -67,9 +70,6 @@ public SQLContext sqlContext() {

@Override
public StructType schema() {
if (this.schema == null) {
return this.schema;
}
return this.scan.readSchema();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public class SpannerScanner implements Scan {
private SpannerTable spannerTable;

public SpannerScanner(Map<String, String> opts) {
this.spannerTable = new SpannerTable(opts);
this.spannerTable = new SpannerTable(null, opts);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
Expand All @@ -40,20 +39,12 @@
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.connector.catalog.SupportsRead;
import org.apache.spark.sql.connector.catalog.Table;
import org.apache.spark.sql.connector.catalog.TableCapability;
import org.apache.spark.sql.connector.catalog.TableProvider;
import org.apache.spark.sql.connector.expressions.Transform;
import org.apache.spark.sql.connector.read.ScanBuilder;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;

/*
* SpannerSpark is the main entry point to
* connect Cloud Spanner with Apache Spark.
*/
public class SpannerSpark implements TableProvider, SupportsRead {
public class SpannerSpark {
private BatchClient batchClient;
private Map<String, String> properties;

Expand Down Expand Up @@ -175,51 +166,4 @@ private Row resultSetIndexToRow(ResultSet rs) {

return RowFactory.create(objects.toArray(new Object[0]));
}

@Override
public Transform[] inferPartitioning(CaseInsensitiveStringMap options) {
return null;
}

@Override
public StructType inferSchema(CaseInsensitiveStringMap options) {
SpannerTable st = new SpannerTable(properties);
return st.schema();
}

@Override
public boolean supportsExternalMetadata() {
return false;
}

@Override
public Table getTable(
StructType schema, Transform[] partitioning, Map<String, String> properties) {
return new SpannerTable(properties);
}

@Override
public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) {
return new SpannerScanBuilder(options);
}

@Override
public Set<TableCapability> capabilities() {
return null;
}

@Override
public Map<String, String> properties() {
return this.properties;
}

@Override
public String name() {
return "cloud-spanner";
}

@Override
public StructType schema() {
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,24 @@
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.apache.spark.sql.connector.catalog.SupportsRead;
import org.apache.spark.sql.connector.catalog.Table;
import org.apache.spark.sql.connector.catalog.TableCapability;
import org.apache.spark.sql.connector.read.ScanBuilder;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;

/*
* SpannerTable implements Table.
*/
public class SpannerTable implements Table {
public class SpannerTable implements Table, SupportsRead {
private String tableName;
private StructType tableSchema;

public SpannerTable(Map<String, String> properties) {
public SpannerTable(StructType providedSchema, Map<String, String> properties) {
// TODO: Use providedSchema in building the SpannerTable.
String connUriPrefix = "cloudspanner:";
String emulatorHost = properties.get("emulatorHost");
if (emulatorHost != null) {
Expand Down Expand Up @@ -89,7 +93,7 @@ public StructType createSchema(String tableName, ResultSet rs) {
// Integer ordinalPosition = column.getInt(1);
boolean isNullable = row.getBoolean(2);
DataType catalogType = SpannerTable.ofSpannerStrType(row.getString(3), isNullable);
schema = schema.add(columnName, catalogType, isNullable);
schema = schema.add(columnName, catalogType, isNullable, "" /* No comments for the text */);
}
this.tableSchema = schema;
return schema;
Expand Down Expand Up @@ -180,4 +184,9 @@ public Set<TableCapability> capabilities() {
public String name() {
return this.tableName;
}

@Override
public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) {
return new SpannerScanBuilder(options);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,69 @@
package com.google.cloud.spark.spanner;

import java.util.Map;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.connector.catalog.Table;
import org.apache.spark.sql.connector.catalog.TableProvider;
import org.apache.spark.sql.connector.expressions.Transform;
import org.apache.spark.sql.sources.BaseRelation;
import org.apache.spark.sql.sources.CreatableRelationProvider;
import org.apache.spark.sql.sources.DataSourceRegister;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;

public class Spark31SpannerTableProvider implements DataSourceRegister, TableProvider {
public class Spark31SpannerTableProvider
implements CreatableRelationProvider, DataSourceRegister, TableProvider {

/*
* Infers the schema of the table identified by the given options.
*/
@Override
public StructType inferSchema(CaseInsensitiveStringMap options) {
return null;
SpannerTable st = new SpannerTable(null, options);
return st.schema();
}

/*
* Returns a Table instance with the specified table schema,
* partitioning and properties to perform a read or write.
*/
@Override
public Table getTable(
StructType schema, Transform[] partitioning, Map<String, String> properties) {
return null;
return new SpannerTable(schema, properties);
}

/*
* Returns true if the source has the ability of
* accepting external table metadata when getting tables.
*/
@Override
public boolean supportsExternalMetadata() {
return false;
}

/*
* Implements DataSourceRegister.shortName(). This method allows Spark to match
* the DataSource when spark.read(...).format("cloud-spanner") is invoked.
*/
@Override
public String shortName() {
return "cloud-spanner";
}

/*
* Implements CreateRelationProvider.createRelation which essentially saves
* a DataFrame to the destination using the data-source specific parameters.
*/
@Override
public BaseRelation createRelation(
SQLContext sqlContext,
SaveMode mode,
scala.collection.immutable.Map<String, String> parameters,
Dataset<Row> data) {
return new SpannerBaseRelation(sqlContext, mode, parameters, data);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,12 @@
import com.google.cloud.spanner.InstanceInfo;
import com.google.cloud.spanner.Spanner;
import com.google.cloud.spanner.SpannerOptions;
import com.google.cloud.spark.spanner.SpannerSpark;
import com.google.cloud.spark.spanner.SpannerTable;
import com.google.spanner.admin.database.v1.CreateDatabaseMetadata;
import com.google.spanner.admin.instance.v1.CreateInstanceMetadata;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
Expand All @@ -31,7 +28,7 @@
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public class SpannerSparkTest {
public class SpannerTableTest {

String databaseId = System.getenv("SPANNER_DATABASE_ID");
String instanceId = System.getenv("SPANNER_INSTANCE_ID");
Expand Down Expand Up @@ -108,9 +105,9 @@ private Map<String, String> connectionProperties() {
}

@Test
public void testSpannerTable() {
public void createSchema() {
Map<String, String> props = connectionProperties();
SpannerTable st = new SpannerTable(props);
SpannerTable st = new SpannerTable(null, props);
StructType actualSchema = st.schema();
StructType expectSchema =
new StructType(
Expand All @@ -123,16 +120,12 @@ public void testSpannerTable() {
new StructField("E", DataTypes.createDecimalType(38, 9), true, null),
new StructField(
"F", DataTypes.createArrayType(DataTypes.StringType, true), true, null))
.toArray(new StructField[6]));
.toArray(new StructField[0]));

assertEquals(expectSchema, actualSchema);
}

@Test
public void testSpannerSpark() {
Map<String, String> props = connectionProperties();
SpannerSpark sp = new SpannerSpark(props);

Dataset<Row> data = sp.execute(null, "SELECT * FROM ATable");
// Object.equals fails for StructType with fields so we'll
// firstly compare lengths, then fieldNames then the simpleString.
assertEquals(expectSchema.length(), actualSchema.length());
assertEquals(expectSchema.fieldNames(), actualSchema.fieldNames());
assertEquals(expectSchema.simpleString(), actualSchema.simpleString());
}
}

0 comments on commit 1394021

Please sign in to comment.