Skip to content

Commit

Permalink
🚸 Enhance Trie functionality (#24)
Browse files Browse the repository at this point in the history
* 🚸 Support .add with an an iterable of strings

* 🚸 Allow instantiation with arguments

* 🚸 Enable method chaining off .add

* 🚸 Implement __eq__

* 🚸 Implement __add__ and __iadd__

* ✅ Add tests for operators and chaining

* 📝 Update Trie documentation

* 🐛 Remove f-string for backwards compatibility

* 🐛 Remove f-string for backwards compatibility

* ✅ Add tests for type guards on + and +=

* 🩹 Improve type guard for +=

Co-authored-by: ddelange <[email protected]>

* ♻️ Reimplement __add__, __iadd__, and _merge

* ✅ Update tests for Trie

* ♻️ Simplify == implementation

* ⏪ Revert a6339c4

* 🎨 Reorganize test cases for Trie

* 🧪 Add tests for unwanted RHS sub-dict mutation after +=

* ✅ Prevent unwanted sub-dict propagation in __iadd__

---------

Co-authored-by: ddelange <[email protected]>
  • Loading branch information
michen00 and ddelange committed Feb 22, 2024
1 parent 39211a2 commit f7d39b9
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 3 deletions.
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,21 @@ trie.add("abxy")
assert trie.pattern() == "(?:ab(?:c|s(?:olute)?|xy?)|foo)"
```

A Trie may be populated with zero or more strings at instantiation or via `.add`, from which method chaining is possible. Two Trie may be merged with the `+` and `+=` operators and will compare equal if their data dictionaries are equal.

```py
trie = Trie()
trie += Trie("abc")
assert (
trie + Trie().add("foo")
== Trie("abc", "foo")
== Trie(*["abc", "foo"])
== Trie().add(*["abc", "foo"])
== Trie().add("abc", "foo")
== Trie().add("abc").add("foo")
)
```


## Installation

Expand Down
70 changes: 67 additions & 3 deletions src/retrie/trie.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,27 @@
trie.add("abxy")
assert trie.pattern() == "(?:ab(?:c|s(?:olute)?|xy?)|foo)"
A Trie may be populated with zero or more strings at instantiation or via `.add`, from
which method chaining is possible. Two Trie may be merged with the `+` and `+=`
operators and will compare equal if their data dictionaries are equal.
::
trie = Trie()
trie += Trie("abc")
assert (
trie + Trie().add("foo")
== Trie("abc", "foo")
== Trie(*["abc", "foo"])
== Trie().add(*["abc", "foo"])
== Trie().add("abc", "foo")
== Trie().add("abc").add("foo")
)
"""

import re
from typing import Dict, Optional, Text # noqa:F401
from copy import deepcopy
from typing import Any, Dict, Optional, Text # noqa:F401

data_type = Dict[Text, Dict]

Expand All @@ -35,20 +53,66 @@ class Trie:

__slots__ = "data"

def __init__(self):
def __init__(
self, *word # type: Text
):
"""Initialize data dictionary."""
self.data = {} # type: data_type
self.add(*word)

def __eq__(
self, other # type: Any
): # type: (...) -> bool
"""Compare two Trie objects."""
return self.__class__ == other.__class__ and self.data == other.data

def __add__(
self, other # type: "Trie"
): # type: (...) -> "Trie"
"""Merge two Trie objects."""
new_trie = Trie()
new_trie += self
new_trie += other
return new_trie

def __iadd__(
self,
other, # type: "Trie"
): # type: (...) -> "Trie"
"""Merge another Trie object into the current Trie."""
if self.__class__ != other.__class__:
raise TypeError(
"Unsupported operand type(s) for +=: '{0}' and '{1}'".format(
type(self), type(other)
)
)
self._merge_subtrie(self.data, deepcopy(other.data))
return self

@classmethod
def _merge_subtrie(
cls,
current_subtrie, # type: data_type
other_subtrie, # type: data_type
): # type: (...) -> None
"""Recursively merge subtrie data."""
for key, value in other_subtrie.items():
if key in current_subtrie:
cls._merge_subtrie(current_subtrie[key], value)
else:
current_subtrie[key] = value

def add(
self, *word # type: Text
):
): # type: (...) -> "Trie"
"""Add one or more words to the current Trie."""
for word in word:
ref = self.data
for char in word:
ref[char] = ref.get(char, {})
ref = ref[char]
ref[""] = {}
return self

def dump(self): # type: (...) -> data_type
"""Dump the current trie as dictionary."""
Expand Down
29 changes: 29 additions & 0 deletions tests/test_trie.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from retrie.trie import Trie


Expand All @@ -22,3 +24,30 @@ def test_trie():

trie.add("fo")
assert trie.pattern() == "(?:ab(?:c|s(?:olute)?|xy?)|fo[eo]?)"

trie = Trie()
trie += Trie("abc")
assert trie.pattern() == "abc"
assert (
trie + Trie().add("foo")
== Trie("abc", "foo")
== Trie("abc") + Trie("foo")
== Trie("foo") + Trie("abc")
== Trie(*["abc", "foo"])
== Trie().add(*["abc", "foo"])
== Trie().add("abc", "foo")
== Trie().add("abc").add("foo")
)
assert trie != object
with pytest.raises(TypeError):
trie += None

assert Trie() + Trie() == Trie()
assert Trie("a", "b", "c").pattern() == "[abc]"
assert Trie("abs") + Trie("absolute") != Trie("absolute")

trie1, trie2 = Trie(), Trie("abc")
trie1 += trie2
assert trie1.data["a"] is not trie2.data["a"]
trie2.data["a"]["b"] = {"s": {"": {}}}
assert trie1 != trie2

0 comments on commit f7d39b9

Please sign in to comment.