From d9f138601f77e1d955dcc7cf8d7fc0262191714c Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Sun, 19 May 2024 18:04:11 +0400 Subject: [PATCH] make use of `pytestmark` --- py-polars/tests/unit/ml/test_to_jax.py | 230 +++++----- py-polars/tests/unit/ml/test_to_torch.py | 525 +++++++++++------------ 2 files changed, 377 insertions(+), 378 deletions(-) diff --git a/py-polars/tests/unit/ml/test_to_jax.py b/py-polars/tests/unit/ml/test_to_jax.py index 648545e15e41..5dc0c172f084 100644 --- a/py-polars/tests/unit/ml/test_to_jax.py +++ b/py-polars/tests/unit/ml/test_to_jax.py @@ -14,6 +14,8 @@ jx, _ = _lazy_import("jax") jxn, _ = _lazy_import("jax.numpy") +pytestmark = pytest.mark.ci_only + if TYPE_CHECKING: from polars.datatypes import PolarsDataType @@ -30,127 +32,125 @@ def df() -> pl.DataFrame: ) -@pytest.mark.ci_only() -class TestJaxIntegration: - """Test coverage for `to_jax` conversions.""" +def assert_array_equal(actual: Any, expected: Any, nans_equal: bool = True) -> None: + assert isinstance(actual, jx.Array) + jxn.array_equal(actual, expected, equal_nan=nans_equal) + + +@pytest.mark.parametrize( + ("dtype", "expected_jax_dtype"), + [ + (pl.Int8, "int8"), + (pl.Int16, "int16"), + (pl.Int32, "int32"), + (pl.Int64, "int32"), + (pl.UInt8, "uint8"), + (pl.UInt16, "uint16"), + (pl.UInt32, "uint32"), + (pl.UInt64, "uint32"), + ], +) +def test_to_jax_from_series( + dtype: PolarsDataType, + expected_jax_dtype: str, +) -> None: + s = pl.Series("x", [1, 2, 3, 4], dtype=dtype) + for dvc in (None, "cpu", jx.devices("cpu")[0]): + assert_array_equal( + s.to_jax(device=dvc), + jxn.array([1, 2, 3, 4], dtype=getattr(jxn, expected_jax_dtype)), + ) + - def assert_array_equal( - self, actual: Any, expected: Any, nans_equal: bool = True - ) -> None: - assert isinstance(actual, jx.Array) - jxn.array_equal(actual, expected, equal_nan=nans_equal) +def test_to_jax_array(df: pl.DataFrame) -> None: + a1 = df.to_jax() + a2 = df.to_jax("array") + a3 = df.to_jax("array", device="cpu") + a4 = df.to_jax("array", device=jx.devices("cpu")[0]) - @pytest.mark.parametrize( - ("dtype", "expected_jax_dtype"), + expected = jxn.array( [ - (pl.Int8, "int8"), - (pl.Int16, "int16"), - (pl.Int32, "int32"), - (pl.Int64, "int32"), - (pl.UInt8, "uint8"), - (pl.UInt16, "uint16"), - (pl.UInt32, "uint32"), - (pl.UInt64, "uint32"), + [1.0, 1.0, 1.5], + [2.0, 0.0, -0.5], + [2.0, 1.0, 0.0], + [3.0, 0.0, -2.0], ], + dtype=jxn.float32, ) - def test_to_jax_from_series( - self, - dtype: PolarsDataType, - expected_jax_dtype: str, - ) -> None: - s = pl.Series("x", [1, 2, 3, 4], dtype=dtype) - for dvc in (None, "cpu", jx.devices("cpu")[0]): - self.assert_array_equal( - s.to_jax(device=dvc), - jxn.array([1, 2, 3, 4], dtype=getattr(jxn, expected_jax_dtype)), - ) - - def test_to_jax_array(self, df: pl.DataFrame) -> None: - a1 = df.to_jax() - a2 = df.to_jax("array") - a3 = df.to_jax("array", device="cpu") - a4 = df.to_jax("array", device=jx.devices("cpu")[0]) - - expected = jxn.array( - [ - [1.0, 1.0, 1.5], - [2.0, 0.0, -0.5], - [2.0, 1.0, 0.0], - [3.0, 0.0, -2.0], - ], - dtype=jxn.float32, - ) - for a in (a1, a2, a3, a4): - self.assert_array_equal(a, expected) - - def test_to_jax_dict(self, df: pl.DataFrame) -> None: - arr_dict = df.to_jax("dict") - assert list(arr_dict.keys()) == ["x", "y", "z"] - - self.assert_array_equal(arr_dict["x"], jxn.array([1, 2, 2, 3], dtype=jxn.int8)) - self.assert_array_equal(arr_dict["y"], jxn.array([1, 0, 1, 0], dtype=jxn.int32)) - self.assert_array_equal( - arr_dict["z"], - jxn.array([1.5, -0.5, 0.0, -2.0], dtype=jxn.float32), - ) + for a in (a1, a2, a3, a4): + assert_array_equal(a, expected) + - arr_dict = df.to_jax("dict", dtype=pl.Float32) - for a, expected_data in zip( - arr_dict.values(), - ([1.0, 2.0, 2.0, 3.0], [1.0, 0.0, 1.0, 0.0], [1.5, -0.5, 0.0, -2.0]), - ): - self.assert_array_equal(a, jxn.array(expected_data, dtype=jxn.float32)) +def test_to_jax_dict(df: pl.DataFrame) -> None: + arr_dict = df.to_jax("dict") + assert list(arr_dict.keys()) == ["x", "y", "z"] - @pytest.mark.skipif( - sys.version_info < (3, 9), - reason="jax.numpy.bool requires Python >= 3.9", + assert_array_equal(arr_dict["x"], jxn.array([1, 2, 2, 3], dtype=jxn.int8)) + assert_array_equal(arr_dict["y"], jxn.array([1, 0, 1, 0], dtype=jxn.int32)) + assert_array_equal( + arr_dict["z"], + jxn.array([1.5, -0.5, 0.0, -2.0], dtype=jxn.float32), ) - def test_to_jax_feature_label_dict(self, df: pl.DataFrame) -> None: - df = pl.DataFrame( - { - "age": [25, 32, 45, 22, 34], - "income": [50000, 75000, 60000, 58000, 120000], - "education": ["bachelor", "master", "phd", "bachelor", "phd"], - "purchased": [False, True, True, False, True], - } - ).to_dummies("education", separator=":") - - lbl_feat_dict = df.to_jax(return_type="dict", label="purchased") - assert list(lbl_feat_dict.keys()) == ["label", "features"] - - self.assert_array_equal( - lbl_feat_dict["label"], - jxn.array([[False], [True], [True], [False], [True]], dtype=jxn.bool), - ) - self.assert_array_equal( - lbl_feat_dict["features"], - jxn.array( - [ - [25, 50000, 1, 0, 0], - [32, 75000, 0, 1, 0], - [45, 60000, 0, 0, 1], - [22, 58000, 1, 0, 0], - [34, 120000, 0, 0, 1], - ], - dtype=jxn.int32, - ), - ) - def test_misc_errors(self, df: pl.DataFrame) -> None: - with pytest.raises( - ValueError, - match="invalid `return_type`: 'stroopwafel'", - ): - _res0 = df.to_jax("stroopwafel") # type: ignore[call-overload] - - with pytest.raises( - ValueError, - match="`label` is required if setting `features` when `return_type='dict'", - ): - _res2 = df.to_jax("dict", features=cs.float()) - - with pytest.raises( - ValueError, - match="`label` and `features` only apply when `return_type` is 'dict'", - ): - _res3 = df.to_jax(label="stroopwafel") + arr_dict = df.to_jax("dict", dtype=pl.Float32) + for a, expected_data in zip( + arr_dict.values(), + ([1.0, 2.0, 2.0, 3.0], [1.0, 0.0, 1.0, 0.0], [1.5, -0.5, 0.0, -2.0]), + ): + assert_array_equal(a, jxn.array(expected_data, dtype=jxn.float32)) + + +@pytest.mark.skipif( + sys.version_info < (3, 9), + reason="jax.numpy.bool requires Python >= 3.9", +) +def test_to_jax_feature_label_dict(df: pl.DataFrame) -> None: + df = pl.DataFrame( + { + "age": [25, 32, 45, 22, 34], + "income": [50000, 75000, 60000, 58000, 120000], + "education": ["bachelor", "master", "phd", "bachelor", "phd"], + "purchased": [False, True, True, False, True], + } + ).to_dummies("education", separator=":") + + lbl_feat_dict = df.to_jax(return_type="dict", label="purchased") + assert list(lbl_feat_dict.keys()) == ["label", "features"] + + assert_array_equal( + lbl_feat_dict["label"], + jxn.array([[False], [True], [True], [False], [True]], dtype=jxn.bool), + ) + assert_array_equal( + lbl_feat_dict["features"], + jxn.array( + [ + [25, 50000, 1, 0, 0], + [32, 75000, 0, 1, 0], + [45, 60000, 0, 0, 1], + [22, 58000, 1, 0, 0], + [34, 120000, 0, 0, 1], + ], + dtype=jxn.int32, + ), + ) + + +def test_misc_errors(df: pl.DataFrame) -> None: + with pytest.raises( + ValueError, + match="invalid `return_type`: 'stroopwafel'", + ): + _res0 = df.to_jax("stroopwafel") # type: ignore[call-overload] + + with pytest.raises( + ValueError, + match="`label` is required if setting `features` when `return_type='dict'", + ): + _res2 = df.to_jax("dict", features=cs.float()) + + with pytest.raises( + ValueError, + match="`label` and `features` only apply when `return_type` is 'dict'", + ): + _res3 = df.to_jax(label="stroopwafel") diff --git a/py-polars/tests/unit/ml/test_to_torch.py b/py-polars/tests/unit/ml/test_to_torch.py index 3cfbe2bc01d8..7f1a4711c8ac 100644 --- a/py-polars/tests/unit/ml/test_to_torch.py +++ b/py-polars/tests/unit/ml/test_to_torch.py @@ -12,6 +12,8 @@ # ensures the tests aren't run locally; this avoids premature local import) torch, _ = _lazy_import("torch") +pytestmark = pytest.mark.ci_only + @pytest.fixture() def df() -> pl.DataFrame: @@ -25,302 +27,299 @@ def df() -> pl.DataFrame: ) -@pytest.mark.ci_only() -class TestTorchIntegration: - """Test coverage for `to_torch` conversions and `polars.ml.torch` classes.""" +def assert_tensor_equal(actual: Any, expected: Any) -> None: + torch.testing.assert_close(actual, expected) - def assert_tensor_equal(self, actual: Any, expected: Any) -> None: - torch.testing.assert_close(actual, expected) - def test_to_torch_from_series(self) -> None: - s = pl.Series("x", [1, 2, 3, 4], dtype=pl.Int8) - t = s.to_torch() +def test_to_torch_from_series() -> None: + s = pl.Series("x", [1, 2, 3, 4], dtype=pl.Int8) + t = s.to_torch() - assert list(t.shape) == [4] - self.assert_tensor_equal(t, torch.tensor([1, 2, 3, 4], dtype=torch.int8)) + assert list(t.shape) == [4] + assert_tensor_equal(t, torch.tensor([1, 2, 3, 4], dtype=torch.int8)) - # note: torch doesn't natively support uint16/32/64. - # confirm that we export to a suitable signed integer type - s = s.cast(pl.UInt16) - t = s.to_torch() - self.assert_tensor_equal(t, torch.tensor([1, 2, 3, 4], dtype=torch.int32)) + # note: torch doesn't natively support uint16/32/64. + # confirm that we export to a suitable signed integer type + s = s.cast(pl.UInt16) + t = s.to_torch() + assert_tensor_equal(t, torch.tensor([1, 2, 3, 4], dtype=torch.int32)) - for dtype in (pl.UInt32, pl.UInt64): - t = s.cast(dtype).to_torch() - self.assert_tensor_equal(t, torch.tensor([1, 2, 3, 4], dtype=torch.int64)) + for dtype in (pl.UInt32, pl.UInt64): + t = s.cast(dtype).to_torch() + assert_tensor_equal(t, torch.tensor([1, 2, 3, 4], dtype=torch.int64)) - def test_to_torch_tensor(self, df: pl.DataFrame) -> None: - t1 = df.to_torch() - t2 = df.to_torch("tensor") - assert list(t1.shape) == [4, 3] - assert (t1 == t2).all().item() is True +def test_to_torch_tensor(df: pl.DataFrame) -> None: + t1 = df.to_torch() + t2 = df.to_torch("tensor") - def test_to_torch_dict(self, df: pl.DataFrame) -> None: - td = df.to_torch("dict") + assert list(t1.shape) == [4, 3] + assert (t1 == t2).all().item() is True - assert list(td.keys()) == ["x", "y", "z"] - self.assert_tensor_equal(td["x"], torch.tensor([1, 2, 2, 3], dtype=torch.int8)) - self.assert_tensor_equal( - td["y"], torch.tensor([True, False, True, False], dtype=torch.bool) - ) - self.assert_tensor_equal( - td["z"], torch.tensor([1.5, -0.5, 0.0, -2.0], dtype=torch.float32) - ) +def test_to_torch_dict(df: pl.DataFrame) -> None: + td = df.to_torch("dict") - def test_to_torch_feature_label_dict(self, df: pl.DataFrame) -> None: - df = pl.DataFrame( - { - "age": [25, 32, 45, 22, 34], - "income": [50000, 75000, 60000, 58000, 120000], - "education": ["bachelor", "master", "phd", "bachelor", "phd"], - "purchased": [False, True, True, False, True], - }, - schema_overrides={"age": pl.Int32, "income": pl.Int32}, - ).to_dummies("education", separator=":") - - lbl_feat_dict = df.to_torch(return_type="dict", label="purchased") - assert list(lbl_feat_dict.keys()) == ["label", "features"] - - self.assert_tensor_equal( - lbl_feat_dict["label"], - torch.tensor([[False], [True], [True], [False], [True]], dtype=torch.bool), - ) - self.assert_tensor_equal( - lbl_feat_dict["features"], - torch.tensor( - [ - [25, 50000, 1, 0, 0], - [32, 75000, 0, 1, 0], - [45, 60000, 0, 0, 1], - [22, 58000, 1, 0, 0], - [34, 120000, 0, 0, 1], - ], - dtype=torch.int32, - ), - ) + assert list(td.keys()) == ["x", "y", "z"] - def test_to_torch_dataset(self, df: pl.DataFrame) -> None: - ds = df.to_torch("dataset", dtype=pl.Float64) + assert_tensor_equal(td["x"], torch.tensor([1, 2, 2, 3], dtype=torch.int8)) + assert_tensor_equal( + td["y"], torch.tensor([True, False, True, False], dtype=torch.bool) + ) + assert_tensor_equal( + td["z"], torch.tensor([1.5, -0.5, 0.0, -2.0], dtype=torch.float32) + ) - assert len(ds) == 4 - assert isinstance(ds, torch.utils.data.Dataset) - assert repr(ds).startswith(" None: + df = pl.DataFrame( + { + "age": [25, 32, 45, 22, 34], + "income": [50000, 75000, 60000, 58000, 120000], + "education": ["bachelor", "master", "phd", "bachelor", "phd"], + "purchased": [False, True, True, False, True], + }, + schema_overrides={"age": pl.Int32, "income": pl.Int32}, + ).to_dummies("education", separator=":") - def test_to_torch_dataset_feature_reorder(self, df: pl.DataFrame) -> None: - ds = df.to_torch("dataset", label="x", features=["z", "y"]) - self.assert_tensor_equal( - torch.tensor( - [ - [1.5000, 1.0000], - [-0.5000, 0.0000], - [0.0000, 1.0000], - [-2.0000, 0.0000], - ] - ), - ds.features, - ) - self.assert_tensor_equal( - torch.tensor([1, 2, 2, 3], dtype=torch.int8), ds.labels - ) + lbl_feat_dict = df.to_torch(return_type="dict", label="purchased") + assert list(lbl_feat_dict.keys()) == ["label", "features"] - def test_to_torch_dataset_feature_subset(self, df: pl.DataFrame) -> None: - ds = df.to_torch("dataset", label="x", features=["z"]) - self.assert_tensor_equal( - torch.tensor([[1.5000], [-0.5000], [0.0000], [-2.0000]]), - ds.features, - ) - self.assert_tensor_equal( - torch.tensor([1, 2, 2, 3], dtype=torch.int8), ds.labels - ) + assert_tensor_equal( + lbl_feat_dict["label"], + torch.tensor([[False], [True], [True], [False], [True]], dtype=torch.bool), + ) + assert_tensor_equal( + lbl_feat_dict["features"], + torch.tensor( + [ + [25, 50000, 1, 0, 0], + [32, 75000, 0, 1, 0], + [45, 60000, 0, 0, 1], + [22, 58000, 1, 0, 0], + [34, 120000, 0, 0, 1], + ], + dtype=torch.int32, + ), + ) - def test_to_torch_dataset_index_slice(self, df: pl.DataFrame) -> None: - ds = df.to_torch("dataset") - ts = ds[1:3] - expected = ( - torch.tensor([[2.0000, 0.0000, -0.5000], [2.0000, 1.0000, 0.0000]]), - ) - self.assert_tensor_equal(expected, ts) +def test_to_torch_dataset(df: pl.DataFrame) -> None: + ds = df.to_torch("dataset", dtype=pl.Float64) - ts = ds[::2] - expected = (torch.tensor([[1.0000, 1.0000, 1.5000], [2.0, 1.0, 0.0]]),) - self.assert_tensor_equal(expected, ts) + assert len(ds) == 4 + assert isinstance(ds, torch.utils.data.Dataset) + assert repr(ds).startswith(" None: + ds = df.to_torch("dataset", label="x", features=["z", "y"]) + assert_tensor_equal( + torch.tensor( + [ + [1.5000, 1.0000], + [-0.5000, 0.0000], + [0.0000, 1.0000], + [-2.0000, 0.0000], + ] + ), + ds.features, ) - def test_to_torch_dataset_index_multi(self, index: Any, df: pl.DataFrame) -> None: - ds = df.to_torch("dataset") - ts = ds[index] + assert_tensor_equal(torch.tensor([1, 2, 2, 3], dtype=torch.int8), ds.labels) - expected = (torch.tensor([[1.0, 1.0, 1.5], [3.0, 0.0, -2.0]]),) - self.assert_tensor_equal(expected, ts) - assert ds.schema == {"features": torch.float32, "labels": None} - def test_to_torch_dataset_index_range(self, df: pl.DataFrame) -> None: - ds = df.to_torch("dataset") - ts = ds[range(3, 0, -1)] +def test_to_torch_dataset_feature_subset(df: pl.DataFrame) -> None: + ds = df.to_torch("dataset", label="x", features=["z"]) + assert_tensor_equal( + torch.tensor([[1.5000], [-0.5000], [0.0000], [-2.0000]]), + ds.features, + ) + assert_tensor_equal(torch.tensor([1, 2, 2, 3], dtype=torch.int8), ds.labels) - expected = ( - torch.tensor([[3.0, 0.0, -2.0], [2.0, 1.0, 0.0], [2.0, 0.0, -0.5]]), - ) - self.assert_tensor_equal(expected, ts) - def test_to_dataset_half_precision(self, df: pl.DataFrame) -> None: - ds = df.to_torch("dataset", label="x") - assert ds.schema == {"features": torch.float32, "labels": torch.int8} +def test_to_torch_dataset_index_slice(df: pl.DataFrame) -> None: + ds = df.to_torch("dataset") + ts = ds[1:3] - dsf16 = ds.half() - assert dsf16.schema == {"features": torch.float16, "labels": torch.float16} + expected = (torch.tensor([[2.0000, 0.0000, -0.5000], [2.0000, 1.0000, 0.0000]]),) + assert_tensor_equal(expected, ts) - # half precision across all data - ts = dsf16[:3:2] - expected = ( - torch.tensor([[1.0000, 1.5000], [1.0000, 0.0000]], dtype=torch.float16), - torch.tensor([1.0, 2.0], dtype=torch.float16), - ) - self.assert_tensor_equal(expected, ts) + ts = ds[::2] + expected = (torch.tensor([[1.0000, 1.0000, 1.5000], [2.0, 1.0, 0.0]]),) + assert_tensor_equal(expected, ts) - # only apply half precision to the feature data - dsf16 = ds.half(labels=False) - assert dsf16.schema == {"features": torch.float16, "labels": torch.int8} - ts = dsf16[:3:2] - expected = ( - torch.tensor([[1.0000, 1.5000], [1.0000, 0.0000]], dtype=torch.float16), - torch.tensor([1, 2], dtype=torch.int8), - ) - self.assert_tensor_equal(expected, ts) +@pytest.mark.parametrize( + "index", + [ + [0, 3], + range(0, 4, 3), + slice(0, 4, 3), + ], +) +def test_to_torch_dataset_index_multi(index: Any, df: pl.DataFrame) -> None: + ds = df.to_torch("dataset") + ts = ds[index] + + expected = (torch.tensor([[1.0, 1.0, 1.5], [3.0, 0.0, -2.0]]),) + assert_tensor_equal(expected, ts) + assert ds.schema == {"features": torch.float32, "labels": None} + + +def test_to_torch_dataset_index_range(df: pl.DataFrame) -> None: + ds = df.to_torch("dataset") + ts = ds[range(3, 0, -1)] + + expected = (torch.tensor([[3.0, 0.0, -2.0], [2.0, 1.0, 0.0], [2.0, 0.0, -0.5]]),) + assert_tensor_equal(expected, ts) + + +def test_to_dataset_half_precision(df: pl.DataFrame) -> None: + ds = df.to_torch("dataset", label="x") + assert ds.schema == {"features": torch.float32, "labels": torch.int8} + + dsf16 = ds.half() + assert dsf16.schema == {"features": torch.float16, "labels": torch.float16} + + # half precision across all data + ts = dsf16[:3:2] + expected = ( + torch.tensor([[1.0000, 1.5000], [1.0000, 0.0000]], dtype=torch.float16), + torch.tensor([1.0, 2.0], dtype=torch.float16), + ) + assert_tensor_equal(expected, ts) + + # only apply half precision to the feature data + dsf16 = ds.half(labels=False) + assert dsf16.schema == {"features": torch.float16, "labels": torch.int8} + + ts = dsf16[:3:2] + expected = ( + torch.tensor([[1.0000, 1.5000], [1.0000, 0.0000]], dtype=torch.float16), + torch.tensor([1, 2], dtype=torch.int8), + ) + assert_tensor_equal(expected, ts) - # only apply half precision to the label data - dsf16 = ds.half(features=False) - assert dsf16.schema == {"features": torch.float32, "labels": torch.float16} + # only apply half precision to the label data + dsf16 = ds.half(features=False) + assert dsf16.schema == {"features": torch.float32, "labels": torch.float16} - ts = dsf16[:3:2] + ts = dsf16[:3:2] + expected = ( + torch.tensor([[1.0000, 1.5000], [1.0000, 0.0000]], dtype=torch.float32), + torch.tensor([1.0, 2.0], dtype=torch.float16), + ) + assert_tensor_equal(expected, ts) + + # no labels + dsf16 = df.to_torch("dataset").half() + assert dsf16.schema == {"features": torch.float16, "labels": None} + + ts = dsf16[:3:2] + expected = ( # type: ignore[assignment] + torch.tensor( + data=[[1.0000, 1.0000, 1.5000], [2.0000, 1.0000, 0.0000]], + dtype=torch.float16, + ), + ) + assert_tensor_equal(expected, ts) + + +@pytest.mark.parametrize( + ("label", "features"), + [ + ("x", None), + ("x", ["y", "z"]), + (cs.by_dtype(pl.INTEGER_DTYPES), ~cs.by_dtype(pl.INTEGER_DTYPES)), + ], +) +def test_to_torch_labelled_dataset(label: Any, features: Any, df: pl.DataFrame) -> None: + ds = df.to_torch("dataset", label=label, features=features) + ts = next(iter(torch.utils.data.DataLoader(ds, batch_size=2, shuffle=False))) + + expected = [ + torch.tensor([[1.0, 1.5], [0.0, -0.5]]), + torch.tensor([1, 2], dtype=torch.int8), + ] + assert len(ts) == len(expected) + for actual, exp in zip(ts, expected): + assert_tensor_equal(exp, actual) + + +def test_to_torch_labelled_dataset_expr(df: pl.DataFrame) -> None: + ds = df.to_torch( + "dataset", + dtype=pl.Float64, + label=(pl.col("x") * 8).cast(pl.Int16), + ) + dl = torch.utils.data.DataLoader(ds, batch_size=2, shuffle=False) + for data in (tuple(ds[:2]), tuple(next(iter(dl)))): expected = ( - torch.tensor([[1.0000, 1.5000], [1.0000, 0.0000]], dtype=torch.float32), - torch.tensor([1.0, 2.0], dtype=torch.float16), - ) - self.assert_tensor_equal(expected, ts) - - # no labels - dsf16 = df.to_torch("dataset").half() - assert dsf16.schema == {"features": torch.float16, "labels": None} - - ts = dsf16[:3:2] - expected = ( # type: ignore[assignment] - torch.tensor( - data=[[1.0000, 1.0000, 1.5000], [2.0000, 1.0000, 0.0000]], - dtype=torch.float16, - ), + torch.tensor([[1.0000, 1.5000], [0.0000, -0.5000]], dtype=torch.float64), + torch.tensor([8, 16], dtype=torch.int16), ) - self.assert_tensor_equal(expected, ts) + assert len(data) == len(expected) + for actual, exp in zip(data, expected): + assert_tensor_equal(exp, actual) - @pytest.mark.parametrize( - ("label", "features"), + +def test_to_torch_labelled_dataset_multi(df: pl.DataFrame) -> None: + ds = df.to_torch("dataset", label=["x", "y"]) + dl = torch.utils.data.DataLoader(ds, batch_size=3, shuffle=False) + ts = list(dl) + + expected = [ [ - ("x", None), - ("x", ["y", "z"]), - (cs.by_dtype(pl.INTEGER_DTYPES), ~cs.by_dtype(pl.INTEGER_DTYPES)), + torch.tensor([[1.5000], [-0.5000], [0.0000]]), + torch.tensor([[1, 1], [2, 0], [2, 1]], dtype=torch.int8), ], - ) - def test_to_torch_labelled_dataset( - self, label: Any, features: Any, df: pl.DataFrame - ) -> None: - ds = df.to_torch("dataset", label=label, features=features) - ts = next(iter(torch.utils.data.DataLoader(ds, batch_size=2, shuffle=False))) - - expected = [ - torch.tensor([[1.0, 1.5], [0.0, -0.5]]), - torch.tensor([1, 2], dtype=torch.int8), - ] - assert len(ts) == len(expected) - for actual, exp in zip(ts, expected): - self.assert_tensor_equal(exp, actual) - - def test_to_torch_labelled_dataset_expr(self, df: pl.DataFrame) -> None: - ds = df.to_torch( - "dataset", - dtype=pl.Float64, - label=(pl.col("x") * 8).cast(pl.Int16), - ) - dl = torch.utils.data.DataLoader(ds, batch_size=2, shuffle=False) - for data in (tuple(ds[:2]), tuple(next(iter(dl)))): - expected = ( - torch.tensor( - [[1.0000, 1.5000], [0.0000, -0.5000]], dtype=torch.float64 - ), - torch.tensor([8, 16], dtype=torch.int16), - ) - assert len(data) == len(expected) - for actual, exp in zip(data, expected): - self.assert_tensor_equal(exp, actual) - - def test_to_torch_labelled_dataset_multi(self, df: pl.DataFrame) -> None: - ds = df.to_torch("dataset", label=["x", "y"]) - dl = torch.utils.data.DataLoader(ds, batch_size=3, shuffle=False) - ts = list(dl) - - expected = [ - [ - torch.tensor([[1.5000], [-0.5000], [0.0000]]), - torch.tensor([[1, 1], [2, 0], [2, 1]], dtype=torch.int8), - ], - [ - torch.tensor([[-2.0]]), - torch.tensor([[3, 0]], dtype=torch.int8), - ], - ] - assert len(ts) == len(expected) - - for actual, exp in zip(ts, expected): - assert len(actual) == len(exp) - for a, e in zip(actual, exp): - self.assert_tensor_equal(e, a) - - def test_misc_errors(self, df: pl.DataFrame) -> None: - ds = df.to_torch("dataset") - - with pytest.raises( - ValueError, - match="invalid `return_type`: 'stroopwafel'", - ): - _res0 = df.to_torch("stroopwafel") # type: ignore[call-overload] - - with pytest.raises( - ValueError, - match="does not support u16, u32, or u64 dtypes", - ): - _res1 = df.to_torch(dtype=pl.UInt16) - - with pytest.raises( - IndexError, - match="tensors used as indices must be long, int", - ): - _res2 = ds[torch.tensor([0, 3], dtype=torch.complex64)] - - with pytest.raises( - ValueError, - match="`label` and `features` only apply when `return_type` is 'dataset' or 'dict'", - ): - _res3 = df.to_torch(label="stroopwafel") - - with pytest.raises( - ValueError, - match="`label` is required if setting `features` when `return_type='dict'", - ): - _res4 = df.to_torch("dict", features=cs.float()) + [ + torch.tensor([[-2.0]]), + torch.tensor([[3, 0]], dtype=torch.int8), + ], + ] + assert len(ts) == len(expected) + + for actual, exp in zip(ts, expected): + assert len(actual) == len(exp) + for a, e in zip(actual, exp): + assert_tensor_equal(e, a) + + +def test_misc_errors(df: pl.DataFrame) -> None: + ds = df.to_torch("dataset") + + with pytest.raises( + ValueError, + match="invalid `return_type`: 'stroopwafel'", + ): + _res0 = df.to_torch("stroopwafel") # type: ignore[call-overload] + + with pytest.raises( + ValueError, + match="does not support u16, u32, or u64 dtypes", + ): + _res1 = df.to_torch(dtype=pl.UInt16) + + with pytest.raises( + IndexError, + match="tensors used as indices must be long, int", + ): + _res2 = ds[torch.tensor([0, 3], dtype=torch.complex64)] + + with pytest.raises( + ValueError, + match="`label` and `features` only apply when `return_type` is 'dataset' or 'dict'", + ): + _res3 = df.to_torch(label="stroopwafel") + + with pytest.raises( + ValueError, + match="`label` is required if setting `features` when `return_type='dict'", + ): + _res4 = df.to_torch("dict", features=cs.float())