Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
Add ReadWriterOptionType (#363)
Browse files Browse the repository at this point in the history
  • Loading branch information
zero323 authored Feb 3, 2020
1 parent d7f9181 commit 6fc6da4
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 16 deletions.
22 changes: 21 additions & 1 deletion test-data/unit/sql-session.test
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ schema = StructType([
])

# Invalid product should have StructType schema
spark.createDataFrame(data, IntegerType()) # E: Argument 1 to "createDataFrame" of "SparkSession" has incompatible type "List[Tuple[str, int]]"; expected "Union[RDD[Union[Union[datetime, date], Union[bool, int, float, str], Decimal]], Iterable[Union[Union[datetime, date], Union[bool, int, float, str], Decimal]]]"
spark.createDataFrame(data, IntegerType()) # E: Argument 1 to "createDataFrame" of "SparkSession" has incompatible type "List[Tuple[str, int]]"; expected "Union[RDD[Union[Union[datetime, date], Union[bool, float, int, str], Decimal]], Iterable[Union[Union[datetime, date], Union[bool, float, int, str], Decimal]]]"

# This shouldn't type check, though is technically speaking valid
# because samplingRatio is ignored
Expand All @@ -76,3 +76,23 @@ spark.createDataFrame(data, schema, samplingRatio=0.1) # E: No overload variant
# N: def [RowLike in (List[Any], Tuple[Any, ...], Row)] createDataFrame(self, data: Union[RDD[RowLike], Iterable[RowLike]], schema: Union[List[str], Tuple[str, ...]] = ..., verifySchema: bool = ...) -> DataFrame \
# N: <4 more similar overloads not shown, out of 6 total overloads>
[out]


[case readWriterOptions]
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

spark.read.option("foo", True).option("foo", 1).option("foo", 1.0).option("foo", "1")
spark.readStream.option("foo", True).option("foo", 1).option("foo", 1.0).option("foo", "1")

spark.read.options(foo=True, bar=1).options(foo=1.0, bar="1")
spark.readStream.options(foo=True, bar=1).options(foo=1.0, bar="1")

spark.read.load(foo=True)
spark.readStream.load(foo=True)

spark.read.load(foo=["a"]) # E: Argument "foo" to "load" of "DataFrameReader" has incompatible type "List[str]"; expected "Union[bool, float, int, str]"
spark.read.option("foo", (1, )) # E: Argument 2 to "option" of "DataFrameReader" has incompatible type "Tuple[int]"; expected "Union[bool, float, int, str]"
spark.read.options(bar={1}) # E: Argument "bar" to "options" of "DataFrameReader" has incompatible type "Set[int]"; expected "Union[bool, float, int, str]"
[out]
2 changes: 2 additions & 0 deletions third_party/3/pyspark/_typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ from typing_extensions import Protocol

T = TypeVar('T', covariant=True)

PrimitiveType = Union[bool, float, int, str]

class SupportsIAdd(Protocol):
def __iadd__(self, other: SupportsIAdd) -> SupportsIAdd: ...

Expand Down
4 changes: 3 additions & 1 deletion third_party/3/pyspark/sql/_typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ from types import FunctionType
import datetime
import decimal

from pyspark._typing import PrimitiveType
import pyspark.sql.column
import pyspark.sql.types
from pyspark.sql.column import Column
Expand All @@ -16,9 +17,10 @@ import pandas.core.series # type: ignore
ColumnOrName = Union[pyspark.sql.column.Column, str]
DecimalLiteral = decimal.Decimal
DateTimeLiteral = Union[datetime.datetime, datetime.date]
LiteralType = Union[bool, int, float, str]
LiteralType = PrimitiveType
AtomicDataTypeOrString = Union[pyspark.sql.types.AtomicType, str]
DataTypeOrString = Union[pyspark.sql.types.DataType, str]
ReadWriterOptionType = PrimitiveType

RowLike = TypeVar("RowLike", List[Any], Tuple[Any, ...], pyspark.sql.types.Row)

Expand Down
15 changes: 8 additions & 7 deletions third_party/3/pyspark/sql/readwriter.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import overload
from typing import Any, Dict, List, Optional, Tuple, Union

from pyspark.sql._typing import ReadWriterOptionType
from pyspark.sql.dataframe import DataFrame
from pyspark.rdd import RDD
from pyspark.sql.context import SQLContext
Expand All @@ -19,11 +20,11 @@ class DataFrameReader(OptionUtils):
def format(self, source: str) -> DataFrameReader: ...
def schema(self, schema: Union[StructType, str]) -> DataFrameReader: ...
def option(self, key: str, value: Union[bool, float, int, str]) -> DataFrameReader: ...
def options(self, **options: str) -> DataFrameReader: ...
def load(self, path: Optional[PathOrPaths] = ..., format: Optional[str] = ..., schema: Optional[StructType] = ..., **options: str) -> DataFrame: ...
def options(self, **options: ReadWriterOptionType) -> DataFrameReader: ...
def load(self, path: Optional[PathOrPaths] = ..., format: Optional[str] = ..., schema: Optional[StructType] = ..., **options: ReadWriterOptionType) -> DataFrame: ...
def json(self, path: Union[str, List[str], RDD[str]], schema: Optional[StructType] = ..., primitivesAsString: Optional[Union[bool, str]] = ..., prefersDecimal: Optional[Union[bool, str]] = ..., allowComments: Optional[Union[bool, str]] = ..., allowUnquotedFieldNames: Optional[Union[bool, str]] = ..., allowSingleQuotes: Optional[Union[bool, str]] = ..., allowNumericLeadingZero: Optional[Union[bool, str]] = ..., allowBackslashEscapingAnyCharacter: Optional[Union[bool, str]] = ..., mode: Optional[str] = ..., columnNameOfCorruptRecord: Optional[str] = ..., dateFormat: Optional[str] = ..., timestampFormat: Optional[str] = ..., multiLine: Optional[Union[bool, str]] = ..., allowUnquotedControlChars: Optional[Union[bool, str]] = ..., lineSep: Optional[str] = ..., samplingRatio: Optional[Union[float, str]] = ..., dropFieldIfAllNull: Optional[Union[bool, str]] = ..., encoding: Optional[str] = ..., locale: Optional[str] = ..., recursiveFileLookup: Optional[bool] = ...) -> DataFrame: ...
def table(self, tableName: str) -> DataFrame: ...
def parquet(self, *paths: str, **options: str) -> DataFrame: ...
def parquet(self, *paths: str, **options: ReadWriterOptionType) -> DataFrame: ...
def text(self, paths: PathOrPaths, wholetext: bool = ..., lineSep: Optional[str] = ..., recursiveFileLookup: Optional[bool] = ...) -> DataFrame: ...
def csv(self, path: PathOrPaths, schema: Optional[StructType] = ..., sep: Optional[str] = ..., encoding: Optional[str] = ..., quote: Optional[str] = ..., escape: Optional[str] = ..., comment: Optional[str] = ..., header: Optional[Union[bool, str]] = ..., inferSchema: Optional[Union[bool, str]] = ..., ignoreLeadingWhiteSpace: Optional[Union[bool, str]] = ..., ignoreTrailingWhiteSpace: Optional[Union[bool, str]] = ..., nullValue: Optional[str] = ..., nanValue: Optional[str] = ..., positiveInf: Optional[str] = ..., negativeInf: Optional[str] = ..., dateFormat: Optional[str] = ..., timestampFormat: Optional[str] = ..., maxColumns: Optional[int] = ..., maxCharsPerColumn: Optional[int] = ..., maxMalformedLogPerPartition: Optional[int] = ..., mode: Optional[str] = ..., columnNameOfCorruptRecord: Optional[str] = ..., multiLine: Optional[Union[bool, str]] = ..., charToEscapeQuoteEscaping: Optional[str] = ..., samplingRatio: Optional[Union[float, str]] = ..., enforceSchema: Optional[Union[bool, str]] = ..., emptyValue: Optional[str] = ..., locale: Optional[str] = ..., lineSep: Optional[str] = ...) -> DataFrame: ...
def orc(self, path: PathOrPaths, mergeSchema: Optional[bool] = ..., recursiveFileLookup: Optional[bool] = ...) -> DataFrame: ...
Expand All @@ -38,8 +39,8 @@ class DataFrameWriter(OptionUtils):
def __init__(self, df: DataFrame) -> None: ...
def mode(self, saveMode: str) -> DataFrameWriter: ...
def format(self, source: str) -> DataFrameWriter: ...
def option(self, key: str, value: Union[bool, float, int, str]) -> DataFrameWriter: ...
def options(self, **options: str) -> DataFrameWriter: ...
def option(self, key: str, value: ReadWriterOptionType) -> DataFrameWriter: ...
def options(self, **options: ReadWriterOptionType) -> DataFrameWriter: ...
@overload
def partitionBy(self, *cols: str) -> DataFrameWriter: ...
@overload
Expand All @@ -52,9 +53,9 @@ class DataFrameWriter(OptionUtils):
def sortBy(self, col: str, *cols: str) -> DataFrameWriter: ...
@overload
def sortBy(self, col: TupleOrListOfString) -> DataFrameWriter: ...
def save(self, path: Optional[str] = ..., format: Optional[str] = ..., mode: Optional[str] = ..., partitionBy: Optional[List[str]] = ..., **options: str) -> None: ...
def save(self, path: Optional[str] = ..., format: Optional[str] = ..., mode: Optional[str] = ..., partitionBy: Optional[List[str]] = ..., **options: ReadWriterOptionType) -> None: ...
def insertInto(self, tableName: str, overwrite: Optional[bool] = ...) -> None: ...
def saveAsTable(self, name: str, format: Optional[str] = ..., mode: Optional[str] = ..., partitionBy: Optional[List[str]] = ..., **options: str) -> None: ...
def saveAsTable(self, name: str, format: Optional[str] = ..., mode: Optional[str] = ..., partitionBy: Optional[List[str]] = ..., **options: ReadWriterOptionType) -> None: ...
def json(self, path: str, mode: Optional[str] = ..., compression: Optional[str] = ..., dateFormat: Optional[str] = ..., timestampFormat: Optional[str] = ..., lineSep: Optional[str] = ..., encoding: Optional[str] = ..., ignoreNullFields: Optional[bool] = ...) -> None: ...
def parquet(self, path: str, mode: Optional[str] = ..., partitionBy: Optional[List[str]] = ..., compression: Optional[str] = ...) -> None: ...
def text(self, path: str, compression: Optional[str] = ..., lineSep: Optional[str] = ...) -> None: ...
Expand Down
14 changes: 7 additions & 7 deletions third_party/3/pyspark/sql/streaming.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import overload
from typing import Any, Callable, Dict, List, Optional, Union

from pyspark.sql._typing import SupportsProcess
from pyspark.sql._typing import SupportsProcess, ReadWriterOptionType
from pyspark.sql.context import SQLContext
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.readwriter import OptionUtils
Expand Down Expand Up @@ -47,9 +47,9 @@ class DataStreamReader(OptionUtils):
def __init__(self, spark: SQLContext) -> None: ...
def format(self, source: str) -> DataStreamReader: ...
def schema(self, schema: Union[StructType, str]) -> DataStreamReader: ...
def option(self, key: str, value: Union[bool, float, int, str]) -> DataStreamReader: ...
def options(self, **options: str) -> DataStreamReader: ...
def load(self, path: Optional[str] = ..., format: Optional[str] = ..., schema: Optional[StructType] = ..., **options: str) -> DataFrame: ...
def option(self, key: str, value: ReadWriterOptionType) -> DataStreamReader: ...
def options(self, **options: ReadWriterOptionType) -> DataStreamReader: ...
def load(self, path: Optional[str] = ..., format: Optional[str] = ..., schema: Optional[StructType] = ..., **options: ReadWriterOptionType) -> DataFrame: ...
def json(self, path: str, schema: Optional[str] = ..., primitivesAsString: Optional[Union[bool, str]] = ..., prefersDecimal: Optional[Union[bool, str]] = ..., allowComments: Optional[Union[bool, str]] = ..., allowUnquotedFieldNames: Optional[Union[bool, str]] = ..., allowSingleQuotes: Optional[Union[bool, str]] = ..., allowNumericLeadingZero: Optional[Union[bool, str]] = ..., allowBackslashEscapingAnyCharacter: Optional[Union[bool, str]] = ..., mode: Optional[str] = ..., columnNameOfCorruptRecord: Optional[str] = ..., dateFormat: Optional[str] = ..., timestampFormat: Optional[str] = ..., multiLine: Optional[Union[bool, str]] = ..., allowUnquotedControlChars: Optional[Union[bool, str]] = ..., lineSep: Optional[str] = ..., locale: Optional[str] = ..., dropFieldIfAllNull: Optional[Union[bool, str]] = ..., encoding: Optional[str] = ..., recursiveFileLookup: Optional[bool] = ...) -> DataFrame: ...
def orc(self, path: str, mergeSchema: Optional[bool] = ..., recursiveFileLookup: Optional[bool] = ...) -> DataFrame: ...
def parquet(self, path: str, mergeSchema: Optional[bool] = ..., recursiveFileLookup: Optional[bool] = ...) -> DataFrame: ...
Expand All @@ -60,8 +60,8 @@ class DataStreamWriter:
def __init__(self, df: DataFrame) -> None: ...
def outputMode(self, outputMode: str) -> DataStreamWriter: ...
def format(self, source: str) -> DataStreamWriter: ...
def option(self, key: str, value: Union[bool, float, int, str]) -> DataStreamWriter: ...
def options(self, **options: str) -> DataStreamWriter: ...
def option(self, key: str, value: ReadWriterOptionType) -> DataStreamWriter: ...
def options(self, **options: ReadWriterOptionType) -> DataStreamWriter: ...
@overload
def partitionBy(self, *cols: str) -> DataStreamWriter: ...
@overload
Expand All @@ -73,7 +73,7 @@ class DataStreamWriter:
def trigger(self, once: bool) -> DataStreamWriter: ...
@overload
def trigger(self, continuous: bool) -> DataStreamWriter: ...
def start(self, path: Optional[str] = ..., format: Optional[str] = ..., outputMode: Optional[str] = ..., partitionBy: Optional[Union[str, List[str]]] = ..., queryName: Optional[str] = ..., **options: str) -> StreamingQuery: ...
def start(self, path: Optional[str] = ..., format: Optional[str] = ..., outputMode: Optional[str] = ..., partitionBy: Optional[Union[str, List[str]]] = ..., queryName: Optional[str] = ..., **options: ReadWriterOptionType) -> StreamingQuery: ...
@overload
def foreach(self, f: Callable[[Row], None]) -> DataStreamWriter: ...
@overload
Expand Down

0 comments on commit 6fc6da4

Please sign in to comment.