From b64b56f9e9a4527bb7e40729e94e49a6845feab0 Mon Sep 17 00:00:00 2001 From: Sean Date: Tue, 25 Jun 2024 19:50:15 +0900 Subject: [PATCH] feat: `segmented_segment_tree` --- src/lib.cairo | 1 + src/libraries/segmented_segment_tree.cairo | 179 +++++++++++++++++++++ 2 files changed, 180 insertions(+) create mode 100644 src/libraries/segmented_segment_tree.cairo diff --git a/src/lib.cairo b/src/lib.cairo index 9d8adfb..4df9767 100644 --- a/src/lib.cairo +++ b/src/lib.cairo @@ -7,6 +7,7 @@ pub mod libraries { pub mod tick; pub mod tick_bitmap; pub mod significant_bit; + pub mod segmented_segment_tree; } pub mod alexandria { diff --git a/src/libraries/segmented_segment_tree.cairo b/src/libraries/segmented_segment_tree.cairo new file mode 100644 index 0000000..488867a --- /dev/null +++ b/src/libraries/segmented_segment_tree.cairo @@ -0,0 +1,179 @@ +pub mod SegmentedSegmentTree { + use core::traits::TryInto; + use core::array::ArrayTrait; + use core::traits::Into; + use clober_cairo::libraries::packed_felt252::get_u62; + use clober_cairo::libraries::packed_felt252::sum_u62; + use clober_cairo::libraries::packed_felt252::update_62; + + #[derive(Copy, Drop, Serde, Debug)] + struct LayerIndex { + pub group: u256, + pub node: u8 + } + + const TWO_POW_16: u256 = 0x10000; // 2**16 + + // const R: u256 = 2; // There are `2` root node groups + // const C: u256 = 4; // There are `4` children (each child is a node group of its own) for + // each node + const L: u8 = 4; // There are `4` layers of node groups + const P: u8 = 4; // uint256 / uint64 = `4` + const P_M: u256 = 3; // % 4 = & `3` + const P_P: u256 = 2; // 2 ** `2` = 4 + const N_P: u256 = 4; // C * P = 2 ** `4` + const MAX_NODES: u256 = 0x8000; // (R * P) * ((C * P) ** (L - 1)) = `32768` + const MAX_NODES_P_MINUS_ONE: u256 = 14; // MAX_NODES / R = 2 ** `14` + + fn get(mut layers: Felt252Dict, index: u256) -> u64 { + assert(index < MAX_NODES, 'INDEX_ERROR'); + let key: felt252 = ((L.into() - 1) * MAX_NODES + index / P.into()).try_into().unwrap(); + get_u62(layers[key], (index & P_M).try_into().unwrap()) + } + + fn total(mut layers: Felt252Dict) -> u256 { + sum_u62(layers[0], 0, 4) + sum_u62(layers[1], 0, 4) + } + + fn _get_layer_indices(index: u256) -> Array { + let mut indices: Array = ArrayTrait::new(); + let mut shifter: u256 = MAX_NODES / 2; + let mut l = 0; + while l < L { + indices + .append( + LayerIndex { + group: index / shifter, + node: ((index / (shifter / P.into())) & P_M).try_into().unwrap() + } + ); + shifter /= 16; + l += 1; + }; + indices + } + + fn query(mut layers: Felt252Dict, left: u256, right: u256) -> u256 { + if left == right { + return 0; + } + assert(left < right, 'INDEX_ERROR'); + assert(right <= MAX_NODES, 'INDEX_ERROR'); + + let left_indices: Array = _get_layer_indices(left); + let right_indices: Array = _get_layer_indices(right); + let mut ret: u256 = 0; + let mut deficit: u256 = 0; + + let mut left_node_index: u8 = 0; + let mut right_node_index: u8 = 0; + let mut l: u256 = (L - 1).into(); + + while l >= 0 { + let left_index: LayerIndex = *left_indices.at(l.try_into().unwrap()); + let right_index: LayerIndex = *right_indices.at(l.try_into().unwrap()); + left_node_index += left_index.node; + right_node_index += right_index.node; + + if right_index.group == left_index.group { + let key: felt252 = (l * MAX_NODES + left_index.group).try_into().unwrap(); + ret += sum_u62(layers[key], left_node_index, right_node_index); + break; + } + + if right_index.group - left_index.group < 4 { + let key: felt252 = (l * MAX_NODES + left_index.group).try_into().unwrap(); + ret += sum_u62(layers[key], left_node_index, P); + + let key: felt252 = (l * MAX_NODES + right_index.group).try_into().unwrap(); + ret += sum_u62(layers[key], 0, right_node_index); + let mut group = left_index.group + 1; + while group < right_index.group { + let key: felt252 = (l * MAX_NODES + group).try_into().unwrap(); + ret += sum_u62(layers[key], 0, P); + group += 1; + }; + break; + } + + if left_index.group % 4 == 0 { + let key: felt252 = (l * MAX_NODES + left_index.group).try_into().unwrap(); + deficit += sum_u62(layers[key], 0, left_node_index); + left_node_index = 0; + } else if left_index.group % 4 == 1 { + let key: felt252 = (l * MAX_NODES + left_index.group).try_into().unwrap(); + deficit += sum_u62(layers[key - 1], 0, P); + deficit += sum_u62(layers[key], 0, left_node_index); + left_node_index = 0; + } else if left_index.group % 4 == 2 { + let key: felt252 = (l * MAX_NODES + left_index.group).try_into().unwrap(); + ret += sum_u62(layers[key], left_node_index, P); + ret += sum_u62(layers[key + 1], 0, P); + left_node_index = 1; + } else { + let key: felt252 = (l * MAX_NODES + left_index.group).try_into().unwrap(); + ret += sum_u62(layers[key], left_node_index, P); + left_node_index = 1; + } + + if right_index.group % 4 == 0 { + let key: felt252 = (l * MAX_NODES + right_index.group).try_into().unwrap(); + ret += sum_u62(layers[key], 0, right_node_index); + right_node_index = 0; + } else if right_index.group % 4 == 1 { + let key: felt252 = (l * MAX_NODES + right_index.group).try_into().unwrap(); + ret += sum_u62(layers[key - 1], 0, P); + ret += sum_u62(layers[key], 0, right_node_index); + right_node_index = 0; + } else if right_index.group % 4 == 2 { + let key: felt252 = (l * MAX_NODES + right_index.group).try_into().unwrap(); + deficit += sum_u62(layers[key], right_node_index, P); + deficit += sum_u62(layers[key + 1], 0, P); + right_node_index = 1; + } else { + let key: felt252 = (l * MAX_NODES + right_index.group).try_into().unwrap(); + deficit += sum_u62(layers[key], right_node_index, P); + right_node_index = 1; + } + + l -= 1; + }; + ret - deficit + } + + fn update(ref layers: Felt252Dict, index: u256, value: u64) -> u64 { + assert(index < MAX_NODES, 'INDEX_ERROR'); + let indices: Array = _get_layer_indices(index); + let bottom_index: LayerIndex = *indices.at((L - 1).try_into().unwrap()).try_into().unwrap(); + let key: felt252 = (MAX_NODES * (L.into() - 1) + bottom_index.group).try_into().unwrap(); + let replaced = get_u62(layers[key], bottom_index.node); + if replaced >= value { + let diff = replaced - value; + let l: u8 = 0; + while l < L { + let layer_index: LayerIndex = *indices.at(l.into()); + let key: felt252 = (l.into() * MAX_NODES + layer_index.group).try_into().unwrap(); + let node: felt252 = layers[key]; + layers + .insert( + key, + update_62(node, layer_index.node, get_u62(node, layer_index.node) - diff) + ); + } + } else { + let diff = value - replaced; + let l: u8 = 0; + while l < L { + let layer_index: LayerIndex = *indices.at(l.into()); + let key: felt252 = (l.into() * MAX_NODES + layer_index.group).try_into().unwrap(); + let node: felt252 = layers[key]; + layers + .insert( + key, + update_62(node, layer_index.node, get_u62(node, layer_index.node) + diff) + ); + } + } + replaced + } +}