diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 8cd3b1b16f346..4bdc4041a86d3 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -72,6 +72,7 @@ INTEGER_DTYPES, N_INFER_DEFAULT, Boolean, + Float32, Float64, Int32, Int64, @@ -113,7 +114,7 @@ ) from polars.selectors import _expand_selector_dicts, _expand_selectors from polars.slice import PolarsSlice -from polars.type_aliases import DbWriteMode, TorchExportType +from polars.type_aliases import DbWriteMode, JaxExportType, TorchExportType with contextlib.suppress(ImportError): # Module not available when building docs from polars.polars import dtype_str_repr as _dtype_str_repr @@ -126,6 +127,7 @@ from typing import Literal import deltalake + import jax import torch from hvplot.plotting.core import hvPlotTabularPolars from xlsxwriter import Workbook @@ -1527,7 +1529,7 @@ def to_numpy( However, the C-like order might be more appropriate to use for downstream applications to prevent cloning data, e.g. when reshaping into a one-dimensional array. Note that this option only takes effect if - `structured` is set to `False` and the DataFrame dtypes allow for a + `structured` is set to `False` and the DataFrame dtypes allow a global dtype for all columns. allow_copy Allow memory to be copied to perform the conversion. If set to `False`, @@ -1620,6 +1622,200 @@ def raise_on_copy(msg: str) -> None: return out + @overload + def to_jax( + self, + return_type: Literal["array"] = ..., + *, + device: jax.Device | str | None = ..., + label: str | Expr | Sequence[str | Expr] | None = ..., + features: str | Expr | Sequence[str | Expr] | None = ..., + dtype: PolarsDataType | None = ..., + order: IndexOrder = ..., + ) -> jax.Array: ... + + @overload + def to_jax( + self, + return_type: Literal["dict"], + *, + device: jax.Device | str | None = ..., + label: str | Expr | Sequence[str | Expr] | None = ..., + features: str | Expr | Sequence[str | Expr] | None = ..., + dtype: PolarsDataType | None = ..., + order: IndexOrder = ..., + ) -> dict[str, jax.Array]: ... + + def to_jax( + self, + return_type: JaxExportType = "array", + *, + device: jax.Device | str | None = None, + label: str | Expr | Sequence[str | Expr] | None = None, + features: str | Expr | Sequence[str | Expr] | None = None, + dtype: PolarsDataType | None = None, + order: IndexOrder = "fortran", + ) -> jax.Array | dict[str, jax.Array]: + """ + Convert DataFrame to a 2D Jax Array, or dict of Jax Arrays. + + Parameters + ---------- + return_type : {"array", "dict"} + Set return type; a 2D Jax Array, or dict of Jax Arrays. + device + Specify the jax `Device` on which the array will be created; can provide + a string (such as "cpu", "gpu", or "tpu") in which case the device is + retrieved as `jax.devices(string)[0]`. For more specific control you + can supply the instantiated `Device` directly. If None, arrays are + created on the default device. + label + One or more column names, expressions, or selectors that label the feature + data; results in a `{"label": ..., "features": ...}` dict being returned + when `return_type` is "dict" instead of a `{"col": array, }` dict. + features + One or more column names, expressions, or selectors that contain the feature + data; if omitted, all columns that are not designated as part of the label + are used. Only applies when `return_type` is "dict". + dtype + Unify the dtype of all returned arrays; this casts any column that is + not already of the required dtype before converting to Array. Note that + export will be single-precision (32bit) unless the Jax config/environment + directs otherwise (eg: "jax_enable_x64" was set True in the config object + at startup, or "JAX_ENABLE_X64" is set to "1" in the environment). + order : {"c", "fortran"} + The index order of the returned Jax array, either C-like or Fortran-like. + + See Also + -------- + to_dummies + to_numpy + to_torch + + Examples + -------- + >>> df = pl.DataFrame( + ... { + ... "lbl": [0, 1, 2, 3], + ... "feat1": [1, 0, 0, 1], + ... "feat2": [1.5, -0.5, 0.0, -2.25], + ... } + ... ) + + Standard return type (2D Array), on the standard device: + + >>> df.to_jax() + Array([[ 0. , 1. , 1.5 ], + [ 1. , 0. , -0.5 ], + [ 2. , 0. , 0. ], + [ 3. , 1. , -2.25]], dtype=float32) + + Create the Array on the default GPU device: + + >>> a = df.to_jax(device="gpu") # doctest: +SKIP + >>> a.device() # doctest: +SKIP + GpuDevice(id=0, process_index=0) + + Create the Array on a specific GPU device: + + >>> gpu_device = jax.devices("gpu")[1]) # doctest: +SKIP + >>> a = df.to_jax(device=gpu_device) # doctest: +SKIP + >>> a.device() # doctest: +SKIP + GpuDevice(id=1, process_index=0) + + As a dictionary of individual Arrays: + + >>> df.to_jax("dict") + {'lbl': Array([0, 1, 2, 3], dtype=int32), + 'feat1': Array([1, 0, 0, 1], dtype=int32), + 'feat2': Array([ 1.5 , -0.5 , 0. , -2.25], dtype=float32)} + + As a "label" and "features" dictionary; note that as "features" is not + declared, it defaults to all the columns that are not in "label": + + >>> df.to_jax("dict", label="lbl") + {'label': Array([[0], + [1], + [2], + [3]], dtype=int32), + 'features': Array([[ 1. , 1.5 ], + [ 0. , -0.5 ], + [ 0. , 0. ], + [ 1. , -2.25]], dtype=float32)} + + As a "label" and "features" dictionary where each is designated using + a selector expression (which can also be used to cast the data if the + label and features are better-represented with different dtypes): + + >>> import polars.selectors as cs + >>> df.to_jax( + ... return_type="dict", + ... features=cs.float(), + ... label=pl.col("lbl").cast(pl.UInt8), + ... ) + {'label': Array([[0], + [1], + [2], + [3]], dtype=uint8), + 'features': Array([[ 1.5 ], + [-0.5 ], + [ 0. ], + [-2.25]], dtype=float32)} + """ + if return_type != "dict" and (label is not None or features is not None): + msg = "`label` and `features` only apply when `return_type` is 'dict'" + raise ValueError(msg) + elif return_type == "dict" and label is None and features is not None: + msg = "`label` is required if setting `features` when `return_type='dict'" + raise ValueError(msg) + + jx = import_optional( + "jax", + install_message="Please see `https://jax.readthedocs.io/en/latest/installation.html` " + "for specific installation recommendations for the Jax package", + ) + enabled_double_precision = jx.config.jax_enable_x64 or bool( + int(os.environ.get("JAX_ENABLE_X64", "1")) + ) + if dtype: + frame = self.cast(dtype) + elif not enabled_double_precision: + # enforce single-precision unless environment/config directs otherwise + frame = self.cast({Float64: Float32, Int64: Int32, UInt64: UInt32}) + else: + frame = self + + if isinstance(device, str): + device = jx.devices(device)[0] + + with contextlib.nullcontext() if device is None else jx.default_device(device): + if return_type == "array": + return jx.numpy.asarray( + # note: jax arrays are immutable, so can avoid a copy (vs torch) + a=frame.to_numpy(writable=False, use_pyarrow=False, order=order), + order="K", + ) + elif return_type == "dict": + if label is not None: + # return a {"label": array(s), "features": array(s)} dict + label_frame = frame.select(label) + features_frame = ( + frame.select(features) + if features is not None + else frame.drop(*label_frame.columns) + ) + return { + "label": label_frame.to_jax(), + "features": features_frame.to_jax(), + } + else: + # return a {"col": array} dict + return {srs.name: srs.to_jax() for srs in frame} + else: + valid_jax_types = ", ".join(get_args(JaxExportType)) + msg = f"invalid `return_type`: {return_type!r}\nExpected one of: {valid_jax_types}" + raise ValueError(msg) + @overload def to_torch( self, @@ -1659,31 +1855,35 @@ def to_torch( dtype: PolarsDataType | None = None, ) -> torch.Tensor | dict[str, torch.Tensor] | PolarsDataset: """ - Convert DataFrame to a 2D PyTorch tensor, Dataset, or dict of Tensors. + Convert DataFrame to a 2D PyTorch Tensor, Dataset, or dict of Tensors. .. versionadded:: 0.20.23 Parameters ---------- return_type : {"tensor", "dataset", "dict"} - Set return type; a 2D PyTorch tensor, PolarsDataset (a frame-specialized + Set return type; a 2D PyTorch Tensor, PolarsDataset (a frame-specialized TensorDataset), or dict of Tensors. label One or more column names, expressions, or selectors that label the feature data; when `return_type` is "dataset", the PolarsDataset will return `(features, label)` tensor tuples for each row. Otherwise, it returns - `(features,)` tensor tuples where the feature contains all the row data; - note that setting this parameter with any other result type will raise an - informative error. + `(features,)` tensor tuples where the feature contains all the row data. features One or more column names, expressions, or selectors that contain the feature data; if omitted, all columns that are not designated as part of the label - are used. This parameter is a no-op for return-types other than "dataset". + are used. dtype - Unify the dtype of all returned tensors; this casts any frame Series - that are not of the required dtype before converting to tensor. This - includes the label column *unless* the label is an expression (such - as `pl.col("label_column").cast(pl.Int16)`). + Unify the dtype of all returned tensors; this casts any column that is + not of the required dtype before converting to Tensor. This includes + the label column *unless* the label is an expression (such as + `pl.col("label_column").cast(pl.Int16)`). + + See Also + -------- + to_dummies + to_jax + to_numpy Examples -------- @@ -1710,6 +1910,19 @@ def to_torch( 'feat1': tensor([1, 0, 0, 1]), 'feat2': tensor([ 1.5000, -0.5000, 0.0000, -2.2500], dtype=torch.float64)} + As a "label" and "features" dictionary; note that as "features" is not + declared, it defaults to all the columns that are not in "label": + + >>> df.to_torch("dict", label="lbl", dtype=pl.Float32) + {'label': tensor([[0.], + [1.], + [2.], + [3.]]), + 'features': tensor([[ 1.0000, 1.5000], + [ 0.0000, -0.5000], + [ 0.0000, 0.0000], + [ 1.0000, -2.2500]])} + As a PolarsDataset, with f64 supertype: >>> ds = df.to_torch("dataset", dtype=pl.Float64) @@ -1722,7 +1935,7 @@ def to_torch( (tensor([[ 0.0000, 1.0000, 1.5000], [ 3.0000, 1.0000, -2.2500]], dtype=torch.float64),) - As a convenience the PolarsDataset can opt-in to half-precision data + As a convenience the PolarsDataset can opt in to half-precision data for experimentation (usually this would be set on the model/pipeline): >>> list(ds.half()) @@ -1746,7 +1959,7 @@ def to_torch( supported). >>> ds = df.to_torch( - ... "dataset", + ... return_type="dataset", ... dtype=pl.Float32, ... label=pl.col("lbl").cast(pl.Int16), ... ) @@ -1771,8 +1984,13 @@ def to_torch( ... batch_size=64, ... ) # doctest: +SKIP """ - if return_type != "dataset" and (label is not None or features is not None): - msg = "the `label` and `features` parameters can only be set when `return_type='dataset'`" + if return_type not in ("dataset", "dict") and ( + label is not None or features is not None + ): + msg = "`label` and `features` only apply when `return_type` is 'dataset' or 'dict'" + raise ValueError(msg) + elif return_type == "dict" and label is None and features is not None: + msg = "`label` is required if setting `features` when `return_type='dict'" raise ValueError(msg) torch = import_optional("torch") @@ -1785,10 +2003,28 @@ def to_torch( frame = self.cast(to_dtype) # type: ignore[arg-type] if return_type == "tensor": + # note: torch tensors are not immutable, so we must consider them writable return torch.from_numpy(frame.to_numpy(writable=True, use_pyarrow=False)) + elif return_type == "dict": - return {srs.name: srs.to_torch() for srs in frame} + if label is not None: + # return a {"label": tensor(s), "features": tensor(s)} dict + label_frame = frame.select(label) + features_frame = ( + frame.select(features) + if features is not None + else frame.drop(*label_frame.columns) + ) + return { + "label": label_frame.to_torch(), + "features": features_frame.to_torch(), + } + else: + # return a {"col": tensor} dict + return {srs.name: srs.to_torch() for srs in frame} + elif return_type == "dataset": + # return a torch Dataset object from polars.ml.torch import PolarsDataset return PolarsDataset(frame, label=label, features=features) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 9d1a939dc3acc..555b9f1132981 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -2,6 +2,8 @@ import contextlib import math +import os +from contextlib import nullcontext from datetime import date, datetime, time, timedelta from decimal import Decimal as PyDecimal from typing import ( @@ -66,6 +68,7 @@ Decimal, Duration, Enum, + Float32, Float64, Int8, Int16, @@ -117,6 +120,7 @@ if TYPE_CHECKING: import sys + import jax import torch from hvplot.plotting.core import hvPlotTabularPolars @@ -4472,9 +4476,52 @@ def to_numpy( return self._s.to_numpy(allow_copy=allow_copy, writable=writable) + def to_jax(self, device: jax.Device | str | None = None) -> jax.Array: + """ + Convert this Series to a Jax Array. + + Parameters + ---------- + device + Specify the jax `Device` on which the array will be created; can provide + a string (such as "cpu", "gpu", or "tpu") in which case the device is + retrieved as `jax.devices(string)[0]`. For more specific control you + can supply the instantiated `Device` directly. If None, arrays are + created on the default device. + + Examples + -------- + >>> s = pl.Series("x", [10.5, 0.0, -10.0, 5.5]) + >>> s.to_jax() + Array([ 10.5, 0. , -10. , 5.5], dtype=float32) + """ + jx = import_optional( + "jax", + install_message="Please see `https://jax.readthedocs.io/en/latest/installation.html` " + "for specific installation recommendations for the Jax package", + ) + if isinstance(device, str): + device = jx.devices(device)[0] + if ( + jx.config.jax_enable_x64 + or bool(int(os.environ.get("JAX_ENABLE_X64", "1"))) + or self.dtype not in {Float64, Int64, UInt64} + ): + srs = self + else: + single_precision = {Float64: Float32, Int64: Int32, UInt64: UInt32} + srs = self.cast(single_precision[self.dtype]) # type: ignore[index] + + with nullcontext() if device is None else jx.default_device(device): + return jx.numpy.asarray( + # note: jax arrays are immutable, so can avoid a copy (vs torch) + a=srs.to_numpy(writable=False, use_pyarrow=False), + order="K", + ) + def to_torch(self) -> torch.Tensor: """ - Convert this Series to a PyTorch tensor. + Convert this Series to a PyTorch Tensor. Examples -------- diff --git a/py-polars/polars/type_aliases.py b/py-polars/polars/type_aliases.py index b57dcee1f5a3f..92daf0b2dd0c9 100644 --- a/py-polars/polars/type_aliases.py +++ b/py-polars/polars/type_aliases.py @@ -163,6 +163,7 @@ DbWriteEngine: TypeAlias = Literal["sqlalchemy", "adbc"] DbWriteMode: TypeAlias = Literal["replace", "append", "fail"] EpochTimeUnit = Literal["ns", "us", "ms", "s", "d"] +JaxExportType: TypeAlias = Literal["array", "dict"] Orientation: TypeAlias = Literal["col", "row"] SearchSortedSide: TypeAlias = Literal["any", "left", "right"] TorchExportType: TypeAlias = Literal["tensor", "dataset", "dict"] diff --git a/py-polars/requirements-ci.txt b/py-polars/requirements-ci.txt index 3086002307dd5..fbb39463fcedf 100644 --- a/py-polars/requirements-ci.txt +++ b/py-polars/requirements-ci.txt @@ -4,4 +4,6 @@ # ------------------------------------------------------- --extra-index-url https://download.pytorch.org/whl/cpu torch +jax +jaxlib pyiceberg>=0.5.0 diff --git a/py-polars/tests/docs/run_doctest.py b/py-polars/tests/docs/run_doctest.py index 7da0150e33474..39b95a548ddca 100644 --- a/py-polars/tests/docs/run_doctest.py +++ b/py-polars/tests/docs/run_doctest.py @@ -58,6 +58,7 @@ # if the module is found in the environment those doctests will # run; if the module is not found, their doctests are skipped. OPTIONAL_MODULES_AND_METHODS: dict[str, set[str]] = { + "jax": {"to_jax"}, "torch": {"to_torch"}, } OPTIONAL_MODULES: set[str] = set() diff --git a/py-polars/tests/unit/ml/__init__.py b/py-polars/tests/unit/ml/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/py-polars/tests/unit/ml/test_to_jax.py b/py-polars/tests/unit/ml/test_to_jax.py new file mode 100644 index 0000000000000..0d26ce57c7b30 --- /dev/null +++ b/py-polars/tests/unit/ml/test_to_jax.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +import polars as pl +import polars.selectors as cs +from polars.dependencies import _lazy_import + +# don't import jax until an actual test is triggered (the decorator already +# ensures the tests aren't run locally; this avoids premature local import) +jx, _ = _lazy_import("jax") +jxn, _ = _lazy_import("jax.numpy") + + +@pytest.fixture() +def df() -> pl.DataFrame: + return pl.DataFrame( + { + "x": [1, 2, 2, 3], + "y": [True, False, True, False], + "z": [1.5, -0.5, 0.0, -2.0], + }, + schema_overrides={"x": pl.Int8, "z": pl.Float32}, + ) + + +@pytest.mark.ci_only() +class TestJaxIntegration: + """Test coverage for `to_jax` conversions.""" + + def assert_array_equal( + self, actual: Any, expected: Any, nans_equal: bool = True + ) -> None: + assert isinstance(actual, jx.Array) + jxn.array_equal(actual, expected, equal_nan=nans_equal) + + def test_to_jax_from_series(self) -> None: + s = pl.Series("x", [1, 2, 3, 4], dtype=pl.Int8) + a = s.to_jax() + + assert list(a.shape) == [4] + self.assert_array_equal(a, jxn.array([1, 2, 3, 4], dtype=jxn.int8)) + + for dtype in (pl.Int32, pl.Int64, pl.UInt32, pl.UInt64): + a = s.cast(dtype).to_jax() + self.assert_array_equal(a, jxn.array([1, 2, 3, 4], dtype=jxn.int32)) + + def test_to_jax_array(self, df: pl.DataFrame) -> None: + a1 = df.to_jax() + a2 = df.to_jax("array") + a3 = df.to_jax("array", device="cpu") + a4 = df.to_jax("array", device=jx.devices("cpu")[0]) + + expected = jxn.array( + [ + [1.0, 1.0, 1.5], + [2.0, 0.0, -0.5], + [2.0, 1.0, 0.0], + [3.0, 0.0, -2.0], + ], + dtype=jxn.float32, + ) + for a in (a1, a2, a3, a4): + self.assert_array_equal(a, expected) + + def test_to_jax_dict(self, df: pl.DataFrame) -> None: + arr_dict = df.to_jax("dict") + + assert list(arr_dict.keys()) == ["x", "y", "z"] + + self.assert_array_equal(arr_dict["x"], jxn.array([1, 2, 2, 3], dtype=jxn.int8)) + self.assert_array_equal( + arr_dict["y"], jxn.array([True, False, True, False], dtype=jxn.bool) + ) + self.assert_array_equal( + arr_dict["z"], jxn.array([1.5, -0.5, 0.0, -2.0], dtype=jxn.float32) + ) + + def test_to_jax_feature_label_dict(self, df: pl.DataFrame) -> None: + df = pl.DataFrame( + { + "age": [25, 32, 45, 22, 34], + "income": [50000, 75000, 60000, 58000, 120000], + "education": ["bachelor", "master", "phd", "bachelor", "phd"], + "purchased": [False, True, True, False, True], + } + ).to_dummies("education", separator=":") + + lbl_feat_dict = df.to_jax(return_type="dict", label="purchased") + assert list(lbl_feat_dict.keys()) == ["label", "features"] + + self.assert_array_equal( + lbl_feat_dict["label"], + jxn.array([[False], [True], [True], [False], [True]], dtype=jxn.bool), + ) + self.assert_array_equal( + lbl_feat_dict["features"], + jxn.array( + [ + [25, 50000, 1, 0, 0], + [32, 75000, 0, 1, 0], + [45, 60000, 0, 0, 1], + [22, 58000, 1, 0, 0], + [34, 120000, 0, 0, 1], + ], + dtype=jxn.int32, + ), + ) + + def test_misc_errors(self, df: pl.DataFrame) -> None: + with pytest.raises( + ValueError, + match="invalid `return_type`: 'stroopwafel'", + ): + _res0 = df.to_jax("stroopwafel") # type: ignore[call-overload] + + with pytest.raises( + ValueError, + match="`label` is required if setting `features` when `return_type='dict'", + ): + _res2 = df.to_jax("dict", features=cs.float()) + + with pytest.raises( + ValueError, + match="`label` and `features` only apply when `return_type` is 'dict'", + ): + _res3 = df.to_jax(label="stroopwafel") diff --git a/py-polars/tests/unit/dataframe/test_to_torch.py b/py-polars/tests/unit/ml/test_to_torch.py similarity index 83% rename from py-polars/tests/unit/dataframe/test_to_torch.py rename to py-polars/tests/unit/ml/test_to_torch.py index be8de2d1f2d1d..d52c754eb83f7 100644 --- a/py-polars/tests/unit/dataframe/test_to_torch.py +++ b/py-polars/tests/unit/ml/test_to_torch.py @@ -9,7 +9,7 @@ from polars.dependencies import _lazy_import # don't import torch until an actual test is triggered (the decorator already -# ensures the tests aren't run locally, this will skip premature local import) +# ensures the tests aren't run locally; this avoids premature local import) torch, _ = _lazy_import("torch") @@ -29,27 +29,25 @@ def df() -> pl.DataFrame: class TestTorchIntegration: """Test coverage for `to_torch` conversions and `polars.ml.torch` classes.""" - def assert_tensor(self, actual: Any, expected: Any) -> None: + def assert_tensor_equal(self, actual: Any, expected: Any) -> None: torch.testing.assert_close(actual, expected) - def test_to_torch_series( - self, - ) -> None: + def test_to_torch_from_series(self) -> None: s = pl.Series("x", [1, 2, 3, 4], dtype=pl.Int8) t = s.to_torch() assert list(t.shape) == [4] - self.assert_tensor(t, torch.tensor([1, 2, 3, 4], dtype=torch.int8)) + self.assert_tensor_equal(t, torch.tensor([1, 2, 3, 4], dtype=torch.int8)) # note: torch doesn't natively support uint16/32/64. # confirm that we export to a suitable signed integer type s = s.cast(pl.UInt16) t = s.to_torch() - self.assert_tensor(t, torch.tensor([1, 2, 3, 4], dtype=torch.int32)) + self.assert_tensor_equal(t, torch.tensor([1, 2, 3, 4], dtype=torch.int32)) for dtype in (pl.UInt32, pl.UInt64): t = s.cast(dtype).to_torch() - self.assert_tensor(t, torch.tensor([1, 2, 3, 4], dtype=torch.int64)) + self.assert_tensor_equal(t, torch.tensor([1, 2, 3, 4], dtype=torch.int64)) def test_to_torch_tensor(self, df: pl.DataFrame) -> None: t1 = df.to_torch() @@ -63,11 +61,11 @@ def test_to_torch_dict(self, df: pl.DataFrame) -> None: assert list(td.keys()) == ["x", "y", "z"] - self.assert_tensor(td["x"], torch.tensor([1, 2, 2, 3], dtype=torch.int8)) - self.assert_tensor( + self.assert_tensor_equal(td["x"], torch.tensor([1, 2, 2, 3], dtype=torch.int8)) + self.assert_tensor_equal( td["y"], torch.tensor([True, False, True, False], dtype=torch.bool) ) - self.assert_tensor( + self.assert_tensor_equal( td["z"], torch.tensor([1.5, -0.5, 0.0, -2.0], dtype=torch.float32) ) @@ -81,11 +79,13 @@ def test_to_torch_dataset(self, df: pl.DataFrame) -> None: ts = ds[0] assert isinstance(ts, tuple) assert len(ts) == 1 - self.assert_tensor(ts[0], torch.tensor([1.0, 1.0, 1.5], dtype=torch.float64)) + self.assert_tensor_equal( + ts[0], torch.tensor([1.0, 1.0, 1.5], dtype=torch.float64) + ) def test_to_torch_dataset_feature_reorder(self, df: pl.DataFrame) -> None: ds = df.to_torch("dataset", label="x", features=["z", "y"]) - self.assert_tensor( + self.assert_tensor_equal( torch.tensor( [ [1.5000, 1.0000], @@ -96,15 +96,19 @@ def test_to_torch_dataset_feature_reorder(self, df: pl.DataFrame) -> None: ), ds.features, ) - self.assert_tensor(torch.tensor([1, 2, 2, 3], dtype=torch.int8), ds.labels) + self.assert_tensor_equal( + torch.tensor([1, 2, 2, 3], dtype=torch.int8), ds.labels + ) def test_to_torch_dataset_feature_subset(self, df: pl.DataFrame) -> None: ds = df.to_torch("dataset", label="x", features=["z"]) - self.assert_tensor( + self.assert_tensor_equal( torch.tensor([[1.5000], [-0.5000], [0.0000], [-2.0000]]), ds.features, ) - self.assert_tensor(torch.tensor([1, 2, 2, 3], dtype=torch.int8), ds.labels) + self.assert_tensor_equal( + torch.tensor([1, 2, 2, 3], dtype=torch.int8), ds.labels + ) def test_to_torch_dataset_index_slice(self, df: pl.DataFrame) -> None: ds = df.to_torch("dataset") @@ -113,11 +117,11 @@ def test_to_torch_dataset_index_slice(self, df: pl.DataFrame) -> None: expected = ( torch.tensor([[2.0000, 0.0000, -0.5000], [2.0000, 1.0000, 0.0000]]), ) - self.assert_tensor(expected, ts) + self.assert_tensor_equal(expected, ts) ts = ds[::2] expected = (torch.tensor([[1.0000, 1.0000, 1.5000], [2.0, 1.0, 0.0]]),) - self.assert_tensor(expected, ts) + self.assert_tensor_equal(expected, ts) @pytest.mark.parametrize( "index", @@ -132,7 +136,7 @@ def test_to_torch_dataset_index_multi(self, index: Any, df: pl.DataFrame) -> Non ts = ds[index] expected = (torch.tensor([[1.0, 1.0, 1.5], [3.0, 0.0, -2.0]]),) - self.assert_tensor(expected, ts) + self.assert_tensor_equal(expected, ts) assert ds.schema == {"features": torch.float32, "labels": None} def test_to_torch_dataset_index_range(self, df: pl.DataFrame) -> None: @@ -142,7 +146,7 @@ def test_to_torch_dataset_index_range(self, df: pl.DataFrame) -> None: expected = ( torch.tensor([[3.0, 0.0, -2.0], [2.0, 1.0, 0.0], [2.0, 0.0, -0.5]]), ) - self.assert_tensor(expected, ts) + self.assert_tensor_equal(expected, ts) def test_to_dataset_half_precision(self, df: pl.DataFrame) -> None: ds = df.to_torch("dataset", label="x") @@ -157,7 +161,7 @@ def test_to_dataset_half_precision(self, df: pl.DataFrame) -> None: torch.tensor([[1.0000, 1.5000], [1.0000, 0.0000]], dtype=torch.float16), torch.tensor([1.0, 2.0], dtype=torch.float16), ) - self.assert_tensor(expected, ts) + self.assert_tensor_equal(expected, ts) # only apply half precision to the feature data dsf16 = ds.half(labels=False) @@ -168,7 +172,7 @@ def test_to_dataset_half_precision(self, df: pl.DataFrame) -> None: torch.tensor([[1.0000, 1.5000], [1.0000, 0.0000]], dtype=torch.float16), torch.tensor([1, 2], dtype=torch.int8), ) - self.assert_tensor(expected, ts) + self.assert_tensor_equal(expected, ts) # only apply half precision to the label data dsf16 = ds.half(features=False) @@ -179,7 +183,7 @@ def test_to_dataset_half_precision(self, df: pl.DataFrame) -> None: torch.tensor([[1.0000, 1.5000], [1.0000, 0.0000]], dtype=torch.float32), torch.tensor([1.0, 2.0], dtype=torch.float16), ) - self.assert_tensor(expected, ts) + self.assert_tensor_equal(expected, ts) # no labels dsf16 = df.to_torch("dataset").half() @@ -192,7 +196,7 @@ def test_to_dataset_half_precision(self, df: pl.DataFrame) -> None: dtype=torch.float16, ), ) - self.assert_tensor(expected, ts) + self.assert_tensor_equal(expected, ts) @pytest.mark.parametrize( ("label", "features"), @@ -214,7 +218,7 @@ def test_to_torch_labelled_dataset( ] assert len(ts) == len(expected) for actual, exp in zip(ts, expected): - self.assert_tensor(exp, actual) + self.assert_tensor_equal(exp, actual) def test_to_torch_labelled_dataset_expr(self, df: pl.DataFrame) -> None: ds = df.to_torch( @@ -232,7 +236,7 @@ def test_to_torch_labelled_dataset_expr(self, df: pl.DataFrame) -> None: ) assert len(data) == len(expected) for actual, exp in zip(data, expected): - self.assert_tensor(exp, actual) + self.assert_tensor_equal(exp, actual) def test_to_torch_labelled_dataset_multi(self, df: pl.DataFrame) -> None: ds = df.to_torch("dataset", label=["x", "y"]) @@ -254,7 +258,7 @@ def test_to_torch_labelled_dataset_multi(self, df: pl.DataFrame) -> None: for actual, exp in zip(ts, expected): assert len(actual) == len(exp) for a, e in zip(actual, exp): - self.assert_tensor(e, a) + self.assert_tensor_equal(e, a) def test_misc_errors(self, df: pl.DataFrame) -> None: ds = df.to_torch("dataset") @@ -279,12 +283,12 @@ def test_misc_errors(self, df: pl.DataFrame) -> None: with pytest.raises( ValueError, - match="`label` and `features` parameters .* when `return_type='dataset'`", + match="`label` and `features` only apply when `return_type` is 'dataset' or 'dict'", ): _res3 = df.to_torch(label="stroopwafel") with pytest.raises( ValueError, - match="`label` and `features` parameters .* when `return_type='dataset'`", + match="`label` is required if setting `features` when `return_type='dict'", ): _res4 = df.to_torch("dict", features=cs.float())