Skip to content

Commit

Permalink
skip a test under py3.8
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed May 17, 2024
1 parent c022600 commit 92bfe04
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
3 changes: 2 additions & 1 deletion py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1684,7 +1684,8 @@ def to_jax(
directs otherwise (eg: "jax_enable_x64" was set True in the config object
at startup, or "JAX_ENABLE_X64" is set to "1" in the environment).
order : {"c", "fortran"}
The index order of the returned Jax array, either C-like or Fortran-like.
The index order of the returned Jax array, either C-like (row-major) or
Fortran-like (column-major).
See Also
--------
Expand Down
7 changes: 6 additions & 1 deletion py-polars/tests/unit/ml/test_to_jax.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import sys
from typing import Any

import pytest
Expand All @@ -19,7 +20,7 @@ def df() -> pl.DataFrame:
return pl.DataFrame(
{
"x": [1, 2, 2, 3],
"y": [True, False, True, False],
"y": [1, 0, 1, 0],
"z": [1.5, -0.5, 0.0, -2.0],
},
schema_overrides={"x": pl.Int8, "z": pl.Float32},
Expand Down Expand Up @@ -78,6 +79,10 @@ def test_to_jax_dict(self, df: pl.DataFrame) -> None:
arr_dict["z"], jxn.array([1.5, -0.5, 0.0, -2.0], dtype=jxn.float32)
)

@pytest.mark.skipif(
sys.version_info < (3, 9),
reason="jax.numpy.bool requires Python >= 3.9",
)
def test_to_jax_feature_label_dict(self, df: pl.DataFrame) -> None:
df = pl.DataFrame(
{
Expand Down

0 comments on commit 92bfe04

Please sign in to comment.