diff --git a/csvkit/cli.py b/csvkit/cli.py index 415cd672b..2cf1a5ceb 100644 --- a/csvkit/cli.py +++ b/csvkit/cli.py @@ -412,7 +412,7 @@ def isatty(f): return False -def default(obj): +def default_str_decimal(obj): if isinstance(obj, (datetime.date, datetime.datetime)): return obj.isoformat() if isinstance(obj, decimal.Decimal): @@ -420,6 +420,12 @@ def default(obj): raise TypeError(f'{repr(obj)} is not JSON serializable') +def default_float_decimal(obj): + if isinstance(obj, decimal.Decimal): + return float(obj) + return default_str_decimal(obj) + + def make_default_headers(n): """ Make a set of simple, default headers for files that are missing them. diff --git a/csvkit/utilities/csvjson.py b/csvkit/utilities/csvjson.py index a4356c4cb..afbe7af95 100644 --- a/csvkit/utilities/csvjson.py +++ b/csvkit/utilities/csvjson.py @@ -6,7 +6,7 @@ import agate -from csvkit.cli import CSVKitUtility, default, match_column_identifier +from csvkit.cli import CSVKitUtility, default_str_decimal, match_column_identifier class CSVJSON(CSVKitUtility): @@ -95,7 +95,7 @@ def main(self): self.output_json() def dump_json(self, data, newline=False): - json.dump(data, self.output_file, default=default, ensure_ascii=False, **self.json_kwargs) + json.dump(data, self.output_file, default=default_str_decimal, ensure_ascii=False, **self.json_kwargs) if newline: self.output_file.write("\n") diff --git a/csvkit/utilities/csvstat.py b/csvkit/utilities/csvstat.py index c661c4d94..2a6342c0b 100644 --- a/csvkit/utilities/csvstat.py +++ b/csvkit/utilities/csvstat.py @@ -8,7 +8,7 @@ import agate -from csvkit.cli import CSVKitUtility, default, parse_column_identifiers +from csvkit.cli import CSVKitUtility, default_float_decimal, parse_column_identifiers locale.setlocale(locale.LC_ALL, '') OPERATIONS = OrderedDict([ @@ -266,7 +266,7 @@ def calculate_stats(self, table, column_id, **kwargs): op = op_data['aggregation'] v = table.aggregate(op(column_id)) - if self.is_finite_decimal(v): + if self.is_finite_decimal(v) and not self.args.json_output: v = format_decimal(v, self.args.decimal_format, self.args.no_grouping_separator) stats[op_name] = v @@ -352,7 +352,7 @@ def print_json(self, table, column_ids, stats): """ data = list(self._rows(table, column_ids, stats)) - json.dump(data, self.output_file, default=default, ensure_ascii=False, indent=self.args.indent) + json.dump(data, self.output_file, default=default_float_decimal, ensure_ascii=False, indent=self.args.indent) def _rows(self, table, column_ids, stats): for column_id in column_ids: