Skip to content

Commit

Permalink
[data] Sort with None (ray-project#48750)
Browse files Browse the repository at this point in the history
## Why are these changes needed?

Adds a Sentinel value for making it possible to sort.

Fixes ray-project#42142 

## Related issue number

<!-- For example: "Closes ray-project#1234" -->

## Checks

- [ ] I've signed off every commit(by using the -s flag, i.e., `git
commit -s`) in this PR.
- [ ] I've run `scripts/format.sh` to lint the changes in this PR.
- [ ] I've included any doc changes needed for
https://docs.ray.io/en/master/.
- [ ] I've added any new APIs to the API Reference. For example, if I
added a
method in Tune, I've added it in `doc/source/tune/api/` under the
           corresponding `.rst` file.
- [ ] I've made sure the tests are passing. Note that there might be a
few flaky tests, see the recent failures at https://flakey-tests.ray.io/
- Testing Strategy
   - [ ] Unit tests
   - [ ] Release tests
   - [ ] This PR is not tested :(

---------

Signed-off-by: Richard Liaw <[email protected]>
  • Loading branch information
richardliaw authored Nov 15, 2024
1 parent 42d101e commit 134e5ec
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 5 deletions.
10 changes: 7 additions & 3 deletions python/ray/data/_internal/arrow_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)
from ray.data._internal.row import TableRow
from ray.data._internal.table_block import TableBlockAccessor, TableBlockBuilder
from ray.data._internal.util import find_partitions
from ray.data._internal.util import NULL_SENTINEL, find_partitions
from ray.data.block import (
Block,
BlockAccessor,
Expand Down Expand Up @@ -500,7 +500,6 @@ def sort_and_partition(
table = sort(self._table, sort_key)
if len(boundaries) == 0:
return [table]

return find_partitions(table, boundaries, sort_key)

def combine(self, sort_key: "SortKey", aggs: Tuple["AggregateFn"]) -> Block:
Expand Down Expand Up @@ -634,6 +633,11 @@ def key_fn(r):
else:
return (0,)

# Replace Nones with NULL_SENTINEL to ensure safe sorting.
def key_fn_with_null_sentinel(r):
values = key_fn(r)
return [NULL_SENTINEL if v is None else v for v in values]

# Handle blocks of different types.
blocks = TableBlockAccessor.normalize_block_types(blocks, "arrow")

Expand All @@ -642,7 +646,7 @@ def key_fn(r):
ArrowBlockAccessor(block).iter_rows(public_row_format=False)
for block in blocks
],
key=key_fn,
key=key_fn_with_null_sentinel,
)
next_row = None
builder = ArrowBlockBuilder()
Expand Down
21 changes: 19 additions & 2 deletions python/ray/data/_internal/planner/exchange/sort_task_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ray.data._internal.progress_bar import ProgressBar
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.table_block import TableBlockAccessor
from ray.data._internal.util import NULL_SENTINEL
from ray.data.block import Block, BlockAccessor, BlockExecStats, BlockMetadata
from ray.types import ObjectRef

Expand All @@ -23,7 +24,7 @@ def __init__(
self,
key: Optional[Union[str, List[str]]] = None,
descending: Union[bool, List[bool]] = False,
boundaries: Optional[list] = None,
boundaries: Optional[List[T]] = None,
):
if key is None:
key = []
Expand Down Expand Up @@ -195,7 +196,23 @@ def sample_boundaries(
samples_table = builder.build()
samples_dict = BlockAccessor.for_block(samples_table).to_numpy(columns=columns)
# This zip does the transposition from list of column values to list of tuples.
samples_list = sorted(zip(*samples_dict.values()))
samples_list = list(zip(*samples_dict.values()))

def is_na(x):
# Check if x is None or NaN. Type casting to np.array first to avoid
# isnan failing on strings and other types.
if x is None:
return True
x = np.asarray(x)
if np.issubdtype(x.dtype, np.number):
return np.isnan(x)
return False

def key_fn_with_nones(sample):
return tuple(NULL_SENTINEL if is_na(x) else x for x in sample)

# Sort the list, but Nones should be NULL_SENTINEL to ensure safe sorting.
samples_list = sorted(samples_list, key=key_fn_with_nones)

# Each boundary corresponds to a quantile of the data.
quantile_indices = [
Expand Down
32 changes: 32 additions & 0 deletions python/ray/data/_internal/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,28 @@
_pyarrow_dataset: LazyModule = None


class _NullSentinel:
"""Sentinel value that sorts greater than any other value."""

def __eq__(self, other):
return isinstance(other, _NullSentinel)

def __lt__(self, other):
return False

def __le__(self, other):
return isinstance(other, _NullSentinel)

def __gt__(self, other):
return True

def __ge__(self, other):
return True


NULL_SENTINEL = _NullSentinel()


def _lazy_import_pyarrow_dataset() -> LazyModule:
global _pyarrow_dataset
if _pyarrow_dataset is None:
Expand Down Expand Up @@ -723,6 +745,16 @@ def find_partition_index(
col_vals = table[col_name].to_numpy()[left:right]
desired_val = desired[i]

# Handle null values - replace them with sentinel values
if desired_val is None:
desired_val = NULL_SENTINEL

# Replace None/NaN values in col_vals with sentinel
null_mask = col_vals == None # noqa: E711
if null_mask.any():
col_vals = col_vals.copy() # Make a copy to avoid modifying original
col_vals[null_mask] = NULL_SENTINEL

prevleft = left
if descending is True:
left = prevleft + (
Expand Down
59 changes: 59 additions & 0 deletions python/ray/data/tests/test_all_to_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,65 @@ def test_unique(ray_start_regular_shared):
assert mock_validate.call_args_list[0].args[0].names == ["b"]


@pytest.mark.parametrize("batch_format", ["pandas", "pyarrow"])
def test_unique_with_nulls(ray_start_regular_shared, batch_format):
ds = ray.data.from_items([3, 2, 3, 1, 2, 3, None])
assert set(ds.unique("item")) == {1, 2, 3, None}
assert len(ds.unique("item")) == 4

ds = ray.data.from_items(
[
{"a": 1, "b": 1},
{"a": 1, "b": 2},
{"a": 1, "b": None},
{"a": None, "b": 3},
{"a": None, "b": 4},
]
)
assert set(ds.unique("a")) == {1, None}
assert len(ds.unique("a")) == 2
assert set(ds.unique("b")) == {1, 2, 3, 4, None}
assert len(ds.unique("b")) == 5

# Check with 3 columns
df = pd.DataFrame(
{
"col1": [1, 2, None, 3, None, 3, 2],
"col2": [None, 2, 2, 3, None, 3, 2],
"col3": [1, None, 2, None, None, None, 2],
}
)
# df["col"].unique() works fine, as expected
ds2 = ray.data.from_pandas(df)
ds2 = ds2.map_batches(lambda x: x, batch_format=batch_format)
assert set(ds2.unique("col1")) == {1, 2, 3, None}
assert len(ds2.unique("col1")) == 4
assert set(ds2.unique("col2")) == {2, 3, None}
assert len(ds2.unique("col2")) == 3
assert set(ds2.unique("col3")) == {1, 2, None}
assert len(ds2.unique("col3")) == 3

# Check with 3 columns and different dtypes
df = pd.DataFrame(
{
"col1": [1, 2, None, 3, None, 3, 2],
"col2": [None, 2, 2, 3, None, 3, 2],
"col3": [1, None, 2, None, None, None, 2],
}
)
df["col1"] = df["col1"].astype("Int64")
df["col2"] = df["col2"].astype("Float64")
df["col3"] = df["col3"].astype("string")
ds3 = ray.data.from_pandas(df)
ds3 = ds3.map_batches(lambda x: x, batch_format=batch_format)
assert set(ds3.unique("col1")) == {1, 2, 3, None}
assert len(ds3.unique("col1")) == 4
assert set(ds3.unique("col2")) == {2, 3, None}
assert len(ds3.unique("col2")) == 3
assert set(ds3.unique("col3")) == {"1.0", "2.0", None}
assert len(ds3.unique("col3")) == 3


def test_grouped_dataset_repr(ray_start_regular_shared):
ds = ray.data.from_items([{"key": "spam"}, {"key": "ham"}, {"key": "spam"}])
assert repr(ds.groupby("key")) == f"GroupedData(dataset={ds!r}, key='key')"
Expand Down
16 changes: 16 additions & 0 deletions python/ray/data/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from ray.data._internal.remote_fn import _make_hashable, cached_remote_fn
from ray.data._internal.util import (
NULL_SENTINEL,
_check_pyarrow_version,
_split_list,
iterate_with_retry,
Expand All @@ -35,6 +36,21 @@ def foo():
assert cpu_only_foo != gpu_only_foo


def test_null_sentinel():
"""Check that NULL_SENTINEL sorts greater than any other value."""
assert NULL_SENTINEL > 1000
assert NULL_SENTINEL > "abc"
assert NULL_SENTINEL == NULL_SENTINEL
assert NULL_SENTINEL != 1000
assert NULL_SENTINEL != "abc"
assert not NULL_SENTINEL < 1000
assert not NULL_SENTINEL < "abc"
assert not NULL_SENTINEL <= 1000
assert not NULL_SENTINEL <= "abc"
assert NULL_SENTINEL >= 1000
assert NULL_SENTINEL >= "abc"


def test_make_hashable():
valid_args = {
"int": 0,
Expand Down

0 comments on commit 134e5ec

Please sign in to comment.