From 48c2faaf7ea2458fddb05da3aaed8d1db0dffec6 Mon Sep 17 00:00:00 2001 From: Best Olunusi Date: Fri, 19 Jan 2024 20:07:41 -0600 Subject: [PATCH] feat: allow precision specification --- babel/numbers.py | 54 +++++++++++-- tests/test_numbers.py | 172 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 219 insertions(+), 7 deletions(-) diff --git a/babel/numbers.py b/babel/numbers.py index 2240c65d5..5720d4e96 100644 --- a/babel/numbers.py +++ b/babel/numbers.py @@ -520,6 +520,7 @@ def format_decimal( format: str | NumberPattern | None = None, locale: Locale | str | None = LC_NUMERIC, decimal_quantization: bool = True, + precision: int | None = None, group_separator: bool = True, *, numbering_system: Literal["default"] | str = "latn", @@ -560,11 +561,19 @@ def format_decimal( >>> format_decimal(12345.67, locale='en_US', group_separator=True) u'12,345.67' + When you bypass locale truncation, you can specify precision: + + >>> format_decimal(1.2346, locale='en_US', decimal_quantization=False, precision=2) + u'1.23' + >>> format_decimal(0.3, locale='en_US', decimal_quantization=False, precision=2) + u'0.30' + :param number: the number to format :param format: :param locale: the `Locale` object or locale identifier :param decimal_quantization: Truncate and round high-precision numbers to the format pattern. Defaults to `True`. + :param precision: Optionally specify precision when decimal_quantization is False. :param group_separator: Boolean to switch group separator on/off in a locale's number format. :param numbering_system: The numbering system used for formatting number symbols. Defaults to "latn". @@ -576,7 +585,7 @@ def format_decimal( format = locale.decimal_formats[format] pattern = parse_pattern(format) return pattern.apply( - number, locale, decimal_quantization=decimal_quantization, group_separator=group_separator, numbering_system=numbering_system) + number, locale, decimal_quantization=decimal_quantization, precision=precision, group_separator=group_separator, numbering_system=numbering_system) def format_compact_decimal( @@ -674,6 +683,7 @@ def format_currency( currency_digits: bool = True, format_type: Literal["name", "standard", "accounting"] = "standard", decimal_quantization: bool = True, + precision: int | None = None, group_separator: bool = True, *, numbering_system: Literal["default"] | str = "latn", @@ -755,6 +765,11 @@ def format_currency( >>> format_currency(1099.9876, 'USD', locale='en_US', decimal_quantization=False) u'$1,099.9876' + When you bypass locale truncation, you can specify precision: + + >>> format_currency(1099.9876, 'USD', locale='en_US', decimal_quantization=False, precision=3) + u'$1,099.988' + :param number: the number to format :param currency: the currency code :param format: the format string to use @@ -763,6 +778,7 @@ def format_currency( :param format_type: the currency format type to use :param decimal_quantization: Truncate and round high-precision numbers to the format pattern. Defaults to `True`. + :param precision: Optionally specify precision when decimal_quantization is False. :param group_separator: Boolean to switch group separator on/off in a locale's number format. :param numbering_system: The numbering system used for formatting number symbols. Defaults to "latn". @@ -772,7 +788,7 @@ def format_currency( if format_type == 'name': return _format_currency_long_name(number, currency, format=format, locale=locale, currency_digits=currency_digits, - decimal_quantization=decimal_quantization, group_separator=group_separator, + decimal_quantization=decimal_quantization, precision=precision, group_separator=group_separator, numbering_system=numbering_system) locale = Locale.parse(locale) if format: @@ -785,7 +801,7 @@ def format_currency( return pattern.apply( number, locale, currency=currency, currency_digits=currency_digits, - decimal_quantization=decimal_quantization, group_separator=group_separator, numbering_system=numbering_system) + decimal_quantization=decimal_quantization, precision=precision, group_separator=group_separator, numbering_system=numbering_system) def _format_currency_long_name( @@ -796,6 +812,7 @@ def _format_currency_long_name( currency_digits: bool = True, format_type: Literal["name", "standard", "accounting"] = "standard", decimal_quantization: bool = True, + precision: int | None = None, group_separator: bool = True, *, numbering_system: Literal["default"] | str = "latn", @@ -825,7 +842,7 @@ def _format_currency_long_name( number_part = pattern.apply( number, locale, currency=currency, currency_digits=currency_digits, - decimal_quantization=decimal_quantization, group_separator=group_separator, numbering_system=numbering_system) + decimal_quantization=decimal_quantization, precision=precision, group_separator=group_separator, numbering_system=numbering_system) return unit_pattern.format(number_part, display_name) @@ -887,6 +904,7 @@ def format_percent( format: str | NumberPattern | None = None, locale: Locale | str | None = LC_NUMERIC, decimal_quantization: bool = True, + precision: int | None = None, group_separator: bool = True, *, numbering_system: Literal["default"] | str = "latn", @@ -922,11 +940,17 @@ def format_percent( >>> format_percent(229291.1234, locale='pt_BR', group_separator=True) u'22.929.112%' + When you bypass locale truncation, you can specify precision: + + >>> format_percent(0.0111, locale='en_US', decimal_quantization=False, precision=3) + u'1.110%' + :param number: the percent number to format :param format: :param locale: the `Locale` object or locale identifier :param decimal_quantization: Truncate and round high-precision numbers to the format pattern. Defaults to `True`. + :param precision: Optionally specify precision when decimal_quantization is False. :param group_separator: Boolean to switch group separator on/off in a locale's number format. :param numbering_system: The numbering system used for formatting number symbols. Defaults to "latn". @@ -938,7 +962,7 @@ def format_percent( format = locale.percent_formats[None] pattern = parse_pattern(format) return pattern.apply( - number, locale, decimal_quantization=decimal_quantization, group_separator=group_separator, + number, locale, decimal_quantization=decimal_quantization, precision=precision, group_separator=group_separator, numbering_system=numbering_system, ) @@ -949,6 +973,7 @@ def format_scientific( locale: Locale | str | None = LC_NUMERIC, decimal_quantization: bool = True, *, + precision: int | None = None, numbering_system: Literal["default"] | str = "latn", ) -> str: """Return value formatted in scientific notation for a specific locale. @@ -972,11 +997,17 @@ def format_scientific( >>> format_scientific(1234.9876, u'#.##E0', locale='en_US', decimal_quantization=False) u'1.2349876E3' + When you bypass locale truncation, you can specify precision: + + >>> format_scientific(000.00100, locale='en_US', decimal_quantization=False, precision=3) + u'1.000E-3' + :param number: the number to format :param format: :param locale: the `Locale` object or locale identifier :param decimal_quantization: Truncate and round high-precision numbers to the format pattern. Defaults to `True`. + :param precision: Optionally specify precision when decimal_quantization is False. :param numbering_system: The numbering system used for formatting number symbols. Defaults to "latn". The special value "default" will use the default numbering system of the locale. :raise `UnsupportedNumberingSystemError`: If the numbering system is not supported by the locale. @@ -986,7 +1017,7 @@ def format_scientific( format = locale.scientific_formats[None] pattern = parse_pattern(format) return pattern.apply( - number, locale, decimal_quantization=decimal_quantization, numbering_system=numbering_system) + number, locale, decimal_quantization=decimal_quantization, precision=precision, numbering_system=numbering_system) class NumberFormatError(ValueError): @@ -1346,6 +1377,7 @@ def apply( currency: str | None = None, currency_digits: bool = True, decimal_quantization: bool = True, + precision: int | None = None, force_frac: tuple[int, int] | None = None, group_separator: bool = True, *, @@ -1371,6 +1403,8 @@ def apply( strictly matching the CLDR definition for the locale. :type decimal_quantization: bool + :param precision: Optionally specify precision when decimal_quantization is False. + :type precision: int|None :param force_frac: DEPRECATED - a forced override for `self.frac_prec` for a single formatting invocation. :param numbering_system: The numbering system used for formatting number symbols. Defaults to "latn". @@ -1407,6 +1441,9 @@ def apply( else: frac_prec = self.frac_prec + if decimal_quantization and precision is not None: + raise ValueError("To specify precision, decimal_quantization should be set to False.") + # Bump decimal precision to the natural precision of the number if it # exceeds the one we're about to use. This adaptative precision is only # triggered if the decimal quantization is disabled or if a scientific @@ -1414,7 +1451,10 @@ def apply( # default '#E0' pattern). This special case has been extensively # discussed at https://github.com/python-babel/babel/pull/494#issuecomment-307649969 . if not decimal_quantization or (self.exp_prec and frac_prec == (0, 0)): - frac_prec = (frac_prec[0], max([frac_prec[1], get_decimal_precision(value)])) + if not decimal_quantization and precision is not None: + frac_prec = (precision, precision) + else: + frac_prec = (frac_prec[0], max([frac_prec[1], get_decimal_precision(value)])) # Render scientific notation. if self.exp_prec: diff --git a/tests/test_numbers.py b/tests/test_numbers.py index d89592a0e..c607a33ab 100644 --- a/tests/test_numbers.py +++ b/tests/test_numbers.py @@ -422,6 +422,24 @@ def test_format_decimal(): with pytest.raises(numbers.UnsupportedNumberingSystemError): numbers.format_decimal(12345.5, locale='en_US', numbering_system="unknown") +def test_format_with_specified_precision_with_decimal_quantization(): + # Specifying precision raises exception when decimal_quantization is not explicitly set to False. + + error_msg = "To specify precision, decimal_quantization should be set to False." + + with pytest.raises(ValueError, match=error_msg): + numbers.format_decimal('1.23', locale='en_US', precision=5) + + with pytest.raises(ValueError, match=error_msg): + numbers.format_currency('0.34', currency='USD', locale='en_US', decimal_quantization=True, precision=5) + + with pytest.raises(ValueError, match=error_msg): + numbers.format_scientific('0.78', locale='en_US', decimal_quantization=True, precision=5) + + with pytest.raises(ValueError, match=error_msg): + numbers.format_percent('6.7', locale='en_US', precision=5) + + @pytest.mark.parametrize('input_value, expected_value', [ ('10000', '10,000'), ('1', '1'), @@ -457,6 +475,44 @@ def test_format_decimal_quantization(): '0.9999999999', locale=locale_code, decimal_quantization=False).endswith('9999999999') is True +@pytest.mark.parametrize('input_value, precision, expected_value', [ + ('10000', 2, '10,000.00'), + ('1', 2, '1.00'), + ('1.0', 2, '1.00'), + ('1.1', 2, '1.10'), + ('1.11', 2, '1.11'), + ('1.110', 2, '1.11'), + ('1.001', 3, '1.001'), + ('1.00100', 3, '1.001'), + ('01.00100', 3, '1.001'), + ('101.00100', 3, '101.001'), + ('00000', 2, '0.00'), + ('0', 2, '0.00'), + ('0.0', 2, '0.00'), + ('0.1', 2, '0.10'), + ('0.11', 2, '0.11'), + ('0.110', 2, '0.11'), + ('0.001', 3, '0.001'), + ('0.00100', 3, '0.001'), + ('00.00100', 3, '0.001'), + ('000.00100', 3, '0.001'), +]) +def test_format_decimal_with_specified_precision(input_value, precision, expected_value): + assert numbers.format_decimal( + decimal.Decimal(input_value), locale='en_US', decimal_quantization=False, precision=precision + ) == expected_value + + +def test_format_decimal_with_specified_precision_all_locales(): + for locale_code in localedata.locale_identifiers(): + assert numbers.format_decimal( + '2.446', + locale=locale_code, + decimal_quantization=False, + precision=2 + ).endswith('45') is True + + def test_format_currency(): assert (numbers.format_currency(1099.98, 'USD', locale='en_US') == '$1,099.98') @@ -579,6 +635,48 @@ def test_format_currency_quantization(): '0.9999999999', 'USD', locale=locale_code, decimal_quantization=False).find('9999999999') > -1 +@pytest.mark.parametrize('input_value, precision, expected_value', [ + ('10000', 2, '$10,000.00'), + ('1', 2, '$1.00'), + ('1.0', 2, '$1.00'), + ('1.1', 2, '$1.10'), + ('1.11', 2, '$1.11'), + ('1.110', 2, '$1.11'), + ('1.001', 3, '$1.001'), + ('1.00100', 3, '$1.001'), + ('01.00100', 3, '$1.001'), + ('101.00100', 3, '$101.001'), + ('00000', 2, '$0.00'), + ('0', 2, '$0.00'), + ('0.0', 2, '$0.00'), + ('0.1', 2, '$0.10'), + ('0.11', 2, '$0.11'), + ('0.110', 2, '$0.11'), + ('0.001', 3, '$0.001'), + ('0.00100', 3, '$0.001'), + ('00.00100', 3, '$0.001'), + ('000.00100', 3, '$0.001'), + ('0.1', 0, '$0'), + ('0.9', 0, '$1'), + ('0.99', 0, '$1'), +]) +def test_format_currency_with_specified_precision(input_value, precision, expected_value): + assert numbers.format_currency( + decimal.Decimal(input_value), currency='USD', locale='en_US', decimal_quantization=False, precision=precision + ) == expected_value + + +def test_format_currency_with_specified_precision_all_locales(): + for locale_code in localedata.locale_identifiers(): + assert numbers.format_currency( + '68.856', + currency='USD', + locale=locale_code, + decimal_quantization=False, + precision=2 + ).find('86') > -1 + + def test_format_currency_long_display_name(): assert (numbers.format_currency(1099.98, 'USD', locale='en_US', format_type='name') == '1,099.98 US dollars') @@ -678,6 +776,42 @@ def test_format_percent_quantization(): '0.9999999999', locale=locale_code, decimal_quantization=False).find('99999999') > -1 +@pytest.mark.parametrize('input_value, precision, expected_value', [ + ('100', 2, '10,000.00%'), + ('0.01', 1, '1.0%'), + ('0.010', 1, '1.0%'), + ('0.011', 2, '1.10%'), + ('0.0111', 3, '1.110%'), + ('0.01110', 3, '1.110%'), + ('0.01001', 3, '1.001%'), + ('0.0100100', 3, '1.001%'), + ('0.010100100', 5, '1.01001%'), + ('0.000000', 0, '0%'), + ('0', 0, '0%'), + ('0.00', 0, '0%'), + ('0.011', 2, '1.10%'), + ('0.0110', 2, '1.10%'), + ('0.0001', 2, '0.01%'), + ('0.000100', 2, '0.01%'), + ('0.0000100', 3, '0.001%'), + ('0.00000100', 4, '0.0001%'), +]) +def test_format_percent_with_specified_precision(input_value, precision, expected_value): + assert numbers.format_percent( + decimal.Decimal(input_value), locale='en_US', decimal_quantization=False, precision=precision + ) == expected_value + + +def test_format_percent_with_specified_precision_all_locales(): + for locale_code in localedata.locale_identifiers(): + assert numbers.format_percent( + '0.12349', + locale=locale_code, + decimal_quantization=False, + precision=2 + ).find('35') > -1 + + def test_format_scientific(): assert numbers.format_scientific(10000, locale='en_US') == '1E4' assert numbers.format_scientific(10000, locale='en_US', numbering_system="default") == '1E4' @@ -737,6 +871,44 @@ def test_format_scientific_quantization(): '0.9999999999', locale=locale_code, decimal_quantization=False).find('999999999') > -1 +@pytest.mark.parametrize('input_value, precision, expected_value', [ + ('10000', 0, '1E4'), + ('1', 0, '1E0'), + ('1.0', 0, '1E0'), + ('1.1', 1, '1.1E0'), + ('1.11', 2, '1.11E0'), + ('1.110', 2, '1.11E0'), + ('1.001', 3, '1.001E0'), + ('1.00100', 3, '1.001E0'), + ('01.00100', 3, '1.001E0'), + ('101.00100', 5, '1.01001E2'), + ('00000', 0, '0E0'), + ('0', 0, '0E0'), + ('0.0', 0, '0E0'), + ('0.1', 1, '1.0E-1'), + ('0.11', 1, '1.1E-1'), + ('0.110', 1, '1.1E-1'), + ('0.001', 3, '1.000E-3'), + ('0.00100', 3, '1.000E-3'), + ('00.00100', 3, '1.000E-3'), + ('000.00100', 3, '1.000E-3'), +]) +def test_format_scientific_with_specified_precision(input_value, precision, expected_value): + assert numbers.format_scientific( + decimal.Decimal(input_value), locale='en_US', decimal_quantization=False, precision=precision + ) == expected_value + + +def test_format_scientific_with_specified_precision_all_locales(): + for locale_code in localedata.locale_identifiers(): + assert numbers.format_scientific( + '1.23456789', + locale=locale_code, + decimal_quantization=False, + precision=4 + ).find('2346') > -1 + + def test_parse_number(): assert numbers.parse_number('1,099', locale='en_US') == 1099 assert numbers.parse_number('1.099', locale='de_DE') == 1099