From 122b8ed0dfbafce4748d0511ab9aa101b2cbf51f Mon Sep 17 00:00:00 2001 From: xcharleslin <4212216+xcharleslin@users.noreply.github.com> Date: Tue, 6 Jun 2023 21:38:58 -0700 Subject: [PATCH] Visualization cleanup (1/n): Use Table for repr (#1011) This PR has DataFrame repr call into Table repr and removes the existing DataFrame specific codepath. For now, we also deprecate the table sizing options (height and width). It will be fixed to a max of 1x20. (heads up @jaychia) Also fixs off-by-one error in Table repr and adds length truncation in Table repr. This PR: - [x] DataFrame repr uses Table repr Future PRs: - [ ] Add default html_value method to Series - [ ] Special implementation of html_value for images - [ ] Manually implement html_repr for Table - [ ] DataFrame repr calls Table html_repr --------- Co-authored-by: Xiayue Charles Lin --- daft/logical/schema.py | 2 +- daft/viz/dataframe_display.py | 22 +++++++----- daft/viz/repr.py | 63 ----------------------------------- src/array/ops/take.rs | 2 +- src/python/schema.rs | 4 +++ src/series/mod.rs | 13 +++++--- src/table/mod.rs | 42 +++++++++++++++++------ tests/dataframe/test_repr.py | 2 +- tests/test_schema.py | 9 ++++- 9 files changed, 68 insertions(+), 91 deletions(-) diff --git a/daft/logical/schema.py b/daft/logical/schema.py index 9413fdd0ec..cd23863aec 100644 --- a/daft/logical/schema.py +++ b/daft/logical/schema.py @@ -86,7 +86,7 @@ def to_name_set(self) -> set[str]: return set(self.column_names()) def __repr__(self) -> str: - return repr([(field.name, field.dtype) for field in self]) + return repr(self._schema) def union(self, other: Schema) -> Schema: if not isinstance(other, Schema): diff --git a/daft/viz/dataframe_display.py b/daft/viz/dataframe_display.py index 9f1a0b2cb6..407d6109f1 100644 --- a/daft/viz/dataframe_display.py +++ b/daft/viz/dataframe_display.py @@ -4,7 +4,7 @@ from daft.dataframe.preview import DataFramePreview from daft.logical.schema import Schema -from daft.viz.repr import vpartition_repr, vpartition_repr_html +from daft.viz.repr import vpartition_repr_html HAS_PILLOW = False try: @@ -22,6 +22,7 @@ class DataFrameDisplay: preview: DataFramePreview schema: Schema + # These formatting options are deprecated for now and not guaranteed to be supported. column_char_width: int = 20 max_col_rows: int = 3 num_rows: int = 10 @@ -46,11 +47,14 @@ def _repr_html_(self) -> str: ) def __repr__(self) -> str: - return vpartition_repr( - self.preview.preview_partition, - self.schema, - self.num_rows, - self._get_user_message(), - max_col_width=self.column_char_width, - max_lines=self.max_col_rows, - ) + if len(self.schema) == 0: + return "(No data to display: Dataframe has no columns)" + + if self.preview.preview_partition is not None: + res = repr(self.preview.preview_partition) + else: + res = repr(self.schema) + + res += f"\n{self._get_user_message()}" + + return res diff --git a/daft/viz/repr.py b/daft/viz/repr.py index 3e6984e810..cf7ec234ca 100644 --- a/daft/viz/repr.py +++ b/daft/viz/repr.py @@ -30,31 +30,6 @@ def _stringify_object_html(val: Any, max_col_width: int, max_lines: int): return html.escape(_truncate(str(val), max_col_width, max_lines)) -def _stringify_vpartition( - data: dict[str, list[Any]], - daft_schema: Schema, - max_col_width: int = DEFAULT_MAX_COL_WIDTH, - max_lines: int = DEFAULT_MAX_LINES, -) -> dict[str, Iterable[str]]: - """Converts a vPartition into a dictionary of display-friendly stringified values""" - assert all( - colname in data for colname in daft_schema.column_names() - ), f"Data does not contain columns: {set(daft_schema.column_names()) - set(data.keys())}" - - data_stringified: dict[str, Iterable[str]] = {} - for colname in daft_schema.column_names(): - field = daft_schema[colname] - if field.dtype._is_python_type(): - data_stringified[colname] = [_truncate(str(val), max_col_width, max_lines) for val in data[colname]] - elif field.dtype == DataType.bool(): - # BUG: tabulate library does not handle string literal values "True" and "False" correctly, so we lowercase them. - data_stringified[colname] = [_truncate(str(val).lower(), max_col_width, max_lines) for val in data[colname]] - else: - data_stringified[colname] = [_truncate(str(val), max_col_width, max_lines) for val in data[colname]] - - return data_stringified - - def _stringify_vpartition_html( data: dict[str, list[Any]], daft_schema: Schema, @@ -137,41 +112,3 @@ def vpartition_repr_html( {tabulate_html_string} {user_message} """ - - -def vpartition_repr( - vpartition: Table | None, - daft_schema: Schema, - num_rows: int, - user_message: str, - max_col_width: int = DEFAULT_MAX_COL_WIDTH, - max_lines: int = DEFAULT_MAX_LINES, -) -> str: - """Converts a vPartition into a prettified string for display in a REPL""" - if len(daft_schema) == 0: - return "(No data to display: Dataframe has no columns)" - - data = ( - {k: v[:num_rows] for k, v in vpartition.to_pydict().items()} - if vpartition is not None - else {colname: [] for colname in daft_schema.column_names()} - ) - data_stringified = _stringify_vpartition( - data, - daft_schema, - max_col_width=max_col_width, - max_lines=max_lines, - ) - - return ( - tabulate( - data_stringified, - headers=[f"{name}\n{daft_schema[name].dtype}" for name in daft_schema.column_names()], - tablefmt="grid", - missingval="None", - # Workaround for https://github.com/astanin/python-tabulate/issues/223 - # If table has no rows, specifying maxcolwidths always raises error. - maxcolwidths=max_col_width if vpartition is not None and len(vpartition) else None, - ) - + f"\n{user_message}" - ) diff --git a/src/array/ops/take.rs b/src/array/ops/take.rs index df5754c2f9..fcf0c9e9ed 100644 --- a/src/array/ops/take.rs +++ b/src/array/ops/take.rs @@ -78,7 +78,7 @@ impl Utf8Array { let val = self.get(idx); match val { None => Ok("None".to_string()), - Some(v) => Ok(format!("\"{v}\"")), + Some(v) => Ok(v.to_string()), } } } diff --git a/src/python/schema.rs b/src/python/schema.rs index 7100ce08e2..323b800798 100644 --- a/src/python/schema.rs +++ b/src/python/schema.rs @@ -77,6 +77,10 @@ impl PySchema { pub fn __getstate__(&self, py: Python) -> PyResult { Ok(PyBytes::new(py, &bincode::serialize(&self.schema).unwrap()).to_object(py)) } + + pub fn __repr__(&self) -> PyResult { + Ok(format!("{}", self.schema)) + } } impl From for PySchema { diff --git a/src/series/mod.rs b/src/series/mod.rs index 7fc85449d6..c37cb3c28c 100644 --- a/src/series/mod.rs +++ b/src/series/mod.rs @@ -50,11 +50,8 @@ impl Series { self.inner.cast(&physical_dtype) } } -} -impl Display for Series { - // `f` is a buffer, and this method must write the formatted string into it - fn fmt(&self, f: &mut Formatter) -> Result { + pub fn to_prettytable(&self) -> prettytable::Table { let mut table = prettytable::Table::new(); let header = @@ -87,6 +84,14 @@ impl Display for Series { table.add_row(row.into()); } + table + } +} + +impl Display for Series { + // `f` is a buffer, and this method must write the formatted string into it + fn fmt(&self, f: &mut Formatter) -> Result { + let table = self.to_prettytable(); write!(f, "{table}") } } diff --git a/src/table/mod.rs b/src/table/mod.rs index 9724b19863..adaf11d781 100644 --- a/src/table/mod.rs +++ b/src/table/mod.rs @@ -363,11 +363,8 @@ impl Table { let new_series: DaftResult> = self.columns.iter().map(|s| s.as_physical()).collect(); Table::from_columns(new_series?) } -} -impl Display for Table { - // `f` is a buffer, and this method must write the formatted string into it - fn fmt(&self, f: &mut Formatter) -> Result { + pub fn to_prettytable(&self, max_col_width: Option) -> prettytable::Table { let mut table = prettytable::Table::new(); let header = self .schema @@ -395,9 +392,16 @@ impl Display for Table { let row = self .columns .iter() - .map(|s| s.str_value(i)) - .collect::>>() - .unwrap(); + .map(|s| { + let mut str_val = s.str_value(i).unwrap(); + if let Some(max_col_width) = max_col_width { + if str_val.len() > max_col_width { + str_val = format!("{}...", &str_val[..max_col_width - 3]); + } + } + str_val + }) + .collect::>(); table.add_row(row.into()); } if tail_rows != 0 { @@ -405,16 +409,32 @@ impl Display for Table { table.add_row(row); } - for i in 0..tail_rows { + for i in (self.len() - tail_rows)..(self.len()) { let row = self .columns .iter() - .map(|s| s.str_value(self.len() - tail_rows - 1 + i)) - .collect::>>() - .unwrap(); + .map(|s| { + let mut str_val = s.str_value(i).unwrap(); + if let Some(max_col_width) = max_col_width { + if s.len() > max_col_width { + str_val = format!("{}...", &str_val[..max_col_width - 3]); + } + } + str_val + }) + .collect::>(); table.add_row(row.into()); } + table + } +} + +impl Display for Table { + // `f` is a buffer, and this method must write the formatted string into it + fn fmt(&self, f: &mut Formatter) -> Result { + let table = self.to_prettytable(Some(20)); + write!(f, "{table}") } } diff --git a/tests/dataframe/test_repr.py b/tests/dataframe/test_repr.py index 97e83572df..5154aaed40 100644 --- a/tests/dataframe/test_repr.py +++ b/tests/dataframe/test_repr.py @@ -27,7 +27,7 @@ def _split_table_row(row: str) -> list[str]: column_types = _split_table_row(lines[2]) data = [] - for line in lines[4:-1]: + for line in lines[4:-2]: if ROW_DIVIDER_REGEX.match(line): continue data.append(_split_table_row(line)) diff --git a/tests/test_schema.py b/tests/test_schema.py index 63f629aa00..c8d7904ec1 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -62,7 +62,14 @@ def test_schema_to_name_set(): def test_repr(): schema = TABLE.schema() - assert repr(schema) == "[('int', Int64), ('float', Float64), ('string', Utf8), ('bool', Boolean)]" + assert ( + repr(schema) + == """+-------+---------+--------+---------+ +| int | float | string | bool | +| Int64 | Float64 | Utf8 | Boolean | ++-------+---------+--------+---------+ +""" + ) def test_to_col_expr():