Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support more integer types #57

Merged
merged 4 commits into from
Feb 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 66 additions & 3 deletions optd-core/src/rel_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@ pub trait RelNodeTyp:

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum Value {
Int(i64),
UInt8(u8),
UInt16(u16),
UInt32(u32),
UInt64(u64),
Int8(i8),
Int16(i16),
Int32(i32),
Int64(i64),
Float(OrderedFloat<f64>),
String(Arc<str>),
Bool(bool),
Expand All @@ -37,7 +44,14 @@ pub enum Value {
impl std::fmt::Display for Value {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Int(x) => write!(f, "{x}"),
Self::UInt8(x) => write!(f, "{x}"),
Self::UInt16(x) => write!(f, "{x}"),
Self::UInt32(x) => write!(f, "{x}"),
Self::UInt64(x) => write!(f, "{x}"),
Self::Int8(x) => write!(f, "{x}"),
Self::Int16(x) => write!(f, "{x}"),
Self::Int32(x) => write!(f, "{x}"),
Self::Int64(x) => write!(f, "{x}"),
Self::Float(x) => write!(f, "{x}"),
Self::String(x) => write!(f, "\"{x}\""),
Self::Bool(x) => write!(f, "{x}"),
Expand All @@ -47,9 +61,58 @@ impl std::fmt::Display for Value {
}

impl Value {
pub fn as_u8(&self) -> u8 {
match self {
Value::UInt8(i) => *i,
_ => panic!("Value is not an u8"),
}
}

pub fn as_u16(&self) -> u16 {
match self {
Value::UInt16(i) => *i,
_ => panic!("Value is not an u16"),
}
}

pub fn as_u32(&self) -> u32 {
match self {
Value::UInt32(i) => *i,
_ => panic!("Value is not an u32"),
}
}

pub fn as_u64(&self) -> u64 {
match self {
Value::UInt64(i) => *i,
_ => panic!("Value is not an u64"),
}
}

pub fn as_i8(&self) -> i8 {
match self {
Value::Int8(i) => *i,
_ => panic!("Value is not an i8"),
}
}

pub fn as_i16(&self) -> i16 {
match self {
Value::Int16(i) => *i,
_ => panic!("Value is not an i16"),
}
}

pub fn as_i32(&self) -> i32 {
match self {
Value::Int32(i) => *i,
_ => panic!("Value is not an i32"),
}
}

pub fn as_i64(&self) -> i64 {
match self {
Value::Int(i) => *i,
Value::Int64(i) => *i,
_ => panic!("Value is not an i64"),
}
}
Expand Down
18 changes: 16 additions & 2 deletions optd-datafusion-bridge/src/from_optd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,14 @@ fn from_optd_schema(optd_schema: &OptdSchema) -> Schema {
let match_type = |typ: &ConstantType| match typ {
ConstantType::Any => unimplemented!(),
ConstantType::Bool => DataType::Boolean,
ConstantType::Int => DataType::Int64,
ConstantType::UInt8 => DataType::UInt8,
ConstantType::UInt16 => DataType::UInt16,
ConstantType::UInt32 => DataType::UInt32,
ConstantType::UInt64 => DataType::UInt64,
ConstantType::Int8 => DataType::Int8,
ConstantType::Int16 => DataType::Int16,
ConstantType::Int32 => DataType::Int32,
ConstantType::Int64 => DataType::Int64,
ConstantType::Date => DataType::Date32,
ConstantType::Decimal => DataType::Float64,
ConstantType::Utf8String => DataType::Utf8,
Expand Down Expand Up @@ -127,7 +134,14 @@ impl OptdPlanContext<'_> {
let value = expr.value();
let value = match typ {
ConstantType::Bool => ScalarValue::Boolean(Some(value.as_bool())),
ConstantType::Int => ScalarValue::Int64(Some(value.as_i64())),
ConstantType::UInt8 => ScalarValue::UInt8(Some(value.as_u8())),
ConstantType::UInt16 => ScalarValue::UInt16(Some(value.as_u16())),
ConstantType::UInt32 => ScalarValue::UInt32(Some(value.as_u32())),
ConstantType::UInt64 => ScalarValue::UInt64(Some(value.as_u64())),
ConstantType::Int8 => ScalarValue::Int8(Some(value.as_i8())),
ConstantType::Int16 => ScalarValue::Int16(Some(value.as_i16())),
ConstantType::Int32 => ScalarValue::Int32(Some(value.as_i32())),
ConstantType::Int64 => ScalarValue::Int64(Some(value.as_i64())),
ConstantType::Decimal => {
ScalarValue::Decimal128(Some(value.as_f64() as i128), 20, 0)
// TODO(chi): no hard code decimal
Expand Down
72 changes: 49 additions & 23 deletions optd-datafusion-bridge/src/into_optd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ impl OptdPlanContext<'_> {
Ok(scan.into_plan_node())
}

fn conv_into_optd_expr(&mut self, expr: &logical_expr::Expr, context: &DFSchema) -> Result<Expr> {
fn conv_into_optd_expr(
&mut self,
expr: &logical_expr::Expr,
context: &DFSchema,
) -> Result<Expr> {
use logical_expr::Expr;
match expr {
Expr::BinaryExpr(node) => {
Expand All @@ -64,15 +68,39 @@ impl OptdPlanContext<'_> {
Expr::Literal(x) => match x {
ScalarValue::UInt8(x) => {
let x = x.as_ref().unwrap();
Ok(ConstantExpr::int(*x as i64).into_expr())
Ok(ConstantExpr::uint8(*x).into_expr())
}
ScalarValue::Utf8(x) => {
ScalarValue::UInt16(x) => {
let x = x.as_ref().unwrap();
Ok(ConstantExpr::string(x).into_expr())
Ok(ConstantExpr::uint16(*x).into_expr())
}
ScalarValue::UInt32(x) => {
let x = x.as_ref().unwrap();
Ok(ConstantExpr::uint32(*x).into_expr())
}
ScalarValue::UInt64(x) => {
let x = x.as_ref().unwrap();
Ok(ConstantExpr::uint64(*x).into_expr())
}
ScalarValue::Int8(x) => {
let x = x.as_ref().unwrap();
Ok(ConstantExpr::int8(*x).into_expr())
}
ScalarValue::Int16(x) => {
let x = x.as_ref().unwrap();
Ok(ConstantExpr::int16(*x).into_expr())
}
ScalarValue::Int32(x) => {
let x = x.as_ref().unwrap();
Ok(ConstantExpr::int32(*x).into_expr())
}
ScalarValue::Int64(x) => {
let x = x.as_ref().unwrap();
Ok(ConstantExpr::int(*x).into_expr())
Ok(ConstantExpr::int64(*x).into_expr())
}
ScalarValue::Utf8(x) => {
let x = x.as_ref().unwrap();
Ok(ConstantExpr::string(x).into_expr())
}
ScalarValue::Date32(x) => {
let x = x.as_ref().unwrap();
Expand Down Expand Up @@ -118,7 +146,7 @@ impl OptdPlanContext<'_> {
expr,
)
.into_expr())
}
}
_ => bail!("Unsupported expression: {:?}", expr),
}
}
Expand Down Expand Up @@ -224,22 +252,18 @@ impl OptdPlanContext<'_> {
// instead of converting them to a join on true, we bail out

match node.filter {
Some(DFExpr::Literal(ScalarValue::Boolean(Some(val)))) => {
Ok(LogicalJoin::new(
left,
right,
ConstantExpr::bool(val).into_expr(),
join_type,
))
}
None => {
Ok(LogicalJoin::new(
left,
right,
ConstantExpr::bool(true).into_expr(),
join_type,
))
}
Some(DFExpr::Literal(ScalarValue::Boolean(Some(val)))) => Ok(LogicalJoin::new(
left,
right,
ConstantExpr::bool(val).into_expr(),
join_type,
)),
None => Ok(LogicalJoin::new(
left,
right,
ConstantExpr::bool(true).into_expr(),
join_type,
)),
_ => bail!("unsupported join filter: {:?}", node.filter),
}
} else if log_ops.len() == 1 {
Expand Down Expand Up @@ -279,7 +303,9 @@ impl OptdPlanContext<'_> {
LogicalPlan::Projection(node) => self.conv_into_optd_projection(node)?.into_plan_node(),
LogicalPlan::Sort(node) => self.conv_into_optd_sort(node)?.into_plan_node(),
LogicalPlan::Aggregate(node) => self.conv_into_optd_agg(node)?.into_plan_node(),
LogicalPlan::SubqueryAlias(node) => self.conv_into_optd_plan_node(node.input.as_ref())?,
LogicalPlan::SubqueryAlias(node) => {
self.conv_into_optd_plan_node(node.input.as_ref())?
}
LogicalPlan::Join(node) => self.conv_into_optd_join(node)?.into_plan_node(),
LogicalPlan::Filter(node) => self.conv_into_optd_filter(node)?.into_plan_node(),
LogicalPlan::CrossJoin(node) => self.conv_into_optd_cross_join(node)?.into_plan_node(),
Expand Down
4 changes: 2 additions & 2 deletions optd-datafusion-bridge/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ impl Catalog for DatafusionCatalog {
for field in fields.fields() {
let dt = match field.data_type() {
DataType::Date32 => ConstantType::Date,
DataType::Int32 => ConstantType::Int,
DataType::Int64 => ConstantType::Int,
DataType::Int32 => ConstantType::Int32,
DataType::Int64 => ConstantType::Int64,
DataType::Float64 => ConstantType::Decimal,
DataType::Utf8 => ConstantType::Utf8String,
DataType::Decimal128(_, _) => ConstantType::Decimal,
Expand Down
2 changes: 1 addition & 1 deletion optd-datafusion-repr/src/bin/test_optimize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub fn main() {
let scan1 = LogicalScan::new("t1".into());
let filter_cond = BinOpExpr::new(
ColumnRefExpr::new(1).0,
ConstantExpr::new(Value::Int(2)).0,
ConstantExpr::new(Value::Int64(2)).0,
BinOpType::Eq,
);
let filter1 = LogicalFilter::new(scan1.0, filter_cond.0);
Expand Down
54 changes: 48 additions & 6 deletions optd-datafusion-repr/src/plan_nodes/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,14 @@ impl OptRelNode for ExprList {
pub enum ConstantType {
Bool,
Utf8String,
Int,
UInt8,
UInt16,
UInt32,
UInt64,
Int8,
Int16,
Int32,
Int64,
Date,
Decimal,
Any,
Expand All @@ -78,7 +85,14 @@ impl ConstantExpr {
let typ = match &value {
Value::Bool(_) => ConstantType::Bool,
Value::String(_) => ConstantType::Utf8String,
Value::Int(_) => ConstantType::Int,
Value::UInt8(_) => ConstantType::UInt8,
Value::UInt16(_) => ConstantType::UInt16,
Value::UInt32(_) => ConstantType::UInt32,
Value::UInt64(_) => ConstantType::UInt64,
Value::Int8(_) => ConstantType::Int8,
Value::Int16(_) => ConstantType::Int16,
Value::Int32(_) => ConstantType::Int32,
Value::Int64(_) => ConstantType::Int64,
Value::Float(_) => ConstantType::Decimal,
_ => unimplemented!(),
};
Expand Down Expand Up @@ -107,12 +121,40 @@ impl ConstantExpr {
)
}

pub fn int(value: i64) -> Self {
Self::new_with_type(Value::Int(value), ConstantType::Int)
pub fn uint8(value: u8) -> Self {
Self::new_with_type(Value::UInt8(value), ConstantType::UInt8)
}

pub fn uint16(value: u16) -> Self {
Self::new_with_type(Value::UInt16(value), ConstantType::UInt16)
}

pub fn uint32(value: u32) -> Self {
Self::new_with_type(Value::UInt32(value), ConstantType::UInt32)
}

pub fn uint64(value: u64) -> Self {
Self::new_with_type(Value::UInt64(value), ConstantType::UInt64)
}

pub fn int8(value: i8) -> Self {
Self::new_with_type(Value::Int8(value), ConstantType::Int8)
}

pub fn int16(value: i16) -> Self {
Self::new_with_type(Value::Int16(value), ConstantType::Int16)
}

pub fn int32(value: i32) -> Self {
Self::new_with_type(Value::Int32(value), ConstantType::Int32)
}

pub fn int64(value: i64) -> Self {
Self::new_with_type(Value::Int64(value), ConstantType::Int64)
}

pub fn date(value: i64) -> Self {
Self::new_with_type(Value::Int(value), ConstantType::Date)
Self::new_with_type(Value::Int64(value), ConstantType::Date)
}

pub fn decimal(value: f64) -> Self {
Expand Down Expand Up @@ -152,7 +194,7 @@ impl ColumnRefExpr {
RelNode {
typ: OptRelNodeTyp::ColumnRef,
children: vec![],
data: Some(Value::Int(column_idx as i64)),
data: Some(Value::Int64(column_idx as i64)),
}
.into(),
))
Expand Down
29 changes: 29 additions & 0 deletions optd-sqlplannertest/tests/constant_predicate.planner.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
-- (no id or description)
create table t1(t1v1 int, t1v2 int);
insert into t1 values (0, 0), (1, 1), (2, 2);

/*
3
*/

-- Test whether the optimizer handles integer equality predicates correctly.
select * from t1 where t1v1 = 0;

/*
0 0
*/

-- Test whether the optimizer handles multiple integer equality predicates correctly.
select * from t1 where t1v1 = 0 and t1v2 = 1;

/*

*/

-- Test whether the optimizer handles multiple integer inequality predicates correctly.
select * from t1 where t1v1 = 0 and t1v2 != 1;

/*
0 0
*/

20 changes: 20 additions & 0 deletions optd-sqlplannertest/tests/constant_predicate.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
- sql: |
create table t1(t1v1 int, t1v2 int);
insert into t1 values (0, 0), (1, 1), (2, 2);
tasks:
- execute_with_logical
- sql: |
select * from t1 where t1v1 = 0;
desc: Test whether the optimizer handles integer equality predicates correctly.
tasks:
- execute_with_logical
- sql: |
select * from t1 where t1v1 = 0 and t1v2 = 1;
desc: Test whether the optimizer handles multiple integer equality predicates correctly.
tasks:
- execute_with_logical
- sql: |
select * from t1 where t1v1 = 0 and t1v2 != 1;
desc: Test whether the optimizer handles multiple integer inequality predicates correctly.
tasks:
- execute_with_logical
Loading