From e67dfc4818d2627fcc08082481783dc720f1c6ce Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Wed, 27 Mar 2024 21:26:46 +0100 Subject: [PATCH] Refactor --- py-polars/src/dataframe/construction.rs | 91 +++++++++++--------- py-polars/tests/unit/interop/test_interop.py | 4 +- 2 files changed, 54 insertions(+), 41 deletions(-) diff --git a/py-polars/src/dataframe/construction.rs b/py-polars/src/dataframe/construction.rs index 3eeb0fe9fc60..d3cd40bc1e0b 100644 --- a/py-polars/src/dataframe/construction.rs +++ b/py-polars/src/dataframe/construction.rs @@ -3,6 +3,7 @@ use pyo3::prelude::*; use super::*; use crate::arrow_interop; +use crate::conversion::any_value::py_object_to_any_value; use crate::conversion::{vec_extract_wrapped, Wrap}; #[pymethods] @@ -31,13 +32,8 @@ impl PyDataFrame { let schema = schema.map(|wrap| wrap.0); let schema_overrides = schema_overrides.map(|wrap| wrap.0); - // If given, read dict fields in schema order. - let mut schema_columns = PlIndexSet::new(); - if let Some(ref s) = schema { - schema_columns.extend(s.iter_names().map(|n| n.to_string())) - } - - let (rows, names) = dicts_to_rows(data, infer_schema_length, schema_columns)?; + let names = get_schema_names(data, schema.as_ref(), infer_schema_length)?; + let rows = dicts_to_rows(data, &names)?; let schema = schema.or_else(|| { Some(columns_names_to_empty_schema( @@ -138,48 +134,63 @@ where Schema::from_iter(fields) } -fn dicts_to_rows( - records: &PyAny, - infer_schema_len: Option, - schema_columns: PlIndexSet, -) -> PyResult<(Vec, Vec)> { - let infer_schema_len = infer_schema_len - .map(|n| std::cmp::max(1, n)) - .unwrap_or(usize::MAX); - let len = records.len()?; - - let key_names = { - if !schema_columns.is_empty() { - schema_columns - } else { - let mut inferred_keys = PlIndexSet::new(); - for d in records.iter()?.take(infer_schema_len) { - let d = d?; - let d = d.downcast::()?; - let keys = d.keys(); - for name in keys { - let name = name.extract::()?; - inferred_keys.insert(name); - } - } - inferred_keys - } - }; +fn dicts_to_rows<'a>(data: &'a PyAny, names: &'a [String]) -> PyResult>> { + let len = data.len()?; let mut rows = Vec::with_capacity(len); - - for d in records.iter()? { + for d in data.iter()? { let d = d?; let d = d.downcast::()?; - let mut row = Vec::with_capacity(key_names.len()); - for k in key_names.iter() { + let mut row = Vec::with_capacity(names.len()); + for k in names.iter() { let val = match d.get_item(k)? { None => AnyValue::Null, - Some(val) => val.extract::>()?.0, + // TODO: Propagate strictness here. + // https://github.com/pola-rs/polars/issues/14427 + Some(val) => py_object_to_any_value(val, false)?, }; row.push(val) } rows.push(Row(row)) } - Ok((rows, key_names.into_iter().collect())) + Ok(rows) +} + +/// Either read the given schema, or infer the schema names from the data. +fn get_schema_names( + data: &PyAny, + schema: Option<&Schema>, + infer_schema_length: Option, +) -> PyResult> { + if let Some(schema) = schema { + Ok(schema.iter_names().map(|n| n.to_string()).collect()) + } else { + infer_schema_names_from_data(data, infer_schema_length) + } +} + +/// Infer schema names from an iterable of dictionaries. +/// +/// The resulting schema order is determined by the order in which the names are encountered in +/// the data. +fn infer_schema_names_from_data( + data: &PyAny, + infer_schema_length: Option, +) -> PyResult> { + let data_len = data.len()?; + let infer_schema_length = infer_schema_length + .map(|n| std::cmp::max(1, n)) + .unwrap_or(data_len); + + let mut names = PlIndexSet::new(); + for d in data.iter()?.take(infer_schema_length) { + let d = d?; + let d = d.downcast::()?; + let keys = d.keys(); + for name in keys { + let name = name.extract::()?; + names.insert(name); + } + } + Ok(names.into_iter().collect()) } diff --git a/py-polars/tests/unit/interop/test_interop.py b/py-polars/tests/unit/interop/test_interop.py index 062a68a02b0b..630530f457be 100644 --- a/py-polars/tests/unit/interop/test_interop.py +++ b/py-polars/tests/unit/interop/test_interop.py @@ -308,7 +308,9 @@ def test_from_dicts() -> None: def test_from_dict_no_inference() -> None: schema = {"a": pl.String} data = [{"a": "aa"}] - pl.from_dicts(data, schema_overrides=schema, infer_schema_length=0) + df = pl.from_dicts(data, schema_overrides=schema, infer_schema_length=0) + assert df.schema == schema + assert df.to_dicts() == data def test_from_dicts_schema_override() -> None: