Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bril ops #611

Merged
merged 8 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ smallvec = "1.11.1"

syn = { version = "2.0", features = ["full", "extra-traits"] }
# currently using the uwplse/bril fork of bril, on eggcc-main
bril2json = { git = "https://github.com/uwplse/bril", rev = "78881c45aa53231915f333d1d6dcc26cedc63b57" }
brilirs = { git = "https://github.com/uwplse/bril", rev = "78881c45aa53231915f333d1d6dcc26cedc63b57" }
bril-rs = { git = "https://github.com/uwplse/bril", rev = "78881c45aa53231915f333d1d6dcc26cedc63b57" }
brilift = { git = "https://github.com/uwplse/bril", rev = "78881c45aa53231915f333d1d6dcc26cedc63b57" }
rs2bril = { git = "https://github.com/uwplse/bril", rev = "78881c45aa53231915f333d1d6dcc26cedc63b57" ,features = [
bril2json = { git = "https://github.com/uwplse/bril", rev = "8fd97903e7f46decb89398cf57a6dabd55e4fecf" }
brilirs = { git = "https://github.com/uwplse/bril", rev = "8fd97903e7f46decb89398cf57a6dabd55e4fecf" }
bril-rs = { git = "https://github.com/uwplse/bril", rev = "8fd97903e7f46decb89398cf57a6dabd55e4fecf" }
brilift = { git = "https://github.com/uwplse/bril", rev = "8fd97903e7f46decb89398cf57a6dabd55e4fecf" }
rs2bril = { git = "https://github.com/uwplse/bril", rev = "8fd97903e7f46decb89398cf57a6dabd55e4fecf" ,features = [
"import",
] }
brillvm = { git = "https://github.com/uwplse/bril", rev = "78881c45aa53231915f333d1d6dcc26cedc63b57" }
brillvm = { git = "https://github.com/uwplse/bril", rev = "8fd97903e7f46decb89398cf57a6dabd55e4fecf" }


ordered-float = { version = "3.7" }
Expand Down
24 changes: 24 additions & 0 deletions dag_in_context/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,22 @@ pub fn div(l: RcExpr, r: RcExpr) -> RcExpr {
RcExpr::new(Expr::Bop(BinaryOp::Div, l, r))
}

pub fn smax(l: RcExpr, r: RcExpr) -> RcExpr {
RcExpr::new(Expr::Bop(BinaryOp::Smax, l, r))
}

pub fn smin(l: RcExpr, r: RcExpr) -> RcExpr {
RcExpr::new(Expr::Bop(BinaryOp::Smin, l, r))
}

pub fn shl(l: RcExpr, r: RcExpr) -> RcExpr {
RcExpr::new(Expr::Bop(BinaryOp::Shl, l, r))
}

pub fn shr(l: RcExpr, r: RcExpr) -> RcExpr {
RcExpr::new(Expr::Bop(BinaryOp::Shr, l, r))
}

pub fn fadd(l: RcExpr, r: RcExpr) -> RcExpr {
RcExpr::new(Expr::Bop(BinaryOp::FAdd, l, r))
}
Expand All @@ -118,6 +134,14 @@ pub fn fdiv(l: RcExpr, r: RcExpr) -> RcExpr {
RcExpr::new(Expr::Bop(BinaryOp::FDiv, l, r))
}

pub fn fmax(l: RcExpr, r: RcExpr) -> RcExpr {
RcExpr::new(Expr::Bop(BinaryOp::Fmax, l, r))
}

pub fn fmin(l: RcExpr, r: RcExpr) -> RcExpr {
RcExpr::new(Expr::Bop(BinaryOp::Fmin, l, r))
}

pub fn less_than(l: RcExpr, r: RcExpr) -> RcExpr {
RcExpr::new(Expr::Bop(BinaryOp::LessThan, l, r))
}
Expand Down
6 changes: 6 additions & 0 deletions dag_in_context/src/from_egglog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ impl<'a> FromEgglog<'a> {
("GreaterThan", []) => BinaryOp::GreaterThan,
("LessEq", []) => BinaryOp::LessEq,
("GreaterEq", []) => BinaryOp::GreaterEq,
("Smax", []) => BinaryOp::Smax,
("Smin", []) => BinaryOp::Smin,
("Shl", []) => BinaryOp::Shl,
("Shr", []) => BinaryOp::Shr,
("FAdd", []) => BinaryOp::FAdd,
("FSub", []) => BinaryOp::FSub,
("FMul", []) => BinaryOp::FMul,
Expand All @@ -177,6 +181,8 @@ impl<'a> FromEgglog<'a> {
("FGreaterThan", []) => BinaryOp::FGreaterThan,
("FLessEq", []) => BinaryOp::FLessEq,
("FGreaterEq", []) => BinaryOp::FGreaterEq,
("Fmax", []) => BinaryOp::Fmax,
("Fmin", []) => BinaryOp::Fmin,
("And", []) => BinaryOp::And,
("Or", []) => BinaryOp::Or,
("PtrAdd", []) => BinaryOp::PtrAdd,
Expand Down
6 changes: 3 additions & 3 deletions dag_in_context/src/greedy_dag_extractor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -857,15 +857,15 @@ impl CostModel for DefaultCostModel {
"Base" | "TupleT" | "TNil" | "TCons" => 0.,
"Int" | "Bool" | "Float" => 0.,
// Algebra
"Add" | "PtrAdd" | "Sub" | "And" | "Or" | "Not" => 10.,
"FAdd" | "FSub" => 50.,
"Add" | "PtrAdd" | "Sub" | "And" | "Or" | "Not" | "Shl" | "Shr" => 10.,
"FAdd" | "FSub" | "Fmax" | "Fmin" => 50.,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have any justification for these costs it's worth a comment. Otherwise, it's fine we can play with these

"Mul" => 30.,
"FMul" => 150.,
"Div" => 50.,
"FDiv" => 250.,
// Comparisons
"Eq" | "LessThan" | "GreaterThan" | "LessEq" | "GreaterEq" => 10.,
"Select" => 10.,
"Select" | "Smax" | "Smin" => 10.,
"FEq" => 10.,
"FLessThan" | "FGreaterThan" | "FLessEq" | "FGreaterEq" => 100.,
// Effects
Expand Down
29 changes: 28 additions & 1 deletion dag_in_context/src/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
//! shared as the same Rc pointer. Otherwise, effects may be executed multiple times.
//! The invariant is maintained by translation from RVSDG, type checking, and translation from egglog.

use std::{collections::HashMap, fmt::Display, rc::Rc};
use std::{
collections::HashMap,
fmt::Display,
ops::{Shl, Shr},
rc::Rc,
};

use crate::{
schema::{BinaryOp, Constant, Expr, RcExpr, TernaryOp, TreeProgram, UnaryOp},
Expand Down Expand Up @@ -242,6 +247,18 @@ impl<'a> VirtualMachine<'a> {
BinaryOp::Div => Const(Constant::Int(
get_int(e1, self).wrapping_div(get_int(e2, self)),
)),
BinaryOp::Smax => {
let a = get_int(e1, self);
let b = get_int(e2, self);
Const(Constant::Int(if a > b { a } else { b }))
}
BinaryOp::Smin => {
let a = get_int(e1, self);
let b = get_int(e2, self);
Const(Constant::Int(if a < b { a } else { b }))
}
BinaryOp::Shl => Const(Constant::Int(get_int(e1, self).shl(get_int(e2, self)))),
BinaryOp::Shr => Const(Constant::Int(get_int(e1, self).shr(get_int(e2, self)))),
BinaryOp::Eq => Const(Constant::Bool(get_int(e1, self) == get_int(e2, self))),
BinaryOp::LessThan => Const(Constant::Bool(get_int(e1, self) < get_int(e2, self))),
BinaryOp::GreaterThan => Const(Constant::Bool(get_int(e1, self) > get_int(e2, self))),
Expand Down Expand Up @@ -303,6 +320,16 @@ impl<'a> VirtualMachine<'a> {
BinaryOp::FGreaterEq => {
Const(Constant::Bool(get_float(e1, self) >= get_float(e2, self)))
}
BinaryOp::Fmax => {
let a = get_float(e1, self);
let b = get_float(e2, self);
Const(Constant::Float(if a > b { a } else { b }))
}
BinaryOp::Fmin => {
let a = get_float(e1, self);
let b = get_float(e2, self);
Const(Constant::Float(if a < b { a } else { b }))
}
}
}

Expand Down
6 changes: 4 additions & 2 deletions dag_in_context/src/optimizations/purity_analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ fn top_is_pure(top: &TernaryOp) -> bool {
fn bop_is_pure(bop: &BinaryOp) -> bool {
use BinaryOp::*;
match bop {
Add | Sub | Mul | LessThan | Div | Eq | GreaterThan | LessEq | GreaterEq => true,
FAdd | FSub | FMul | FLessThan | FDiv | FEq | FGreaterThan | FLessEq | FGreaterEq => true,
Add | Sub | Mul | LessThan | Div | Eq | GreaterThan | LessEq | GreaterEq | Smax | Smin
| Shl | Shr => true,
FAdd | FSub | FMul | FLessThan | FDiv | FEq | FGreaterThan | FLessEq | FGreaterEq
| Fmax | Fmin => true,
PtrAdd => true,
And | Or => true,
Load | Print | Free => false,
Expand Down
6 changes: 6 additions & 0 deletions dag_in_context/src/schema.egg
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@
(LessEq)
(GreaterEq)
(Eq)
(Smin)
(Smax)
(Shl)
(Shr)
;; float operators
(FAdd)
(FSub)
Expand All @@ -104,6 +108,8 @@
(FLessEq)
(FGreaterEq)
(FEq)
(Fmin)
(Fmax)
;; logical operators
(And)
(Or)
Expand Down
6 changes: 6 additions & 0 deletions dag_in_context/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ pub enum BinaryOp {
GreaterThan,
LessEq,
GreaterEq,
Smax,
Smin,
Shl,
Shr,
FAdd,
FSub,
FMul,
Expand All @@ -56,6 +60,8 @@ pub enum BinaryOp {
FGreaterThan,
FLessEq,
FGreaterEq,
Fmax,
Fmin,
And,
Or,
PtrAdd,
Expand Down
26 changes: 20 additions & 6 deletions dag_in_context/src/schema_helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ impl BinaryOp {
LessThan => "LessThan",
GreaterEq => "GreaterEq",
LessEq => "LessEq",
Smax => "Smax",
Smin => "Smin",
Shl => "Shl",
Shr => "Shr",
FAdd => "FAdd",
FSub => "FSub",
FMul => "FMul",
Expand All @@ -77,6 +81,8 @@ impl BinaryOp {
FLessThan => "FLessThan",
FGreaterEq => "FGreaterEq",
FLessEq => "FLessEq",
Fmax => "Fmax",
Fmin => "Fmin",
And => "And",
Or => "Or",
Load => "Load",
Expand Down Expand Up @@ -690,12 +696,20 @@ impl BinaryOp {
/// When a binary op has concrete input sorts, return them.
pub fn types(&self) -> Option<(Type, Type, Type)> {
match self {
BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div => {
Some((base(intt()), base(intt()), base(intt())))
}
BinaryOp::FAdd | BinaryOp::FSub | BinaryOp::FMul | BinaryOp::FDiv => {
Some((base(floatt()), base(floatt()), base(floatt())))
}
BinaryOp::Add
| BinaryOp::Sub
| BinaryOp::Mul
| BinaryOp::Div
| BinaryOp::Smax
| BinaryOp::Smin
| BinaryOp::Shl
| BinaryOp::Shr => Some((base(intt()), base(intt()), base(intt()))),
BinaryOp::FAdd
| BinaryOp::FSub
| BinaryOp::FMul
| BinaryOp::FDiv
| BinaryOp::Fmax
| BinaryOp::Fmin => Some((base(floatt()), base(floatt()), base(floatt()))),
BinaryOp::And | BinaryOp::Or => Some((base(boolt()), base(boolt()), base(boolt()))),
BinaryOp::LessThan
| BinaryOp::GreaterThan
Expand Down
6 changes: 6 additions & 0 deletions src/rvsdg/from_dag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ fn value_op_from_binary_op(bop: BinaryOp) -> Option<ValueOps> {
BinaryOp::GreaterThan => Some(ValueOps::Gt),
BinaryOp::LessEq => Some(ValueOps::Le),
BinaryOp::GreaterEq => Some(ValueOps::Ge),
BinaryOp::Smax => Some(ValueOps::Smax),
BinaryOp::Smin => Some(ValueOps::Smin),
BinaryOp::Shl => Some(ValueOps::Shl),
BinaryOp::Shr => Some(ValueOps::Shr),
// float operators
BinaryOp::FAdd => Some(ValueOps::Fadd),
BinaryOp::FSub => Some(ValueOps::Fsub),
Expand All @@ -163,6 +167,8 @@ fn value_op_from_binary_op(bop: BinaryOp) -> Option<ValueOps> {
BinaryOp::FGreaterThan => Some(ValueOps::Fgt),
BinaryOp::FLessEq => Some(ValueOps::Fle),
BinaryOp::FGreaterEq => Some(ValueOps::Fge),
BinaryOp::Fmax => Some(ValueOps::Fmax),
BinaryOp::Fmin => Some(ValueOps::Fmin),
// logical op
BinaryOp::And => Some(ValueOps::And),
BinaryOp::Or => Some(ValueOps::Or),
Expand Down
7 changes: 7 additions & 0 deletions src/rvsdg/to_dag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,11 @@ impl<'a> DagTranslator<'a> {
(ValueOps::Sub, [a, b]) => sub(a.clone(), b.clone()),
(ValueOps::Div, [a, b]) => div(a.clone(), b.clone()),

(ValueOps::Smax, [a, b]) => smax(a.clone(), b.clone()),
(ValueOps::Smin, [a, b]) => smin(a.clone(), b.clone()),
(ValueOps::Shl, [a, b]) => shl(a.clone(), b.clone()),
(ValueOps::Shr, [a, b]) => shr(a.clone(), b.clone()),

(ValueOps::Fadd, [a, b]) => fadd(a.clone(), b.clone()),
(ValueOps::Fmul, [a, b]) => fmul(a.clone(), b.clone()),
(ValueOps::Fsub, [a, b]) => fsub(a.clone(), b.clone()),
Expand All @@ -256,6 +261,8 @@ impl<'a> DagTranslator<'a> {
(ValueOps::Flt, [a, b]) => fless_than(a.clone(), b.clone()),
(ValueOps::Fge, [a, b]) => fgreater_eq(a.clone(), b.clone()),
(ValueOps::Fle, [a, b]) => fless_eq(a.clone(), b.clone()),
(ValueOps::Fmax, [a, b]) => fmax(a.clone(), b.clone()),
(ValueOps::Fmin, [a, b]) => fmin(a.clone(), b.clone()),

(ValueOps::And, [a, b]) => and(a.clone(), b.clone()),
(ValueOps::Or, [a, b]) => or(a.clone(), b.clone()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ expression: visualization.result
v14_: bool = const false;
v15_: int = id v6_;
v16_: bool = id v14_;
v17_: int = id v7_;
v18_: int = id v7_;
v17_: int = id v8_;
v18_: int = id v8_;
v19_: int = id v9_;
v20_: int = id v10_;
v21_: int = id v11_;
Expand Down
Loading
Loading