diff --git a/src/sort/i64.rs b/src/sort/i64.rs index fef30925..977d3bdf 100644 --- a/src/sort/i64.rs +++ b/src/sort/i64.rs @@ -59,6 +59,13 @@ impl Sort for I64Sort { add_primitives!(typeinfo, "to-string" = |a: i64| -> Symbol { a.to_string().into() }); + // Must be in the i64 sort register function because the string sort is registered before the i64 sort. + typeinfo.add_primitive(CountMatches { + name: "count-matches".into(), + string: typeinfo.get_sort(), + int: self.clone(), + }); + } fn make_expr(&self, _egraph: &EGraph, value: Value) -> (Cost, Expr) { @@ -83,3 +90,32 @@ impl FromSort for i64 { value.bits as Self } } + +struct CountMatches { + name: Symbol, + string: Arc, + int: Arc, +} + +impl PrimitiveLike for CountMatches { + fn name(&self) -> Symbol { + self.name + } + + fn accept(&self, types: &[ArcSort]) -> Option { + if types.len() == 2 + && types[0].name() == self.string.name + && types[1].name() == self.string.name + { + Some(self.int.clone()) + } else { + None + } + } + + fn apply(&self, values: &[Value]) -> Option { + let string1 = Symbol::load(&self.string, &values[0]).to_string(); + let string2 = Symbol::load(&self.string, &values[1]).to_string(); + Some(Value::from(string1.matches(&string2).count() as i64)) + } +} diff --git a/src/sort/string.rs b/src/sort/string.rs index a8ffb03c..34bbac78 100644 --- a/src/sort/string.rs +++ b/src/sort/string.rs @@ -6,7 +6,7 @@ use super::*; #[derive(Debug)] pub struct StringSort { - name: Symbol, + pub name: Symbol, } impl StringSort { @@ -33,6 +33,10 @@ impl Sort for StringSort { fn register_primitives(self: Arc, typeinfo: &mut TypeInfo) { typeinfo.add_primitive(Add { name: "+".into(), + string: self.clone(), + }); + typeinfo.add_primitive(Replace { + name: "replace".into(), string: self, }); } @@ -85,3 +89,34 @@ impl PrimitiveLike for Add { Some(Value::from(res_symbol)) } } + +struct Replace { + name: Symbol, + string: Arc, +} + +impl PrimitiveLike for Replace { + fn name(&self) -> Symbol { + self.name + } + + fn accept(&self, types: &[ArcSort]) -> Option { + if types.len() == 3 + && types[0].name() == self.string.name + && types[1].name() == self.string.name + && types[2].name() == self.string.name + { + Some(self.string.clone()) + } else { + None + } + } + + fn apply(&self, values: &[Value]) -> Option { + let string1 = Symbol::load(&self.string, &values[0]).to_string(); + let string2 = Symbol::load(&self.string, &values[1]).to_string(); + let string3 = Symbol::load(&self.string, &values[2]).to_string(); + let res: Symbol = string1.replace(&string2, &string3).into(); + Some(Value::from(res)) + } +} diff --git a/tests/string.egg b/tests/string.egg index 34e84b3a..74a879ef 100644 --- a/tests/string.egg +++ b/tests/string.egg @@ -1,2 +1,8 @@ ; Tests for the string sort + +; Concatenation (check (= (+ "a" "bc" "de") "abcde")) +; Counting the number of substring occurances +(check (= (count-matches "ab ab" "ab") 2)) +; replacing a substring +(check (= (replace "ab ab" "ab" "cd") "cd cd"))