Skip to content

Commit

Permalink
fix(python,rust): fix inconsistent decimal formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
Julian-J-S committed Apr 3, 2024
1 parent 273adf5 commit f973038
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 34 deletions.
43 changes: 29 additions & 14 deletions crates/polars-arrow/src/compute/decimal.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use atoi::FromRadix10SignedChecked;
use num_traits::Euclid;

/// Count the number of b'0's at the beginning of a slice.
fn leading_zeros(bytes: &[u8]) -> u8 {
Expand Down Expand Up @@ -178,34 +179,48 @@ pub fn format_decimal(v: i128, scale: usize, trim_zeros: bool) -> FormatBuffer {
const ZEROS: [u8; BUF_LEN] = [b'0'; BUF_LEN];

let mut buf = FormatBuffer::new();
let factor = POW10[scale]; //10_i128.pow(scale as _);
let (div, rem) = (v / factor, v.abs() % factor);
let factor = POW10[scale];
let (div, rem) = v.abs().div_rem_euclid(&factor);

unsafe {
let mut ptr = buf.data.as_mut_ptr();
if div == 0 && v < 0 {
if v < 0 {
*ptr = b'-';
ptr = ptr.add(1);
buf.len = 1;
ptr = ptr.add(1);
}
let n_whole = itoap::write_to_ptr(ptr, div);
buf.len += n_whole;
ptr = ptr.add(n_whole);

if scale == 0 {
return buf;
}

*ptr = b'.';
ptr = ptr.add(1);

if rem != 0 {
ptr = ptr.add(n_whole);
*ptr = b'.';
ptr = ptr.add(1);
let mut frac_buf = [0_u8; BUF_LEN];
let n_frac = itoap::write_to_ptr(frac_buf.as_mut_ptr(), rem);
std::ptr::copy_nonoverlapping(ZEROS.as_ptr(), ptr, scale - n_frac);
ptr = ptr.add(scale - n_frac);
std::ptr::copy_nonoverlapping(frac_buf.as_mut_ptr(), ptr, n_frac);
buf.len += 1 + scale;
if trim_zeros {
ptr = ptr.add(n_frac - 1);
while *ptr == b'0' {
ptr = ptr.sub(1);
buf.len -= 1;
}
ptr = ptr.add(n_frac);
} else {
std::ptr::copy_nonoverlapping(ZEROS.as_ptr(), ptr, scale);
ptr = ptr.add(scale);
}
buf.len += 1 + scale;

if trim_zeros {
ptr = ptr.sub(1);
while *ptr == b'0' {
ptr = ptr.sub(1);
buf.len -= 1;
}
if *ptr == b'.' {
buf.len -= 1;
}
}
}
Expand Down
35 changes: 15 additions & 20 deletions py-polars/tests/unit/datatypes/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,30 +69,25 @@ class Y:


@pytest.mark.parametrize(
("trim_zeros", "expected"),
("input", "trim_zeros", "expected"),
[
(True, "0.01"),
(False, "0.010000000000000000000000000"),
("0.00", True, "0"),
("0.00", False, "0.00"),
("-1", True, "-1"),
("-1.000000000000000000000000000", False, "-1.000000000000000000000000000"),
("0.0100", True, "0.01"),
("0.0100", False, "0.0100"),
("0.010000000000000000000000000", False, "0.010000000000000000000000000"),
("-1.123801239123981293891283123", True, "-1.123801239123981293891283123"),
("12345678901.234567890123458390192857685", True, "12345678901.234567890123458390192857685"),
("-99999999999.999999999999999999999999999", True, "-99999999999.999999999999999999999999999"),
],
)
def test_to_from_pydecimal_and_format(trim_zeros: bool, expected: str) -> None:
dec_strs = [
"0",
"-1",
expected,
"-1.123801239123981293891283123",
"12345678901.234567890123458390192857685",
"-99999999999.999999999999999999999999999",
]
def test_decimal_format(input: str, trim_zeros: bool, expected: str) -> None:
with pl.Config(trim_decimal_zeros=trim_zeros):
formatted = (
str(pl.Series(list(map(D, dec_strs))))
.split("[", 1)[1]
.split("\n", 1)[1]
.strip()[1:-1]
.split()
)
assert formatted == dec_strs
series = pl.Series([input]).str.to_decimal()
formatted = str(series).split('\n')[-2].strip()
assert formatted == expected


def test_init_decimal_dtype() -> None:
Expand Down

0 comments on commit f973038

Please sign in to comment.