Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Mar 27, 2024
1 parent af81de0 commit e67dfc4
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 41 deletions.
91 changes: 51 additions & 40 deletions py-polars/src/dataframe/construction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use pyo3::prelude::*;

use super::*;
use crate::arrow_interop;
use crate::conversion::any_value::py_object_to_any_value;
use crate::conversion::{vec_extract_wrapped, Wrap};

#[pymethods]
Expand Down Expand Up @@ -31,13 +32,8 @@ impl PyDataFrame {
let schema = schema.map(|wrap| wrap.0);
let schema_overrides = schema_overrides.map(|wrap| wrap.0);

// If given, read dict fields in schema order.
let mut schema_columns = PlIndexSet::new();
if let Some(ref s) = schema {
schema_columns.extend(s.iter_names().map(|n| n.to_string()))
}

let (rows, names) = dicts_to_rows(data, infer_schema_length, schema_columns)?;
let names = get_schema_names(data, schema.as_ref(), infer_schema_length)?;
let rows = dicts_to_rows(data, &names)?;

let schema = schema.or_else(|| {
Some(columns_names_to_empty_schema(
Expand Down Expand Up @@ -138,48 +134,63 @@ where
Schema::from_iter(fields)
}

fn dicts_to_rows(
records: &PyAny,
infer_schema_len: Option<usize>,
schema_columns: PlIndexSet<String>,
) -> PyResult<(Vec<Row>, Vec<String>)> {
let infer_schema_len = infer_schema_len
.map(|n| std::cmp::max(1, n))
.unwrap_or(usize::MAX);
let len = records.len()?;

let key_names = {
if !schema_columns.is_empty() {
schema_columns
} else {
let mut inferred_keys = PlIndexSet::new();
for d in records.iter()?.take(infer_schema_len) {
let d = d?;
let d = d.downcast::<PyDict>()?;
let keys = d.keys();
for name in keys {
let name = name.extract::<String>()?;
inferred_keys.insert(name);
}
}
inferred_keys
}
};
fn dicts_to_rows<'a>(data: &'a PyAny, names: &'a [String]) -> PyResult<Vec<Row<'a>>> {
let len = data.len()?;
let mut rows = Vec::with_capacity(len);

for d in records.iter()? {
for d in data.iter()? {
let d = d?;
let d = d.downcast::<PyDict>()?;

let mut row = Vec::with_capacity(key_names.len());
for k in key_names.iter() {
let mut row = Vec::with_capacity(names.len());
for k in names.iter() {
let val = match d.get_item(k)? {
None => AnyValue::Null,
Some(val) => val.extract::<Wrap<AnyValue>>()?.0,
// TODO: Propagate strictness here.
// https://github.com/pola-rs/polars/issues/14427
Some(val) => py_object_to_any_value(val, false)?,
};
row.push(val)
}
rows.push(Row(row))
}
Ok((rows, key_names.into_iter().collect()))
Ok(rows)
}

/// Either read the given schema, or infer the schema names from the data.
fn get_schema_names(
data: &PyAny,
schema: Option<&Schema>,
infer_schema_length: Option<usize>,
) -> PyResult<Vec<String>> {
if let Some(schema) = schema {
Ok(schema.iter_names().map(|n| n.to_string()).collect())
} else {
infer_schema_names_from_data(data, infer_schema_length)
}
}

/// Infer schema names from an iterable of dictionaries.
///
/// The resulting schema order is determined by the order in which the names are encountered in
/// the data.
fn infer_schema_names_from_data(
data: &PyAny,
infer_schema_length: Option<usize>,
) -> PyResult<Vec<String>> {
let data_len = data.len()?;
let infer_schema_length = infer_schema_length
.map(|n| std::cmp::max(1, n))
.unwrap_or(data_len);

let mut names = PlIndexSet::new();
for d in data.iter()?.take(infer_schema_length) {
let d = d?;
let d = d.downcast::<PyDict>()?;
let keys = d.keys();
for name in keys {
let name = name.extract::<String>()?;
names.insert(name);
}
}
Ok(names.into_iter().collect())
}
4 changes: 3 additions & 1 deletion py-polars/tests/unit/interop/test_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,9 @@ def test_from_dicts() -> None:
def test_from_dict_no_inference() -> None:
schema = {"a": pl.String}
data = [{"a": "aa"}]
pl.from_dicts(data, schema_overrides=schema, infer_schema_length=0)
df = pl.from_dicts(data, schema_overrides=schema, infer_schema_length=0)
assert df.schema == schema
assert df.to_dicts() == data


def test_from_dicts_schema_override() -> None:
Expand Down

0 comments on commit e67dfc4

Please sign in to comment.