Skip to content

Commit

Permalink
feat: Support Decimal read from IPC (#15965)
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion authored May 1, 2024
1 parent c9e786b commit 31eaabe
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 27 deletions.
2 changes: 1 addition & 1 deletion crates/polars-arrow/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ impl std::fmt::Debug for dyn Array + '_ {
match self.data_type().to_physical_type() {
Null => fmt_dyn!(self, NullArray, f),
Boolean => fmt_dyn!(self, BooleanArray, f),
Primitive(primitive) => with_match_primitive_type!(primitive, |$T| {
Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| {
fmt_dyn!(self, PrimitiveArray<$T>, f)
}),
BinaryView => fmt_dyn!(self, BinaryViewArray, f),
Expand Down
89 changes: 68 additions & 21 deletions crates/polars-arrow/src/mmap/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::io::ipc::read::{Dictionaries, IpcBuffer, Node, OutOfSpecKind};
use crate::io::ipc::IpcField;
use crate::offset::Offset;
use crate::types::NativeType;
use crate::{match_integer_type, with_match_primitive_type};
use crate::{match_integer_type, with_match_primitive_type_full};

fn get_buffer_bounds(buffers: &mut VecDeque<IpcBuffer>) -> PolarsResult<(usize, usize)> {
let buffer = buffers.pop_front().ok_or_else(
Expand All @@ -29,6 +29,19 @@ fn get_buffer_bounds(buffers: &mut VecDeque<IpcBuffer>) -> PolarsResult<(usize,
Ok((offset, length))
}

/// Checks that the length of `bytes` is at least `size_of::<T>() * expected_len`, and
/// returns a boolean indicating whether it is aligned.
fn check_bytes_len_and_is_aligned<T: NativeType>(
bytes: &[u8],
expected_len: usize,
) -> PolarsResult<bool> {
if bytes.len() < std::mem::size_of::<T>() * expected_len {
polars_bail!(ComputeError: "buffer's length is too small in mmap")
};

Ok(bytemuck::try_cast_slice::<_, T>(bytes).is_ok())
}

fn get_buffer<'a, T: NativeType>(
data: &'a [u8],
block_offset: usize,
Expand All @@ -42,13 +55,8 @@ fn get_buffer<'a, T: NativeType>(
.get(block_offset + offset..block_offset + offset + length)
.ok_or_else(|| polars_err!(ComputeError: "buffer out of bounds"))?;

// validate alignment
let v: &[T] = bytemuck::try_cast_slice(values)
.map_err(|_| polars_err!(ComputeError: "buffer not aligned for mmap"))?;

if v.len() < num_rows {
polars_bail!(ComputeError: "buffer's length is too small in mmap",
)
if !check_bytes_len_and_is_aligned::<T>(values, num_rows)? {
polars_bail!(ComputeError: "buffer not aligned for mmap");
}

Ok(values)
Expand Down Expand Up @@ -270,19 +278,58 @@ fn mmap_primitive<P: NativeType, T: AsRef<[u8]>>(

let validity = get_validity(data_ref, block_offset, buffers, null_count)?.map(|x| x.as_ptr());

let values = get_buffer::<P>(data_ref, block_offset, buffers, num_rows)?.as_ptr();
let bytes = get_bytes(data_ref, block_offset, buffers)?;
let is_aligned = check_bytes_len_and_is_aligned::<P>(bytes, num_rows)?;

Ok(unsafe {
create_array(
data,
num_rows,
null_count,
[validity, Some(values)].into_iter(),
[].into_iter(),
None,
None,
)
})
let out = if is_aligned || std::mem::size_of::<T>() <= 8 {
assert!(
is_aligned,
"primitive type with size <= 8 bytes should have been aligned"
);
let bytes_ptr = bytes.as_ptr();

unsafe {
create_array(
data,
num_rows,
null_count,
[validity, Some(bytes_ptr)].into_iter(),
[].into_iter(),
None,
None,
)
}
} else {
let mut values = vec![P::default(); num_rows];
unsafe {
std::ptr::copy_nonoverlapping(
bytes.as_ptr(),
values.as_mut_ptr() as *mut u8,
bytes.len(),
)
};
// Now we need to keep the new buffer alive
let owned_data = Arc::new((
// We can drop the original ref if we don't have a validity
validity.and(Some(data)),
values,
));
let bytes_ptr = owned_data.1.as_ptr() as *mut u8;

unsafe {
create_array(
owned_data,
num_rows,
null_count,
[validity, Some(bytes_ptr)].into_iter(),
[].into_iter(),
None,
None,
)
}
};

Ok(out)
}

#[allow(clippy::too_many_arguments)]
Expand Down Expand Up @@ -482,7 +529,7 @@ fn get_array<T: AsRef<[u8]>>(
match data_type.to_physical_type() {
Null => mmap_null(data, &node, block_offset, buffers),
Boolean => mmap_boolean(data, &node, block_offset, buffers),
Primitive(p) => with_match_primitive_type!(p, |$T| {
Primitive(p) => with_match_primitive_type_full!(p, |$T| {
mmap_primitive::<$T, _>(data, &node, block_offset, buffers)
}),
Utf8 | Binary => mmap_binary::<i32, _>(data, &node, block_offset, buffers),
Expand Down
5 changes: 0 additions & 5 deletions crates/polars-core/src/series/from.rs
Original file line number Diff line number Diff line change
Expand Up @@ -443,11 +443,6 @@ impl Series {
Ok(StructChunked::new_unchecked(name, &fields).into_series())
},
ArrowDataType::FixedSizeBinary(_) => {
if verbose() {
eprintln!(
"Polars does not support decimal types so the 'Series' are read as Float64"
);
}
let chunks = cast_chunks(&chunks, &DataType::Binary, true)?;
Ok(BinaryChunked::from_chunks(name, chunks).into_series())
},
Expand Down
31 changes: 31 additions & 0 deletions py-polars/tests/unit/io/test_ipc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import io
from decimal import Decimal
from typing import TYPE_CHECKING, Any

import pandas as pd
Expand Down Expand Up @@ -306,3 +307,33 @@ def test_read_ipc_only_loads_selected_columns(
# Only one column's worth of memory should be used; 2 columns would be
# 32_000_000 at least, but there's some overhead.
assert 16_000_000 < memory_usage_without_pyarrow.get_peak() < 23_000_000


@pytest.mark.write_disk()
def test_ipc_decimal_15920(
monkeypatch: Any,
tmp_path: Path,
) -> None:
monkeypatch.setenv("POLARS_ACTIVATE_DECIMAL", "1")
tmp_path.mkdir(exist_ok=True)

base_df = pl.Series(
"x",
[
*[
Decimal(x)
for x in [
"10.1", "11.2", "12.3", "13.4", "14.5", "15.6", "16.7", "17.8", "18.9", "19.0",
"20.1", "21.2", "22.3", "23.4", "24.5", "25.6", "26.7", "27.8", "28.9", "29.0",
"30.1", "31.2", "32.3", "33.4", "34.5", "35.6", "36.7", "37.8", "38.9", "39.0"
]
],
*(50 * [None])
],
dtype=pl.Decimal(18, 2),
).to_frame() # fmt: skip

for df in [base_df, base_df.drop_nulls()]:
path = f"{tmp_path}/data"
df.write_ipc(path)
assert_frame_equal(pl.read_ipc(path), df)

0 comments on commit 31eaabe

Please sign in to comment.