Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for 'enum' extension type to v2 reader/writer #3038

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions python/python/tests/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,3 +436,106 @@ def test_blob(tmp_path):
reader = LanceFileReader(str(path))
assert len(reader.metadata().columns[0].pages) == 1
assert reader.read_all().to_table() == pa.table({"val": vals})


def test_enum_vs_categorical(tmp_path):
# Helper method to make two dict arrays, with same dictionary values
# but different indices
def make_tbls(values, indices1, indices2):
# Need to make two separate dictionaries here or else arrow-rs won't concat
d1 = pa.array(values, pa.string())
d2 = pa.array(values, pa.string())
i1 = pa.array(indices1, pa.int16())
i2 = pa.array(indices2, pa.int16())

dict1 = pa.DictionaryArray.from_arrays(i1, d1)
dict2 = pa.DictionaryArray.from_arrays(i2, d2)
tab1 = pa.table({"dictionary": dict1})
tab2 = pa.table({"dictionary": dict2})
return tab1, tab2

# Helper method to round trip two tables through lance and return the decoded
# dictionary array
def round_trip_dict(tab1: pa.Table, tab2: pa.Table) -> pa.DictionaryArray:
with LanceFileWriter(tmp_path / "categorical.lance") as writer:
writer.write_batch(tab1)
writer.write_batch(tab2)

reader = LanceFileReader(tmp_path / "categorical.lance")
round_tripped = reader.read_all().to_table()

arr2 = round_tripped.column("dictionary").chunk(0).dictionary
return arr2

# Helper method to convert a table with dictionary array into a table with
# enum array
def enumify(tbl) -> pa.Table:
categories = ",".join(tbl.column(0).chunk(0).dictionary.to_pylist())
enum_schema = pa.schema(
[
pa.field(
"dictionary",
pa.dictionary(pa.int16(), pa.string()),
metadata={
"ARROW:extension:name": "polars.enum",
"ARROW:extension:metadata": '{"categories": ['
+ categories
+ "]}",
},
)
]
)
return pa.table([tbl.column(0)], schema=enum_schema)

tab1, tab2 = make_tbls(
["blue", "red", "green", "yellow"],
[0, 1, 0, 1, 0, 1],
[1, 2, 1, 2, 1, 2],
)

round_trip = round_trip_dict(tab1, tab2)

# Sometimes array concatenation will just concatenate the dictionaries
assert round_trip.to_pylist() == [
"blue",
"red",
"green",
"yellow",
"blue",
"red",
"green",
"yellow",
]

tab1 = enumify(tab1)
tab2 = enumify(tab2)

round_trip = round_trip_dict(tab1, tab2)

# However, there should be no concatenation with the enum type
assert round_trip.to_pylist() == [
"blue",
"red",
"green",
"yellow",
]

tab1, tab2 = make_tbls(
[str(i) for i in range(1000)],
list(range(500)),
list(range(500, 900)),
)

round_trip = round_trip_dict(tab1, tab2)

# Other times array concatenation will combine the
# dictionaries and remove unused items
assert round_trip.to_pylist() == [str(i) for i in range(900)]

# Again, no concatenation with enum type
tab1 = enumify(tab1)
tab2 = enumify(tab2)

round_trip = round_trip_dict(tab1, tab2)

assert round_trip.to_pylist() == [str(i) for i in range(1000)]
2 changes: 2 additions & 0 deletions rust/lance-arrow/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ arrow-select = { workspace = true }
half = { workspace = true }
num-traits = { workspace = true }
rand.workspace = true
serde = { workspace = true }
serde_json = { workspace = true }

[target.'cfg(target_arch = "wasm32")'.dependencies]
getrandom = { version = "0.2", features = ["js"] }
146 changes: 146 additions & 0 deletions rust/lance-arrow/src/dict_enum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
use std::sync::Arc;

use arrow_array::{cast::AsArray, Array, StringArray};
use arrow_schema::{ArrowError, Field as ArrowField};
use serde::{Deserialize, Serialize};

use crate::{
bfloat16::{ARROW_EXT_META_KEY, ARROW_EXT_NAME_KEY},
DataTypeExt, Result,
};

const ENUM_TYPE: &str = "polars.enum";

// TODO: Could be slightly more efficient to use custom JSON serialization
// to go straight from JSON to StringArray without the Vec<String> intermediate
// but this is fine for now
#[derive(Deserialize, Serialize)]
struct DictionaryEnumMetadata {
categories: Vec<String>,
}

pub struct DictionaryEnumType {
pub categories: Arc<dyn Array>,
}

impl DictionaryEnumType {
/// Adds extension type metadata to the given field
///
/// Fails if the field is already an extension type of some kind
pub fn wrap_field(&self, field: &ArrowField) -> Result<ArrowField> {
let mut metadata = field.metadata().clone();
if metadata.contains_key(ARROW_EXT_NAME_KEY) {
return Err(ArrowError::InvalidArgumentError(
"Field already has extension metadata".to_string(),
));
}
metadata.insert(ARROW_EXT_NAME_KEY.to_string(), ENUM_TYPE.to_string());
metadata.insert(
ARROW_EXT_META_KEY.to_string(),
serde_json::to_string(&DictionaryEnumMetadata {
categories: self
.categories
.as_any()
.downcast_ref::<StringArray>()
.unwrap()
.values()
.iter()
.map(|x| x.to_string())
.collect(),
})
.unwrap(),
);
Ok(field.clone().with_metadata(metadata))
}

/// Creates a new enum type from the given dictionary array
///
/// # Arguments
///
/// * `arr` - The dictionary array to create the enum type from
///
/// # Errors
///
/// An error is returned if the array is not a dictionary array or if the dictionary
/// array does not have string values
pub fn from_dict_array(arr: &dyn Array) -> Result<Self> {
let arr = arr.as_any_dictionary_opt().ok_or_else(|| {
ArrowError::InvalidArgumentError(
"Expected a dictionary array for enum type".to_string(),
)
})?;
if !arr.values().data_type().is_binary_like() {
Err(ArrowError::InvalidArgumentError(
"Expected a dictionary array with string values for enum type".to_string(),
))
} else {
Ok(Self {
categories: Arc::new(arr.values().clone()),
})
}
}

/// Attempts to parse the type from the given field
///
/// If the field is not an enum type then None is returned
///
/// Errors can occur if the field is an enum type but the metadata
/// is not correctly formatted
///
/// # Arguments
///
/// * `field` - The field to parse
/// * `sample_arr` - An optional sample array. If provided then categories will be extracted
/// from this array, avoiding the need to parse the metadata. This array should be a dictionary
/// array where the dictionary items are the categories.
///
/// The sample_arr is only used if the field is an enum type. E.g. it is safe to do something
/// like:
///
/// ```ignore
/// let arr = batch.column(0);
/// let field = batch.schema().field(0);
/// let enum_type = DictionaryEnumType::from_field(field, Some(arr));
/// ```
pub fn from_field(
field: &ArrowField,
sample_arr: Option<&Arc<dyn Array>>,
) -> Result<Option<Self>> {
if field
.metadata()
.get(ARROW_EXT_NAME_KEY)
.map(|k| k.eq_ignore_ascii_case(ENUM_TYPE))
.unwrap_or(false)
{
// Prefer extracting values from the first array if possible as it's cheaper
if let Some(arr) = sample_arr {
let dict_arr = arr.as_any_dictionary_opt().ok_or_else(|| {
ArrowError::InvalidArgumentError(
"Expected a dictionary array for enum type".to_string(),
)
})?;
Ok(Some(Self {
categories: dict_arr.values().clone(),
}))
} else {
// No arrays, need to use the field metadata
let meta = field.metadata().get(ARROW_EXT_META_KEY).ok_or_else(|| {
ArrowError::InvalidArgumentError(format!(
"Field {} is missing extension metadata",
field.name()
))
})?;
let meta: DictionaryEnumMetadata = serde_json::from_str(meta).map_err(|e| {
ArrowError::InvalidArgumentError(format!(
"Arrow extension metadata for enum was not correctly formed: {}",
e
))
})?;
let categories = Arc::new(StringArray::from_iter_values(meta.categories));
Ok(Some(Self { categories }))
}
} else {
Ok(None)
}
}
}
1 change: 1 addition & 0 deletions rust/lance-arrow/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use arrow_select::{interleave::interleave, take::take};
use rand::prelude::*;

pub mod deepcopy;
pub mod dict_enum;
pub mod schema;
pub use schema::*;
pub mod bfloat16;
Expand Down
Loading
Loading