From bc65aaf24265f6843f40aab49cb9f1ac0ee36e4e Mon Sep 17 00:00:00 2001 From: Jay Chia <17691182+jaychia@users.noreply.github.com> Date: Thu, 3 Aug 2023 10:47:46 -0700 Subject: [PATCH] [CHORE] Add unit tests for int96 timestamps (#1229) Closes: #1215 --------- Co-authored-by: Jay Chia --- tests/table/table_io/test_parquet.py | 143 +++++++++++++------- tests/table/table_io/test_read_time_cast.py | 8 +- 2 files changed, 96 insertions(+), 55 deletions(-) diff --git a/tests/table/table_io/test_parquet.py b/tests/table/table_io/test_parquet.py index e9dda78c59..e810f1e403 100644 --- a/tests/table/table_io/test_parquet.py +++ b/tests/table/table_io/test_parquet.py @@ -1,15 +1,16 @@ from __future__ import annotations +import contextlib import datetime -import io import pathlib +import tempfile import pyarrow as pa import pyarrow.parquet as papq import pytest import daft -from daft.datatype import DataType +from daft.datatype import DataType, TimeUnit from daft.logical.schema import Schema from daft.runners.partitioning import TableReadOptions from daft.table import Table, schema_inference, table_io @@ -31,11 +32,11 @@ def test_read_input(tmpdir): assert table_io.read_parquet(f, schema=schema).to_arrow() == data -def _parquet_write_helper(data: pa.Table, row_group_size: int = None): - f = io.BytesIO() - papq.write_table(data, f, row_group_size=row_group_size) - f.seek(0) - return f +@contextlib.contextmanager +def _parquet_write_helper(data: pa.Table, row_group_size: int = None, papq_write_table_kwargs: dict = {}): + with tempfile.NamedTemporaryFile() as tmpfile: + papq.write_table(data, tmpfile.name, row_group_size=row_group_size, **papq_write_table_kwargs) + yield tmpfile.name @pytest.mark.parametrize( @@ -54,18 +55,18 @@ def _parquet_write_helper(data: pa.Table, row_group_size: int = None): ([1, None, 2], DataType.list("item", DataType.int64())), ], ) -def test_parquet_infer_schema(data, expected_dtype): - f = _parquet_write_helper( +@pytest.mark.parametrize("use_native_downloader", [True, False]) +def test_parquet_infer_schema(data, expected_dtype, use_native_downloader): + with _parquet_write_helper( pa.Table.from_pydict( { "id": [1, 2, 3], "data": [data, data, None], } ) - ) - - schema = schema_inference.from_parquet(f) - assert schema == Schema._from_field_name_and_types([("id", DataType.int64()), ("data", expected_dtype)]) + ) as f: + schema = schema_inference.from_parquet(f, use_native_downloader=use_native_downloader) + assert schema == Schema._from_field_name_and_types([("id", DataType.int64()), ("data", expected_dtype)]) @pytest.mark.parametrize( @@ -85,30 +86,33 @@ def test_parquet_infer_schema(data, expected_dtype): ({"foo": 1}, daft.Series.from_pylist([{"foo": 1}, {"foo": 1}, None])), ], ) -def test_parquet_read_data(data, expected_data_series): - f = _parquet_write_helper( +@pytest.mark.parametrize("use_native_downloader", [True, False]) +def test_parquet_read_data(data, expected_data_series, use_native_downloader): + with _parquet_write_helper( pa.Table.from_pydict( { "id": [1, 2, 3], "data": [data, data, None], } ) - ) - - schema = Schema._from_field_name_and_types([("id", DataType.int64()), ("data", expected_data_series.datatype())]) - expected = Table.from_pydict( - { - "id": [1, 2, 3], - "data": expected_data_series, - } - ) - table = table_io.read_parquet(f, schema) - assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}" + ) as f: + schema = Schema._from_field_name_and_types( + [("id", DataType.int64()), ("data", expected_data_series.datatype())] + ) + expected = Table.from_pydict( + { + "id": [1, 2, 3], + "data": expected_data_series, + } + ) + table = table_io.read_parquet(f, schema, use_native_downloader=use_native_downloader) + assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}" @pytest.mark.parametrize("row_group_size", [None, 1, 3]) -def test_parquet_read_data_limit_rows(row_group_size): - f = _parquet_write_helper( +@pytest.mark.parametrize("use_native_downloader", [True, False]) +def test_parquet_read_data_limit_rows(row_group_size, use_native_downloader): + with _parquet_write_helper( pa.Table.from_pydict( { "id": [1, 2, 3], @@ -116,34 +120,71 @@ def test_parquet_read_data_limit_rows(row_group_size): } ), row_group_size=row_group_size, - ) - - schema = Schema._from_field_name_and_types([("id", DataType.int64()), ("data", DataType.int64())]) - expected = Table.from_pydict( - { - "id": [1, 2], - "data": [1, 2], - } - ) - table = table_io.read_parquet(f, schema, read_options=TableReadOptions(num_rows=2)) - assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}" + ) as f: + schema = Schema._from_field_name_and_types([("id", DataType.int64()), ("data", DataType.int64())]) + expected = Table.from_pydict( + { + "id": [1, 2], + "data": [1, 2], + } + ) + table = table_io.read_parquet( + f, schema, read_options=TableReadOptions(num_rows=2), use_native_downloader=use_native_downloader + ) + assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}" -def test_parquet_read_data_select_columns(): - f = _parquet_write_helper( +@pytest.mark.parametrize("use_native_downloader", [True, False]) +def test_parquet_read_data_select_columns(use_native_downloader): + with _parquet_write_helper( pa.Table.from_pydict( { "id": [1, 2, 3], "data": [1, 2, None], } ) - ) - - schema = Schema._from_field_name_and_types([("id", DataType.int64()), ("data", DataType.int64())]) - expected = Table.from_pydict( - { - "data": [1, 2, None], - } - ) - table = table_io.read_parquet(f, schema, read_options=TableReadOptions(column_names=["data"])) - assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}" + ) as f: + schema = Schema._from_field_name_and_types([("id", DataType.int64()), ("data", DataType.int64())]) + expected = Table.from_pydict( + { + "data": [1, 2, None], + } + ) + table = table_io.read_parquet( + f, schema, read_options=TableReadOptions(column_names=["data"]), use_native_downloader=use_native_downloader + ) + assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}" + + +@pytest.mark.parametrize("use_native_downloader", [True, False]) +@pytest.mark.parametrize("use_deprecated_int96_timestamps", [True, False]) +def test_parquet_read_timestamps(use_native_downloader, use_deprecated_int96_timestamps): + data = { + "timestamp_ms": pa.array([1, 2, 3], pa.timestamp("ms")), + "timestamp_us": pa.array([1, 2, 3], pa.timestamp("us")), + } + schema = [ + ("timestamp_ms", DataType.timestamp(TimeUnit.ms())), + ("timestamp_us", DataType.timestamp(TimeUnit.us())), + ] + # int64 timestamps cannot support nanosecond resolutions + if use_deprecated_int96_timestamps: + data["timestamp_ns"] = pa.array([1, 2, 3], pa.timestamp("ns")) + schema.append(("timestamp_ns", DataType.timestamp(TimeUnit.ns()))) + + with _parquet_write_helper( + pa.Table.from_pydict(data), + papq_write_table_kwargs={ + "use_deprecated_int96_timestamps": use_deprecated_int96_timestamps, + "coerce_timestamps": "us" if not use_deprecated_int96_timestamps else None, + }, + ) as f: + schema = Schema._from_field_name_and_types(schema) + expected = Table.from_pydict(data) + table = table_io.read_parquet( + f, + schema, + read_options=TableReadOptions(column_names=schema.column_names()), + use_native_downloader=use_native_downloader, + ) + assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}" diff --git a/tests/table/table_io/test_read_time_cast.py b/tests/table/table_io/test_read_time_cast.py index 45372154b2..f2aa7f10b2 100644 --- a/tests/table/table_io/test_read_time_cast.py +++ b/tests/table/table_io/test_read_time_cast.py @@ -41,7 +41,7 @@ ], ) def test_parquet_cast_at_read_time(data, schema, expected): - f = _parquet_write_helper(data) - table = table_io.read_parquet(f, schema) - assert table.schema() == schema - assert table.to_arrow() == expected.to_arrow() + with _parquet_write_helper(data) as f: + table = table_io.read_parquet(f, schema) + assert table.schema() == schema + assert table.to_arrow() == expected.to_arrow()