Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SmallAtom optimization #372

Merged
merged 12 commits into from
Feb 9, 2024
601 changes: 514 additions & 87 deletions src/allocator.rs

Large diffs are not rendered by default.

28 changes: 15 additions & 13 deletions src/chia_dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ impl Dialect for ChiaDialect {
max_cost: Cost,
extension: OperatorSet,
) -> Response {
let b = allocator.atom(o);
if b.len() == 4 {
let op_len = allocator.atom_len(o);
if op_len == 4 {
// these are unkown operators with assigned cost
// the formula is:
// +---+---+---+------------+
Expand All @@ -83,6 +83,7 @@ impl Dialect for ChiaDialect {
// (3 bytes) + 2 bits
// cost_function

let b = allocator.atom(o);
let opcode = u32::from_be_bytes(b.try_into().unwrap());

// the secp operators have a fixed cost of 1850000 and 1300000,
Expand All @@ -97,10 +98,13 @@ impl Dialect for ChiaDialect {
};
return f(allocator, argument_list, max_cost);
}
if b.len() != 1 {
if op_len != 1 {
return unknown_operator(allocator, o, argument_list, self.flags, max_cost);
}
let f = match b[0] {
let Some(op) = allocator.small_number(o) else {
return unknown_operator(allocator, o, argument_list, self.flags, max_cost);
};
let f = match op {
// 1 = quote
// 2 = apply
3 => op_if,
Expand Down Expand Up @@ -146,7 +150,7 @@ impl Dialect for ChiaDialect {
_ => {
if extension == OperatorSet::BLS || (self.flags & ENABLE_BLS_OPS_OUTSIDE_GUARD) != 0
{
match b[0] {
match op {
48 => op_coinid,
49 => op_bls_g1_subtract,
50 => op_bls_g1_multiply,
Expand Down Expand Up @@ -179,16 +183,14 @@ impl Dialect for ChiaDialect {
f(allocator, argument_list, max_cost)
}

fn quote_kw(&self) -> &[u8] {
&[1]
fn quote_kw(&self) -> u32 {
1
}

fn apply_kw(&self) -> &[u8] {
&[2]
fn apply_kw(&self) -> u32 {
2
}

fn softfork_kw(&self) -> &[u8] {
&[36]
fn softfork_kw(&self) -> u32 {
36
}

// interpret the extension argument passed to the softfork operator, and
Expand Down
6 changes: 3 additions & 3 deletions src/dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ pub enum OperatorSet {
}

pub trait Dialect {
fn quote_kw(&self) -> &[u8];
fn apply_kw(&self) -> &[u8];
fn softfork_kw(&self) -> &[u8];
fn quote_kw(&self) -> u32;
fn apply_kw(&self) -> u32;
fn softfork_kw(&self) -> u32;
fn softfork_extension(&self, ext: u32) -> OperatorSet;
fn op(
&self,
Expand Down
77 changes: 61 additions & 16 deletions src/more_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::ops::BitAndAssign;
use std::ops::BitOrAssign;
use std::ops::BitXorAssign;

use crate::allocator::{Allocator, NodePtr, SExp};
use crate::allocator::{len_for_value, Allocator, NodePtr, NodeVisitor, SExp};
use crate::cost::{check_cost, Cost};
use crate::err_utils::err;
use crate::number::Number;
Expand Down Expand Up @@ -365,9 +365,21 @@ pub fn op_add(a: &mut Allocator, mut input: NodePtr, max_cost: Cost) -> Response
cost + (byte_count as Cost * ARITH_COST_PER_BYTE),
max_cost,
)?;
let (v, len) = int_atom(a, arg, "+")?;
byte_count += len;
total += v;

match a.node(arg) {
NodeVisitor::Buffer(buf) => {
use crate::number::number_from_u8;
total += number_from_u8(buf);
byte_count += buf.len();
}
NodeVisitor::U32(val) => {
total += val;
byte_count += len_for_value(val);
}
NodeVisitor::Pair(_, _) => {
return err(arg, "+ requires int args");
}
}
}
let total = a.new_number(total)?;
cost += byte_count as Cost * ARITH_COST_PER_BYTE;
Expand All @@ -383,12 +395,25 @@ pub fn op_subtract(a: &mut Allocator, mut input: NodePtr, max_cost: Cost) -> Res
input = rest;
cost += ARITH_COST_PER_ARG;
check_cost(a, cost + byte_count as Cost * ARITH_COST_PER_BYTE, max_cost)?;
let (v, len) = int_atom(a, arg, "-")?;
byte_count += len;
if is_first {
total += v;
let (v, len) = int_atom(a, arg, "-")?;
byte_count = len;
total = v;
} else {
total -= v;
match a.node(arg) {
NodeVisitor::Buffer(buf) => {
use crate::number::number_from_u8;
total -= number_from_u8(buf);
byte_count += buf.len();
}
NodeVisitor::U32(val) => {
total -= val;
byte_count += len_for_value(val);
}
NodeVisitor::Pair(_, _) => {
return err(arg, "- requires int args");
}
}
};
is_first = false;
}
Expand All @@ -411,14 +436,24 @@ pub fn op_multiply(a: &mut Allocator, mut input: NodePtr, max_cost: Cost) -> Res
continue;
}

let (v0, l1) = int_atom(a, arg, "*")?;
let l1 = match a.node(arg) {
NodeVisitor::Buffer(buf) => {
use crate::number::number_from_u8;
total *= number_from_u8(buf);
buf.len()
}
NodeVisitor::U32(val) => {
total *= val;
len_for_value(val)
}
NodeVisitor::Pair(_, _) => {
return err(arg, "* requires int args");
}
};

total *= v0;
cost += MUL_COST_PER_OP;

cost += (l0 + l1) as Cost * MUL_LINEAR_COST_PER_BYTE;
cost += (l0 * l1) as Cost / MUL_SQUARE_COST_PER_BYTE_DIVIDER;

l0 = limbs_for_int(&total);
}
let total = a.new_number(total)?;
Expand Down Expand Up @@ -490,10 +525,20 @@ pub fn op_mod(a: &mut Allocator, input: NodePtr, _max_cost: Cost) -> Response {

pub fn op_gr(a: &mut Allocator, input: NodePtr, _max_cost: Cost) -> Response {
let [v0, v1] = get_args::<2>(a, input, ">")?;
let (v0, v0_len) = int_atom(a, v0, ">")?;
let (v1, v1_len) = int_atom(a, v1, ">")?;
let cost = GR_BASE_COST + (v0_len + v1_len) as Cost * GR_COST_PER_BYTE;
Ok(Reduction(cost, if v0 > v1 { a.one() } else { a.nil() }))

match (a.small_number(v0), a.small_number(v1)) {
(Some(lhs), Some(rhs)) => {
let cost =
GR_BASE_COST + (len_for_value(lhs) + len_for_value(rhs)) as Cost * GR_COST_PER_BYTE;
Ok(Reduction(cost, if lhs > rhs { a.one() } else { a.nil() }))
}
_ => {
let (v0, v0_len) = int_atom(a, v0, ">")?;
let (v1, v1_len) = int_atom(a, v1, ">")?;
let cost = GR_BASE_COST + (v0_len + v1_len) as Cost * GR_COST_PER_BYTE;
Ok(Reduction(cost, if v0 > v1 { a.one() } else { a.nil() }))
}
}
}

pub fn op_gr_bytes(a: &mut Allocator, input: NodePtr, _max_cost: Cost) -> Response {
Expand Down
81 changes: 39 additions & 42 deletions src/op_utils.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::allocator::{Allocator, NodePtr, SExp};
use crate::allocator::{Allocator, NodePtr, NodeVisitor, SExp};
use crate::cost::Cost;
use crate::err_utils::err;
use crate::number::Number;
Expand Down Expand Up @@ -279,37 +279,36 @@ pub fn uint_atom<const SIZE: usize>(
args: NodePtr,
op_name: &str,
) -> Result<u64, EvalErr> {
let bytes = match a.sexp(args) {
SExp::Atom => a.atom(args),
_ => {
return err(args, &format!("{op_name} requires int arg"));
match a.node(args) {
NodeVisitor::Buffer(bytes) => {
if bytes.is_empty() {
return Ok(0);
}

if (bytes[0] & 0x80) != 0 {
return err(args, &format!("{op_name} requires positive int arg"));
}

// strip leading zeros
let mut buf: &[u8] = bytes;
while !buf.is_empty() && buf[0] == 0 {
buf = &buf[1..];
}

if buf.len() > SIZE {
return err(args, &format!("{op_name} requires u{} arg", SIZE * 8));
}

let mut ret = 0;
for b in buf {
ret <<= 8;
ret |= *b as u64;
}
Ok(ret)
}
};

if bytes.is_empty() {
return Ok(0);
}

if (bytes[0] & 0x80) != 0 {
return err(args, &format!("{op_name} requires positive int arg"));
}

// strip leading zeros
let mut buf: &[u8] = bytes;
while !buf.is_empty() && buf[0] == 0 {
buf = &buf[1..];
}

if buf.len() > SIZE {
return err(args, &format!("{op_name} requires u{} arg", SIZE * 8));
}

let mut ret = 0;
for b in buf {
ret <<= 8;
ret |= *b as u64;
NodeVisitor::U32(val) => Ok(val as u64),
NodeVisitor::Pair(_, _) => err(args, &format!("{op_name} requires int arg")),
}
Ok(ret)
}

#[cfg(test)]
Expand Down Expand Up @@ -532,18 +531,16 @@ fn test_u64_from_bytes() {
}

pub fn i32_atom(a: &Allocator, args: NodePtr, op_name: &str) -> Result<i32, EvalErr> {
let buf = match a.sexp(args) {
SExp::Atom => a.atom(args),
_ => {
return err(args, &format!("{op_name} requires int32 args"));
}
};
match i32_from_u8(buf) {
Some(v) => Ok(v),
_ => err(
args,
&format!("{op_name} requires int32 args (with no leading zeros)"),
),
match a.node(args) {
NodeVisitor::Buffer(buf) => match i32_from_u8(buf) {
Some(v) => Ok(v),
_ => err(
args,
&format!("{op_name} requires int32 args (with no leading zeros)"),
),
},
NodeVisitor::U32(val) => Ok(val as i32),
NodeVisitor::Pair(_, _) => err(args, &format!("{op_name} requires int32 args")),
}
}

Expand Down
Loading
Loading