Skip to content

Commit

Permalink
Merge branch 'main' into xiaohan/refactor_spanner
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 committed Sep 26, 2024
2 parents a41f07e + 79c2dfc commit 8a247a4
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 13 deletions.
10 changes: 5 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,19 @@
extra_deps = {}

extra_deps['dev'] = [
'datasets>=2.4.0,<3',
'datasets>=2.4.0,<4',
'pyarrow>14.0.0',
'docformatter>=1.4',
'jupyter==1.0.0',
'jupyter==1.1.1',
'pre-commit>=2.18.1,<4',
'pytest==8.3.2',
'pytest==8.3.3',
'pytest_codeblocks==0.17.0',
'pytest-cov>=4,<6',
'toml==0.10.2',
'yamllint==1.35.1',
'moto>=4.0,<6',
'fastapi==0.112.2',
'pydantic==2.8.2',
'fastapi==0.114.2',
'pydantic==2.9.2',
'uvicorn==0.30.6',
'pytest-split==0.9.0',
]
Expand Down
2 changes: 1 addition & 1 deletion streaming/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

"""The Streaming Version."""

__version__ = '0.8.1'
__version__ = '0.10.0.dev0'
4 changes: 3 additions & 1 deletion streaming/base/converters/dataframe_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import (ArrayType, BinaryType, BooleanType, ByteType, DateType,
DayTimeIntervalType, DecimalType, DoubleType, FloatType,
IntegerType, LongType, NullType, ShortType, StringType,
IntegerType, LongType, MapType, NullType, ShortType, StringType,
StructField, StructType, TimestampNTZType, TimestampType)
except ImportError as e:
e.msg = get_import_exception_message(e.name, extra_deps='spark') # pyright: ignore
Expand Down Expand Up @@ -70,6 +70,8 @@ def is_json_compatible(data_type: Any):
return all(is_json_compatible(field.dataType) for field in data_type.fields)
elif isinstance(data_type, ArrayType):
return is_json_compatible(data_type.elementType)
elif isinstance(data_type, MapType):
return is_json_compatible(data_type.keyType) and is_json_compatible(data_type.valueType)
elif isinstance(data_type, (StringType, IntegerType, FloatType, BooleanType, NullType)):
return True
else:
Expand Down
2 changes: 2 additions & 0 deletions streaming/base/format/mds/encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,8 @@ class JSON(Encoding):
"""Store arbitrary data as JSON."""

def encode(self, obj: Any) -> bytes:
if isinstance(obj, np.ndarray):
obj = obj.tolist()
data = json.dumps(obj)
self._is_valid(obj, data)
return data.encode('utf-8')
Expand Down
17 changes: 11 additions & 6 deletions tests/base/converters/test_dataframe_to_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,13 +418,18 @@ def test_is_json_compatible(self):
]), True), True)
])

valid_schemas = [message_schema, prompt_response_schema, combined_schema]
string_map_keys_schema = StructType(
[StructField('map_field', MapType(StringType(), StringType()), nullable=True)])

valid_schemas = [
message_schema, prompt_response_schema, combined_schema, string_map_keys_schema
]

schema_with_binary = StructType([StructField('data', BinaryType(), nullable=True)])

# Schema with MapType having non-string keys
schema_with_non_string_map_keys = StructType(
[StructField('map_field', MapType(IntegerType(), StringType()), nullable=True)])
non_string_map_keys_schema = StructType(
[StructField('map_field', MapType(BinaryType(), StringType()), nullable=True)])

# Schema with DateType and TimestampType
schema_with_date_and_timestamp = StructType([
Expand All @@ -433,14 +438,14 @@ def test_is_json_compatible(self):
])

invalid_schemas = [
schema_with_binary, schema_with_non_string_map_keys, schema_with_date_and_timestamp
schema_with_binary, non_string_map_keys_schema, schema_with_date_and_timestamp
]

for s in valid_schemas:
assert is_json_compatible(s)
assert is_json_compatible(s), str(s)

for s in invalid_schemas:
assert not is_json_compatible(s)
assert not is_json_compatible(s), str(s)

def test_complex_schema(self,
complex_dataframe: Any,
Expand Down
16 changes: 16 additions & 0 deletions tests/test_encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,22 @@ def test_json_encode_decode(self, data: Any):
# Validate data content
assert dec_data == data

@pytest.mark.parametrize('data', [np.array([1]), np.array(['foo']), np.array([{'foo': 1}])])
def test_json_encode_decode_ndarray(self, data: Any):
json_enc = mdsEnc.JSON()
assert json_enc.size is None

# Test encode
enc_data = json_enc.encode(data)
assert isinstance(enc_data, bytes)

# Test decode
dec_data = json_enc.decode(enc_data)
assert isinstance(dec_data, list)

# Validate data content
assert dec_data == data.tolist()

def test_json_invalid_data(self):
wrong_json_with_single_quotes = "{'name': 'streaming'}"
with pytest.raises(json.JSONDecodeError):
Expand Down

0 comments on commit 8a247a4

Please sign in to comment.