Skip to content

Commit

Permalink
csvstat: Output decimals as floats when --json is set, #1216 #828
Browse files Browse the repository at this point in the history
  • Loading branch information
jpmckinney committed Oct 17, 2023
1 parent 39ca69f commit 42873d9
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
8 changes: 7 additions & 1 deletion csvkit/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,14 +412,20 @@ def isatty(f):
return False


def default(obj):
def default_str_decimal(obj):
if isinstance(obj, (datetime.date, datetime.datetime)):
return obj.isoformat()
if isinstance(obj, decimal.Decimal):
return str(obj)
raise TypeError(f'{repr(obj)} is not JSON serializable')


def default_float_decimal(obj):
if isinstance(obj, decimal.Decimal):
return float(obj)
return default_str_decimal(obj)


def make_default_headers(n):
"""
Make a set of simple, default headers for files that are missing them.
Expand Down
4 changes: 2 additions & 2 deletions csvkit/utilities/csvjson.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import agate

from csvkit.cli import CSVKitUtility, default, match_column_identifier
from csvkit.cli import CSVKitUtility, default_str_decimal, match_column_identifier


class CSVJSON(CSVKitUtility):
Expand Down Expand Up @@ -95,7 +95,7 @@ def main(self):
self.output_json()

def dump_json(self, data, newline=False):
json.dump(data, self.output_file, default=default, ensure_ascii=False, **self.json_kwargs)
json.dump(data, self.output_file, default=default_str_decimal, ensure_ascii=False, **self.json_kwargs)
if newline:
self.output_file.write("\n")

Expand Down
6 changes: 3 additions & 3 deletions csvkit/utilities/csvstat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import agate

from csvkit.cli import CSVKitUtility, default, parse_column_identifiers
from csvkit.cli import CSVKitUtility, default_float_decimal, parse_column_identifiers

locale.setlocale(locale.LC_ALL, '')
OPERATIONS = OrderedDict([
Expand Down Expand Up @@ -266,7 +266,7 @@ def calculate_stats(self, table, column_id, **kwargs):
op = op_data['aggregation']
v = table.aggregate(op(column_id))

if self.is_finite_decimal(v):
if self.is_finite_decimal(v) and not self.args.json_output:
v = format_decimal(v, self.args.decimal_format, self.args.no_grouping_separator)

stats[op_name] = v
Expand Down Expand Up @@ -352,7 +352,7 @@ def print_json(self, table, column_ids, stats):
"""
data = list(self._rows(table, column_ids, stats))

json.dump(data, self.output_file, default=default, ensure_ascii=False, indent=self.args.indent)
json.dump(data, self.output_file, default=default_float_decimal, ensure_ascii=False, indent=self.args.indent)

def _rows(self, table, column_ids, stats):
for column_id in column_ids:
Expand Down

0 comments on commit 42873d9

Please sign in to comment.