diff --git a/crates/polars-arrow/src/compute/decimal.rs b/crates/polars-arrow/src/compute/decimal.rs index d4199d260362..cc6dc3eab321 100644 --- a/crates/polars-arrow/src/compute/decimal.rs +++ b/crates/polars-arrow/src/compute/decimal.rs @@ -1,4 +1,5 @@ use atoi::FromRadix10SignedChecked; +use num_traits::Euclid; /// Count the number of b'0's at the beginning of a slice. fn leading_zeros(bytes: &[u8]) -> u8 { @@ -178,34 +179,48 @@ pub fn format_decimal(v: i128, scale: usize, trim_zeros: bool) -> FormatBuffer { const ZEROS: [u8; BUF_LEN] = [b'0'; BUF_LEN]; let mut buf = FormatBuffer::new(); - let factor = POW10[scale]; //10_i128.pow(scale as _); - let (div, rem) = (v / factor, v.abs() % factor); + let factor = POW10[scale]; + let (div, rem) = v.abs().div_rem_euclid(&factor); unsafe { let mut ptr = buf.data.as_mut_ptr(); - if div == 0 && v < 0 { + if v < 0 { *ptr = b'-'; - ptr = ptr.add(1); buf.len = 1; + ptr = ptr.add(1); } let n_whole = itoap::write_to_ptr(ptr, div); buf.len += n_whole; + ptr = ptr.add(n_whole); + + if scale == 0 { + return buf; + } + + *ptr = b'.'; + ptr = ptr.add(1); + if rem != 0 { - ptr = ptr.add(n_whole); - *ptr = b'.'; - ptr = ptr.add(1); let mut frac_buf = [0_u8; BUF_LEN]; let n_frac = itoap::write_to_ptr(frac_buf.as_mut_ptr(), rem); std::ptr::copy_nonoverlapping(ZEROS.as_ptr(), ptr, scale - n_frac); ptr = ptr.add(scale - n_frac); std::ptr::copy_nonoverlapping(frac_buf.as_mut_ptr(), ptr, n_frac); - buf.len += 1 + scale; - if trim_zeros { - ptr = ptr.add(n_frac - 1); - while *ptr == b'0' { - ptr = ptr.sub(1); - buf.len -= 1; - } + ptr = ptr.add(n_frac); + } else { + std::ptr::copy_nonoverlapping(ZEROS.as_ptr(), ptr, scale); + ptr = ptr.add(scale); + } + buf.len += 1 + scale; + + if trim_zeros { + ptr = ptr.sub(1); + while *ptr == b'0' { + ptr = ptr.sub(1); + buf.len -= 1; + } + if *ptr == b'.' { + buf.len -= 1; } } } diff --git a/py-polars/tests/unit/datatypes/test_decimal.py b/py-polars/tests/unit/datatypes/test_decimal.py index 5661f542ad2e..1e95adb4a3e7 100644 --- a/py-polars/tests/unit/datatypes/test_decimal.py +++ b/py-polars/tests/unit/datatypes/test_decimal.py @@ -70,30 +70,33 @@ class Y: @pytest.mark.parametrize( - ("trim_zeros", "expected"), + ("input", "trim_zeros", "expected"), [ - (True, "0.01"), - (False, "0.010000000000000000000000000"), + ("0.00", True, "0"), + ("0.00", False, "0.00"), + ("-1", True, "-1"), + ("-1.000000000000000000000000000", False, "-1.000000000000000000000000000"), + ("0.0100", True, "0.01"), + ("0.0100", False, "0.0100"), + ("0.010000000000000000000000000", False, "0.010000000000000000000000000"), + ("-1.123801239123981293891283123", True, "-1.123801239123981293891283123"), + ( + "12345678901.234567890123458390192857685", + True, + "12345678901.234567890123458390192857685", + ), + ( + "-99999999999.999999999999999999999999999", + True, + "-99999999999.999999999999999999999999999", + ), ], ) -def test_to_from_pydecimal_and_format(trim_zeros: bool, expected: str) -> None: - dec_strs = [ - "0", - "-1", - expected, - "-1.123801239123981293891283123", - "12345678901.234567890123458390192857685", - "-99999999999.999999999999999999999999999", - ] +def test_decimal_format(input: str, trim_zeros: bool, expected: str) -> None: with pl.Config(trim_decimal_zeros=trim_zeros): - formatted = ( - str(pl.Series(list(map(D, dec_strs)))) - .split("[", 1)[1] - .split("\n", 1)[1] - .strip()[1:-1] - .split() - ) - assert formatted == dec_strs + series = pl.Series([input]).str.to_decimal() + formatted = str(series).split("\n")[-2].strip() + assert formatted == expected def test_init_decimal_dtype() -> None: