Skip to content

Commit

Permalink
feat(python): Add new to_jax method to support export to jax arrays…
Browse files Browse the repository at this point in the history
… from `DataFrame`
  • Loading branch information
alexander-beedie committed May 17, 2024
1 parent 7fa728e commit 23ec6da
Show file tree
Hide file tree
Showing 8 changed files with 467 additions and 47 deletions.
270 changes: 253 additions & 17 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
INTEGER_DTYPES,
N_INFER_DEFAULT,
Boolean,
Float32,
Float64,
Int32,
Int64,
Expand Down Expand Up @@ -113,7 +114,7 @@
)
from polars.selectors import _expand_selector_dicts, _expand_selectors
from polars.slice import PolarsSlice
from polars.type_aliases import DbWriteMode, TorchExportType
from polars.type_aliases import DbWriteMode, JaxExportType, TorchExportType

with contextlib.suppress(ImportError): # Module not available when building docs
from polars.polars import dtype_str_repr as _dtype_str_repr
Expand All @@ -126,6 +127,7 @@
from typing import Literal

import deltalake
import jax
import torch
from hvplot.plotting.core import hvPlotTabularPolars
from xlsxwriter import Workbook
Expand Down Expand Up @@ -1527,7 +1529,7 @@ def to_numpy(
However, the C-like order might be more appropriate to use for downstream
applications to prevent cloning data, e.g. when reshaping into a
one-dimensional array. Note that this option only takes effect if
`structured` is set to `False` and the DataFrame dtypes allow for a
`structured` is set to `False` and the DataFrame dtypes allow a
global dtype for all columns.
allow_copy
Allow memory to be copied to perform the conversion. If set to `False`,
Expand Down Expand Up @@ -1620,6 +1622,200 @@ def raise_on_copy(msg: str) -> None:

return out

@overload
def to_jax(
self,
return_type: Literal["array"] = ...,
*,
device: jax.Device | str | None = ...,
label: str | Expr | Sequence[str | Expr] | None = ...,
features: str | Expr | Sequence[str | Expr] | None = ...,
dtype: PolarsDataType | None = ...,
order: IndexOrder = ...,
) -> jax.Array: ...

@overload
def to_jax(
self,
return_type: Literal["dict"],
*,
device: jax.Device | str | None = ...,
label: str | Expr | Sequence[str | Expr] | None = ...,
features: str | Expr | Sequence[str | Expr] | None = ...,
dtype: PolarsDataType | None = ...,
order: IndexOrder = ...,
) -> dict[str, jax.Array]: ...

def to_jax(
self,
return_type: JaxExportType = "array",
*,
device: jax.Device | str | None = None,
label: str | Expr | Sequence[str | Expr] | None = None,
features: str | Expr | Sequence[str | Expr] | None = None,
dtype: PolarsDataType | None = None,
order: IndexOrder = "fortran",
) -> jax.Array | dict[str, jax.Array]:
"""
Convert DataFrame to a 2D Jax Array, or dict of Jax Arrays.
Parameters
----------
return_type : {"array", "dict"}
Set return type; a 2D Jax Array, or dict of Jax Arrays.
device
Specify the jax `Device` on which the array will be created; can provide
a string (such as "cpu", "gpu", or "tpu") in which case the device is
retrieved as `jax.devices(string)[0]`. For more specific control you
can supply the instantiated `Device` directly. If None, arrays are
created on the default device.
label
One or more column names, expressions, or selectors that label the feature
data; results in a `{"label": ..., "features": ...}` dict being returned
when `return_type` is "dict" instead of a `{"col": array, }` dict.
features
One or more column names, expressions, or selectors that contain the feature
data; if omitted, all columns that are not designated as part of the label
are used. Only applies when `return_type` is "dict".
dtype
Unify the dtype of all returned arrays; this casts any column that is
not already of the required dtype before converting to Array. Note that
export will be single-precision (32bit) unless the Jax config/environment
directs otherwise (eg: "jax_enable_x64" was set True in the config object
at startup, or "JAX_ENABLE_X64" is set to "1" in the environment).
order : {"c", "fortran"}
The index order of the returned Jax array, either C-like or Fortran-like.
See Also
--------
to_dummies
to_numpy
to_torch
Examples
--------
>>> df = pl.DataFrame(
... {
... "lbl": [0, 1, 2, 3],
... "feat1": [1, 0, 0, 1],
... "feat2": [1.5, -0.5, 0.0, -2.25],
... }
... )
Standard return type (2D Array), on the standard device:
>>> df.to_jax()
Array([[ 0. , 1. , 1.5 ],
[ 1. , 0. , -0.5 ],
[ 2. , 0. , 0. ],
[ 3. , 1. , -2.25]], dtype=float32)
Create the Array on the default GPU device:
>>> a = df.to_jax(device="gpu") # doctest: +SKIP
>>> a.device() # doctest: +SKIP
GpuDevice(id=0, process_index=0)
Create the Array on a specific GPU device:
>>> gpu_device = jax.devices("gpu")[1]) # doctest: +SKIP
>>> a = df.to_jax(device=gpu_device) # doctest: +SKIP
>>> a.device() # doctest: +SKIP
GpuDevice(id=1, process_index=0)
As a dictionary of individual Arrays:
>>> df.to_jax("dict")
{'lbl': Array([0, 1, 2, 3], dtype=int32),
'feat1': Array([1, 0, 0, 1], dtype=int32),
'feat2': Array([ 1.5 , -0.5 , 0. , -2.25], dtype=float32)}
As a "label" and "features" dictionary; note that as "features" is not
declared, it defaults to all the columns that are not in "label":
>>> df.to_jax("dict", label="lbl")
{'label': Array([[0],
[1],
[2],
[3]], dtype=int32),
'features': Array([[ 1. , 1.5 ],
[ 0. , -0.5 ],
[ 0. , 0. ],
[ 1. , -2.25]], dtype=float32)}
As a "label" and "features" dictionary where each is designated using
a selector expression (which can also be used to cast the data if the
label and features are better-represented with different dtypes):
>>> import polars.selectors as cs
>>> df.to_jax(
... return_type="dict",
... features=cs.float(),
... label=pl.col("lbl").cast(pl.UInt8),
... )
{'label': Array([[0],
[1],
[2],
[3]], dtype=uint8),
'features': Array([[ 1.5 ],
[-0.5 ],
[ 0. ],
[-2.25]], dtype=float32)}
"""
if return_type != "dict" and (label is not None or features is not None):
msg = "`label` and `features` only apply when `return_type` is 'dict'"
raise ValueError(msg)
elif return_type == "dict" and label is None and features is not None:
msg = "`label` is required if setting `features` when `return_type='dict'"
raise ValueError(msg)

jx = import_optional(
"jax",
install_message="Please see `https://jax.readthedocs.io/en/latest/installation.html` "
"for specific installation recommendations for the Jax package",
)
enabled_double_precision = jx.config.jax_enable_x64 or bool(
int(os.environ.get("JAX_ENABLE_X64", "1"))
)
if dtype:
frame = self.cast(dtype)
elif not enabled_double_precision:
# enforce single-precision unless environment/config directs otherwise
frame = self.cast({Float64: Float32, Int64: Int32, UInt64: UInt32})
else:
frame = self

if isinstance(device, str):
device = jx.devices(device)[0]

with contextlib.nullcontext() if device is None else jx.default_device(device):
if return_type == "array":
return jx.numpy.asarray(
# note: jax arrays are immutable, so can avoid a copy (vs torch)
a=frame.to_numpy(writable=False, use_pyarrow=False, order=order),
order="K",
)
elif return_type == "dict":
if label is not None:
# return a {"label": array(s), "features": array(s)} dict
label_frame = frame.select(label)
features_frame = (
frame.select(features)
if features is not None
else frame.drop(*label_frame.columns)
)
return {
"label": label_frame.to_jax(),
"features": features_frame.to_jax(),
}
else:
# return a {"col": array} dict
return {srs.name: srs.to_jax() for srs in frame}
else:
valid_jax_types = ", ".join(get_args(JaxExportType))
msg = f"invalid `return_type`: {return_type!r}\nExpected one of: {valid_jax_types}"
raise ValueError(msg)

@overload
def to_torch(
self,
Expand Down Expand Up @@ -1659,31 +1855,35 @@ def to_torch(
dtype: PolarsDataType | None = None,
) -> torch.Tensor | dict[str, torch.Tensor] | PolarsDataset:
"""
Convert DataFrame to a 2D PyTorch tensor, Dataset, or dict of Tensors.
Convert DataFrame to a 2D PyTorch Tensor, Dataset, or dict of Tensors.
.. versionadded:: 0.20.23
Parameters
----------
return_type : {"tensor", "dataset", "dict"}
Set return type; a 2D PyTorch tensor, PolarsDataset (a frame-specialized
Set return type; a 2D PyTorch Tensor, PolarsDataset (a frame-specialized
TensorDataset), or dict of Tensors.
label
One or more column names, expressions, or selectors that label the feature
data; when `return_type` is "dataset", the PolarsDataset will return
`(features, label)` tensor tuples for each row. Otherwise, it returns
`(features,)` tensor tuples where the feature contains all the row data;
note that setting this parameter with any other result type will raise an
informative error.
`(features,)` tensor tuples where the feature contains all the row data.
features
One or more column names, expressions, or selectors that contain the feature
data; if omitted, all columns that are not designated as part of the label
are used. This parameter is a no-op for return-types other than "dataset".
are used.
dtype
Unify the dtype of all returned tensors; this casts any frame Series
that are not of the required dtype before converting to tensor. This
includes the label column *unless* the label is an expression (such
as `pl.col("label_column").cast(pl.Int16)`).
Unify the dtype of all returned tensors; this casts any column that is
not of the required dtype before converting to Tensor. This includes
the label column *unless* the label is an expression (such as
`pl.col("label_column").cast(pl.Int16)`).
See Also
--------
to_dummies
to_jax
to_numpy
Examples
--------
Expand All @@ -1710,6 +1910,19 @@ def to_torch(
'feat1': tensor([1, 0, 0, 1]),
'feat2': tensor([ 1.5000, -0.5000, 0.0000, -2.2500], dtype=torch.float64)}
As a "label" and "features" dictionary; note that as "features" is not
declared, it defaults to all the columns that are not in "label":
>>> df.to_torch("dict", label="lbl", dtype=pl.Float32)
{'label': tensor([[0.],
[1.],
[2.],
[3.]]),
'features': tensor([[ 1.0000, 1.5000],
[ 0.0000, -0.5000],
[ 0.0000, 0.0000],
[ 1.0000, -2.2500]])}
As a PolarsDataset, with f64 supertype:
>>> ds = df.to_torch("dataset", dtype=pl.Float64)
Expand All @@ -1722,7 +1935,7 @@ def to_torch(
(tensor([[ 0.0000, 1.0000, 1.5000],
[ 3.0000, 1.0000, -2.2500]], dtype=torch.float64),)
As a convenience the PolarsDataset can opt-in to half-precision data
As a convenience the PolarsDataset can opt in to half-precision data
for experimentation (usually this would be set on the model/pipeline):
>>> list(ds.half())
Expand All @@ -1746,7 +1959,7 @@ def to_torch(
supported).
>>> ds = df.to_torch(
... "dataset",
... return_type="dataset",
... dtype=pl.Float32,
... label=pl.col("lbl").cast(pl.Int16),
... )
Expand All @@ -1771,8 +1984,13 @@ def to_torch(
... batch_size=64,
... ) # doctest: +SKIP
"""
if return_type != "dataset" and (label is not None or features is not None):
msg = "the `label` and `features` parameters can only be set when `return_type='dataset'`"
if return_type not in ("dataset", "dict") and (
label is not None or features is not None
):
msg = "`label` and `features` only apply when `return_type` is 'dataset' or 'dict'"
raise ValueError(msg)
elif return_type == "dict" and label is None and features is not None:
msg = "`label` is required if setting `features` when `return_type='dict'"
raise ValueError(msg)

torch = import_optional("torch")
Expand All @@ -1785,10 +2003,28 @@ def to_torch(
frame = self.cast(to_dtype) # type: ignore[arg-type]

if return_type == "tensor":
# note: torch tensors are not immutable, so we must consider them writable
return torch.from_numpy(frame.to_numpy(writable=True, use_pyarrow=False))

elif return_type == "dict":
return {srs.name: srs.to_torch() for srs in frame}
if label is not None:
# return a {"label": tensor(s), "features": tensor(s)} dict
label_frame = frame.select(label)
features_frame = (
frame.select(features)
if features is not None
else frame.drop(*label_frame.columns)
)
return {
"label": label_frame.to_torch(),
"features": features_frame.to_torch(),
}
else:
# return a {"col": tensor} dict
return {srs.name: srs.to_torch() for srs in frame}

elif return_type == "dataset":
# return a torch Dataset object
from polars.ml.torch import PolarsDataset

return PolarsDataset(frame, label=label, features=features)
Expand Down
Loading

0 comments on commit 23ec6da

Please sign in to comment.