diff --git a/daft/logical/schema.py b/daft/logical/schema.py index 9413fdd0ec..cd23863aec 100644 --- a/daft/logical/schema.py +++ b/daft/logical/schema.py @@ -86,7 +86,7 @@ def to_name_set(self) -> set[str]: return set(self.column_names()) def __repr__(self) -> str: - return repr([(field.name, field.dtype) for field in self]) + return repr(self._schema) def union(self, other: Schema) -> Schema: if not isinstance(other, Schema): diff --git a/daft/viz/dataframe_display.py b/daft/viz/dataframe_display.py index 9f1a0b2cb6..407d6109f1 100644 --- a/daft/viz/dataframe_display.py +++ b/daft/viz/dataframe_display.py @@ -4,7 +4,7 @@ from daft.dataframe.preview import DataFramePreview from daft.logical.schema import Schema -from daft.viz.repr import vpartition_repr, vpartition_repr_html +from daft.viz.repr import vpartition_repr_html HAS_PILLOW = False try: @@ -22,6 +22,7 @@ class DataFrameDisplay: preview: DataFramePreview schema: Schema + # These formatting options are deprecated for now and not guaranteed to be supported. column_char_width: int = 20 max_col_rows: int = 3 num_rows: int = 10 @@ -46,11 +47,14 @@ def _repr_html_(self) -> str: ) def __repr__(self) -> str: - return vpartition_repr( - self.preview.preview_partition, - self.schema, - self.num_rows, - self._get_user_message(), - max_col_width=self.column_char_width, - max_lines=self.max_col_rows, - ) + if len(self.schema) == 0: + return "(No data to display: Dataframe has no columns)" + + if self.preview.preview_partition is not None: + res = repr(self.preview.preview_partition) + else: + res = repr(self.schema) + + res += f"\n{self._get_user_message()}" + + return res diff --git a/daft/viz/repr.py b/daft/viz/repr.py index 3e6984e810..cf7ec234ca 100644 --- a/daft/viz/repr.py +++ b/daft/viz/repr.py @@ -30,31 +30,6 @@ def _stringify_object_html(val: Any, max_col_width: int, max_lines: int): return html.escape(_truncate(str(val), max_col_width, max_lines)) -def _stringify_vpartition( - data: dict[str, list[Any]], - daft_schema: Schema, - max_col_width: int = DEFAULT_MAX_COL_WIDTH, - max_lines: int = DEFAULT_MAX_LINES, -) -> dict[str, Iterable[str]]: - """Converts a vPartition into a dictionary of display-friendly stringified values""" - assert all( - colname in data for colname in daft_schema.column_names() - ), f"Data does not contain columns: {set(daft_schema.column_names()) - set(data.keys())}" - - data_stringified: dict[str, Iterable[str]] = {} - for colname in daft_schema.column_names(): - field = daft_schema[colname] - if field.dtype._is_python_type(): - data_stringified[colname] = [_truncate(str(val), max_col_width, max_lines) for val in data[colname]] - elif field.dtype == DataType.bool(): - # BUG: tabulate library does not handle string literal values "True" and "False" correctly, so we lowercase them. - data_stringified[colname] = [_truncate(str(val).lower(), max_col_width, max_lines) for val in data[colname]] - else: - data_stringified[colname] = [_truncate(str(val), max_col_width, max_lines) for val in data[colname]] - - return data_stringified - - def _stringify_vpartition_html( data: dict[str, list[Any]], daft_schema: Schema, @@ -137,41 +112,3 @@ def vpartition_repr_html( {tabulate_html_string} {user_message} """ - - -def vpartition_repr( - vpartition: Table | None, - daft_schema: Schema, - num_rows: int, - user_message: str, - max_col_width: int = DEFAULT_MAX_COL_WIDTH, - max_lines: int = DEFAULT_MAX_LINES, -) -> str: - """Converts a vPartition into a prettified string for display in a REPL""" - if len(daft_schema) == 0: - return "(No data to display: Dataframe has no columns)" - - data = ( - {k: v[:num_rows] for k, v in vpartition.to_pydict().items()} - if vpartition is not None - else {colname: [] for colname in daft_schema.column_names()} - ) - data_stringified = _stringify_vpartition( - data, - daft_schema, - max_col_width=max_col_width, - max_lines=max_lines, - ) - - return ( - tabulate( - data_stringified, - headers=[f"{name}\n{daft_schema[name].dtype}" for name in daft_schema.column_names()], - tablefmt="grid", - missingval="None", - # Workaround for https://github.com/astanin/python-tabulate/issues/223 - # If table has no rows, specifying maxcolwidths always raises error. - maxcolwidths=max_col_width if vpartition is not None and len(vpartition) else None, - ) - + f"\n{user_message}" - ) diff --git a/src/array/ops/take.rs b/src/array/ops/take.rs index df5754c2f9..fcf0c9e9ed 100644 --- a/src/array/ops/take.rs +++ b/src/array/ops/take.rs @@ -78,7 +78,7 @@ impl Utf8Array { let val = self.get(idx); match val { None => Ok("None".to_string()), - Some(v) => Ok(format!("\"{v}\"")), + Some(v) => Ok(v.to_string()), } } } diff --git a/src/python/schema.rs b/src/python/schema.rs index 7100ce08e2..323b800798 100644 --- a/src/python/schema.rs +++ b/src/python/schema.rs @@ -77,6 +77,10 @@ impl PySchema { pub fn __getstate__(&self, py: Python) -> PyResult { Ok(PyBytes::new(py, &bincode::serialize(&self.schema).unwrap()).to_object(py)) } + + pub fn __repr__(&self) -> PyResult { + Ok(format!("{}", self.schema)) + } } impl From for PySchema { diff --git a/src/series/mod.rs b/src/series/mod.rs index 7fc85449d6..c37cb3c28c 100644 --- a/src/series/mod.rs +++ b/src/series/mod.rs @@ -50,11 +50,8 @@ impl Series { self.inner.cast(&physical_dtype) } } -} -impl Display for Series { - // `f` is a buffer, and this method must write the formatted string into it - fn fmt(&self, f: &mut Formatter) -> Result { + pub fn to_prettytable(&self) -> prettytable::Table { let mut table = prettytable::Table::new(); let header = @@ -87,6 +84,14 @@ impl Display for Series { table.add_row(row.into()); } + table + } +} + +impl Display for Series { + // `f` is a buffer, and this method must write the formatted string into it + fn fmt(&self, f: &mut Formatter) -> Result { + let table = self.to_prettytable(); write!(f, "{table}") } } diff --git a/src/table/mod.rs b/src/table/mod.rs index 9724b19863..adaf11d781 100644 --- a/src/table/mod.rs +++ b/src/table/mod.rs @@ -363,11 +363,8 @@ impl Table { let new_series: DaftResult> = self.columns.iter().map(|s| s.as_physical()).collect(); Table::from_columns(new_series?) } -} -impl Display for Table { - // `f` is a buffer, and this method must write the formatted string into it - fn fmt(&self, f: &mut Formatter) -> Result { + pub fn to_prettytable(&self, max_col_width: Option) -> prettytable::Table { let mut table = prettytable::Table::new(); let header = self .schema @@ -395,9 +392,16 @@ impl Display for Table { let row = self .columns .iter() - .map(|s| s.str_value(i)) - .collect::>>() - .unwrap(); + .map(|s| { + let mut str_val = s.str_value(i).unwrap(); + if let Some(max_col_width) = max_col_width { + if str_val.len() > max_col_width { + str_val = format!("{}...", &str_val[..max_col_width - 3]); + } + } + str_val + }) + .collect::>(); table.add_row(row.into()); } if tail_rows != 0 { @@ -405,16 +409,32 @@ impl Display for Table { table.add_row(row); } - for i in 0..tail_rows { + for i in (self.len() - tail_rows)..(self.len()) { let row = self .columns .iter() - .map(|s| s.str_value(self.len() - tail_rows - 1 + i)) - .collect::>>() - .unwrap(); + .map(|s| { + let mut str_val = s.str_value(i).unwrap(); + if let Some(max_col_width) = max_col_width { + if s.len() > max_col_width { + str_val = format!("{}...", &str_val[..max_col_width - 3]); + } + } + str_val + }) + .collect::>(); table.add_row(row.into()); } + table + } +} + +impl Display for Table { + // `f` is a buffer, and this method must write the formatted string into it + fn fmt(&self, f: &mut Formatter) -> Result { + let table = self.to_prettytable(Some(20)); + write!(f, "{table}") } } diff --git a/tests/dataframe/test_repr.py b/tests/dataframe/test_repr.py index 97e83572df..5154aaed40 100644 --- a/tests/dataframe/test_repr.py +++ b/tests/dataframe/test_repr.py @@ -27,7 +27,7 @@ def _split_table_row(row: str) -> list[str]: column_types = _split_table_row(lines[2]) data = [] - for line in lines[4:-1]: + for line in lines[4:-2]: if ROW_DIVIDER_REGEX.match(line): continue data.append(_split_table_row(line)) diff --git a/tests/test_schema.py b/tests/test_schema.py index 63f629aa00..c8d7904ec1 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -62,7 +62,14 @@ def test_schema_to_name_set(): def test_repr(): schema = TABLE.schema() - assert repr(schema) == "[('int', Int64), ('float', Float64), ('string', Utf8), ('bool', Boolean)]" + assert ( + repr(schema) + == """+-------+---------+--------+---------+ +| int | float | string | bool | +| Int64 | Float64 | Utf8 | Boolean | ++-------+---------+--------+---------+ +""" + ) def test_to_col_expr():