Skip to content

Commit

Permalink
[BP] Fix boolean array for arrow-backed DF. (#10527) (#10901)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Oct 17, 2024
1 parent 7f87b9e commit a4c6cde
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 44 deletions.
32 changes: 2 additions & 30 deletions python-package/xgboost/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def pandas_pa_type(ser: Any) -> np.ndarray:
# combine_chunks takes the most significant amount of time
chunk: pa.Array = aa.combine_chunks()
# When there's null value, we have to use copy
zero_copy = chunk.null_count == 0
zero_copy = chunk.null_count == 0 and not pa.types.is_boolean(chunk.type)
# Alternately, we can use chunk.buffers(), which returns a list of buffers and
# we need to concatenate them ourselves.
# FIXME(jiamingy): Is there a better way to access the arrow buffer along with
Expand Down Expand Up @@ -825,37 +825,9 @@ def _arrow_transform(data: DataType) -> Any:

data = cast(pa.Table, data)

def type_mapper(dtype: pa.DataType) -> Optional[str]:
"""Maps pyarrow type to pandas arrow extension type."""
if pa.types.is_int8(dtype):
return pd.ArrowDtype(pa.int8())
if pa.types.is_int16(dtype):
return pd.ArrowDtype(pa.int16())
if pa.types.is_int32(dtype):
return pd.ArrowDtype(pa.int32())
if pa.types.is_int64(dtype):
return pd.ArrowDtype(pa.int64())
if pa.types.is_uint8(dtype):
return pd.ArrowDtype(pa.uint8())
if pa.types.is_uint16(dtype):
return pd.ArrowDtype(pa.uint16())
if pa.types.is_uint32(dtype):
return pd.ArrowDtype(pa.uint32())
if pa.types.is_uint64(dtype):
return pd.ArrowDtype(pa.uint64())
if pa.types.is_float16(dtype):
return pd.ArrowDtype(pa.float16())
if pa.types.is_float32(dtype):
return pd.ArrowDtype(pa.float32())
if pa.types.is_float64(dtype):
return pd.ArrowDtype(pa.float64())
if pa.types.is_boolean(dtype):
return pd.ArrowDtype(pa.bool_())
return None

# For common cases, this is zero-copy, can check with:
# pa.total_allocated_bytes()
df = data.to_pandas(types_mapper=type_mapper)
df = data.to_pandas(types_mapper=pd.ArrowDtype)
return df


Expand Down
33 changes: 19 additions & 14 deletions python-package/xgboost/testing/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,6 @@ def pd_arrow_dtypes() -> Generator:

# Integer
dtypes = pandas_pyarrow_mapper
Null: Union[float, None, Any] = np.nan
orig = pd.DataFrame(
{"f0": [1, 2, Null, 3], "f1": [4, 3, Null, 1]}, dtype=np.float32
)
# Create a dictionary-backed dataframe, enable this when the roundtrip is
# implemented in pandas/pyarrow
#
Expand All @@ -191,24 +187,33 @@ def pd_arrow_dtypes() -> Generator:
# pd_catcodes = pd_cat_df["f1"].cat.codes
# assert pd_catcodes.equals(pa_catcodes)

for Null in (None, pd.NA):
for Null in (None, pd.NA, 0):
for dtype in dtypes:
if dtype.startswith("float16") or dtype.startswith("bool"):
continue
# Use np.nan is a baseline
orig_null = Null if not pd.isna(Null) and Null == 0 else np.nan
orig = pd.DataFrame(
{"f0": [1, 2, orig_null, 3], "f1": [4, 3, orig_null, 1]},
dtype=np.float32,
)

df = pd.DataFrame(
{"f0": [1, 2, Null, 3], "f1": [4, 3, Null, 1]}, dtype=dtype
)
yield orig, df

orig = pd.DataFrame(
{"f0": [True, False, pd.NA, True], "f1": [False, True, pd.NA, True]},
dtype=pd.BooleanDtype(),
)
df = pd.DataFrame(
{"f0": [True, False, pd.NA, True], "f1": [False, True, pd.NA, True]},
dtype=pd.ArrowDtype(pa.bool_()),
)
yield orig, df
# If Null is `False`, then there's no missing value.
for Null in (pd.NA, False):
orig = pd.DataFrame(
{"f0": [True, False, Null, True], "f1": [False, True, Null, True]},
dtype=pd.BooleanDtype(),
)
df = pd.DataFrame(
{"f0": [True, False, Null, True], "f1": [False, True, Null, True]},
dtype=pd.ArrowDtype(pa.bool_()),
)
yield orig, df


def check_inf(rng: RNG) -> None:
Expand Down

0 comments on commit a4c6cde

Please sign in to comment.