301 lines
11 KiB
Python
301 lines
11 KiB
Python
|
"""Radix Trie with radix 256
|
||
|
|
||
|
A Radix Trie[1] - once built - allows efficient prefix search. The trie works
|
||
|
on byte strings and hence is oblivious to encoding. The encoding for creation
|
||
|
and search must match. Payload of each node can be an arbitrary object.
|
||
|
|
||
|
Usage
|
||
|
-----
|
||
|
.. code :: python
|
||
|
t = Trie()
|
||
|
t.add("Hello", "P1")
|
||
|
t.add("Hi", "P2")
|
||
|
t.add("Hela", "P3")
|
||
|
t.find("He") # ["P1", "P3"]
|
||
|
|
||
|
[1] https://en.wikipedia.org/wiki/Radix_tree
|
||
|
"""
|
||
|
|
||
|
from __future__ import annotations
|
||
|
from typing import Sequence, MutableSequence, ByteString, Any, Optional
|
||
|
from abc import ABC, abstractmethod
|
||
|
|
||
|
import logging
|
||
|
log = logging.getLogger(__name__)
|
||
|
|
||
|
class Trie:
|
||
|
def __init__(self, multi_value=False):
|
||
|
self.root = Root([])
|
||
|
self.multi_value = multi_value
|
||
|
|
||
|
def insert(self, label: ByteString, content: Any):
|
||
|
log.info(f"Inserting {label} into Trie")
|
||
|
start = self.root.child_by_common_prefix(label)
|
||
|
if not start:
|
||
|
log.debug(f"Creating new terminal for {label} at root")
|
||
|
new_node = Terminal(label, content, self.root, [], self.multi_value)
|
||
|
self.root.put_child(new_node)
|
||
|
return new_node
|
||
|
log.debug(f"Found match {start} for {label}. Traversing down")
|
||
|
self._insert(start, label, content)
|
||
|
|
||
|
def _insert(self, node, label, content):
|
||
|
log.info(f"Inserting {label} into Trie at {node}")
|
||
|
if node.has_label(label):
|
||
|
log.debug(f"{node} equals {label}. Wrapping node as Terminal.")
|
||
|
if isinstance(node, Terminal) and not self.multi_value:
|
||
|
log.warning(f"{node} is already a Terminal. Content will be overwritten.")
|
||
|
terminal = Terminal.from_child(node, content, self.multi_value)
|
||
|
node.replace_with(terminal)
|
||
|
return terminal
|
||
|
|
||
|
if node.is_prefix_of(label):
|
||
|
log.debug(f"{node} is prefix of {label}")
|
||
|
cutoff = node.cut_from(label)
|
||
|
next_node = node.child_by_common_prefix(cutoff)
|
||
|
if not next_node:
|
||
|
log.debug(f"No matching child found for {cutoff}. Creating new child terminal.")
|
||
|
terminal = Terminal(cutoff, content, node, [], self.multi_value)
|
||
|
node.put_child(terminal)
|
||
|
return terminal
|
||
|
else:
|
||
|
log.debug(f"Found match {next_node} for {cutoff}. Traversing down.")
|
||
|
return self._insert(next_node, cutoff, content)
|
||
|
|
||
|
if node.starts_with(label):
|
||
|
log.debug(f"{label} is part of {node}. Creating new parent from {label}")
|
||
|
new_node = Terminal(label, content, node.parent, [], self.multi_value)
|
||
|
node.replace_with(new_node)
|
||
|
node.strip_prefix(label)
|
||
|
new_node.put_child(node)
|
||
|
return new_node
|
||
|
|
||
|
log.debug(f"{label} and {node} have a common ancestor")
|
||
|
common_prefix = node.common_prefix(label)
|
||
|
log.debug(f"Creating new ancestor for {common_prefix}")
|
||
|
ancestor = Child(common_prefix, node.parent, [])
|
||
|
node.replace_with(ancestor)
|
||
|
terminal = Terminal(cut_off_prefix(common_prefix, label), content, ancestor, [], self.multi_value)
|
||
|
node.strip_prefix(common_prefix)
|
||
|
ancestor.put_child(terminal)
|
||
|
ancestor.put_child(node)
|
||
|
return terminal
|
||
|
|
||
|
def find(self, prefix):
|
||
|
node = self._find(self.root, prefix)
|
||
|
return self._get_terminals(node, prefix)
|
||
|
|
||
|
def _find(self, node, prefix, collector=""):
|
||
|
cutoff = node.cut_from(prefix)
|
||
|
log.debug(f"Searching for {cutoff} in {node}")
|
||
|
child = node.child_by_prefix_match(cutoff)
|
||
|
if not child and not cutoff:
|
||
|
return node
|
||
|
elif not child and cutoff:
|
||
|
log.debug(f"Leftover cutoff {cutoff}. Trying to find node with prefix {cutoff}")
|
||
|
child = node.child_by_common_prefix(cutoff)
|
||
|
if not child or not child.starts_with(cutoff):
|
||
|
return None
|
||
|
log.debug(f"Found child {child} starting with {cutoff}")
|
||
|
return child
|
||
|
else: # child must be not None
|
||
|
log.debug(f"Found node {child} in {node} for {cutoff}. Traversing down.")
|
||
|
return self._find(child, cutoff)
|
||
|
|
||
|
def _get_terminals(self, node, label_builder):
|
||
|
if not node: return []
|
||
|
|
||
|
collector = []
|
||
|
if isinstance(node, Terminal):
|
||
|
collector.append((node, label_builder))
|
||
|
for child in node.children:
|
||
|
l = child.extend(label_builder)
|
||
|
collector.extend(self._get_terminals(child, l))
|
||
|
return collector
|
||
|
|
||
|
def to_dot(self) -> str:
|
||
|
return "graph {\n\n"+self.root.to_dot()+"\n}"
|
||
|
|
||
|
def has_common_prefix(label: ByteString, other_label: ByteString) -> bool:
|
||
|
""" Whether label and other_label have a prefix in common. """
|
||
|
assert label and other_label
|
||
|
return True if label[0] == other_label[0] else False
|
||
|
|
||
|
def common_prefix(label: ByteString, other_label: ByteString) -> ByteString:
|
||
|
""" Get the common prefix of label and other_label. """
|
||
|
buffer = bytearray()
|
||
|
for (a,b) in zip(label, other_label):
|
||
|
if a == b: buffer.append(a)
|
||
|
else: break
|
||
|
return buffer
|
||
|
|
||
|
def is_prefix_of(prefix: ByteString, label: ByteString) -> bool:
|
||
|
""" Whether label starts with prefix """
|
||
|
if len(prefix) > len(label):
|
||
|
return False
|
||
|
for (a,b) in zip(prefix, label):
|
||
|
if a != b: return False
|
||
|
return True
|
||
|
|
||
|
def find_first(predicate, iterable):
|
||
|
""" Return the first element in iterable that satisfies predicate or None """
|
||
|
try: return next(filter(predicate, iterable))
|
||
|
except StopIteration: return None
|
||
|
|
||
|
def cut_off_prefix(prefix: ByteString, label: ByteString) -> ByteString:
|
||
|
""" Cut prefix from start of label. Return rest of label. """
|
||
|
assert is_prefix_of(prefix, label)
|
||
|
return bytes(label[len(prefix):])
|
||
|
|
||
|
class Node(ABC):
|
||
|
def __init__(self, children: MutableSequence[Child]):
|
||
|
self.children = children
|
||
|
|
||
|
def child_by_common_prefix(self, label: ByteString) -> Optional[Child]:
|
||
|
""" Return Child that has a common prefix with label if one exists. """
|
||
|
def by_common_prefix(child: Child):
|
||
|
return has_common_prefix(child.label, label)
|
||
|
return find_first(by_common_prefix, self.children)
|
||
|
|
||
|
def child_by_prefix_match(self, label: ByteString) -> Optional[Child]:
|
||
|
""" Return Child which label is a prefix of the given label if one exists. """
|
||
|
def by_prefix_match(child: Child):
|
||
|
return is_prefix_of(child.label, label)
|
||
|
return find_first(by_prefix_match, self.children)
|
||
|
|
||
|
def put_child(self, child: Child):
|
||
|
""" Put child into this node's children. Replacing existing children. """
|
||
|
if child in self.children:
|
||
|
log.warning(f"Replacing child {child.label}")
|
||
|
self.remove_child(child)
|
||
|
child.parent = self
|
||
|
self.children.append(child)
|
||
|
|
||
|
def replace_child(self, child: Child, replacement: Child):
|
||
|
""" Remove child from this node's children and add replacement. """
|
||
|
self.remove_child(child)
|
||
|
self.put_child(replacement)
|
||
|
|
||
|
def remove_child(self, child: Child):
|
||
|
""" Remove child from this node's children """
|
||
|
if not child in self.children:
|
||
|
log.warning(f"Trying to delete {child.label} but it does not exist.")
|
||
|
self.children.remove(child)
|
||
|
|
||
|
@abstractmethod
|
||
|
def dot_label(self) -> str:
|
||
|
""" Readable label for this node in a dot graph """
|
||
|
...
|
||
|
|
||
|
@abstractmethod
|
||
|
def dot_id(self) -> str:
|
||
|
""" Technical id for this node in a dot graph. Must be unique. """
|
||
|
...
|
||
|
|
||
|
@abstractmethod
|
||
|
def cut_from(self, label: ByteString) -> ByteString:
|
||
|
""" Cut off node's label considered as prefix from label. """
|
||
|
...
|
||
|
|
||
|
def to_dot(self) -> str:
|
||
|
s = f'{self.dot_id()} [label="{self.dot_label()}"]\n'
|
||
|
for child in self.children:
|
||
|
s += f"{self.dot_id()} -- {child.dot_id()}\n"
|
||
|
s += child.to_dot()
|
||
|
return s
|
||
|
|
||
|
class Root(Node):
|
||
|
def cut_from(self, label: ByteString) -> ByteString:
|
||
|
return label
|
||
|
|
||
|
def dot_label(self):
|
||
|
return "root"
|
||
|
|
||
|
def dot_id(self):
|
||
|
return "root"
|
||
|
|
||
|
class Child(Node):
|
||
|
def __init__(self, label: ByteString, parent: Node, children: MutableSequence[Child]):
|
||
|
self.label = label
|
||
|
self.parent = parent
|
||
|
self.children = children
|
||
|
|
||
|
def __eq__(self, other_child):
|
||
|
return (isinstance(other_child, Child)
|
||
|
and self.label == other_child.label)
|
||
|
|
||
|
def __hash__(self):
|
||
|
return hash(self.label)
|
||
|
|
||
|
def __str__(self):
|
||
|
return self.label.decode('utf-8', 'replace').replace('"', '\\"')
|
||
|
|
||
|
def dot_label(self):
|
||
|
return self.label.decode('utf-8', 'replace').replace('"', '\\"')
|
||
|
|
||
|
def dot_id(self):
|
||
|
return id(self)
|
||
|
|
||
|
def has_label(self, label):
|
||
|
return self.label == label
|
||
|
|
||
|
def is_prefix_of(self, label):
|
||
|
return is_prefix_of(self.label, label)
|
||
|
|
||
|
def replace_with(self, new_child: Child):
|
||
|
new_child.parent = self.parent
|
||
|
self.parent.replace_child(self, new_child)
|
||
|
|
||
|
def starts_with(self, label: ByteString) -> bool:
|
||
|
return is_prefix_of(label, self.label)
|
||
|
|
||
|
def cut_from(self, label: ByteString) -> ByteString:
|
||
|
""" Cut node's label from (start of) label """
|
||
|
return cut_off_prefix(self.label, label)
|
||
|
|
||
|
def strip_prefix(self, prefix: ByteString):
|
||
|
""" Cut off prefix from node's label """
|
||
|
self.label = cut_off_prefix(prefix, self.label)
|
||
|
|
||
|
def extend(self, label: ByteString) -> ByteString:
|
||
|
""" Extend label by node's label """
|
||
|
return bytes(label) + bytes(self.label)
|
||
|
|
||
|
def split_label_at(self, index):
|
||
|
return (self.label[:index], self.label[index:])
|
||
|
|
||
|
def contains(self, label):
|
||
|
if len(label) > len(self.label):
|
||
|
return False
|
||
|
for (a,b) in zip(self.label, label):
|
||
|
if a != b: return False
|
||
|
return True
|
||
|
|
||
|
def common_prefix(self, label):
|
||
|
return common_prefix(self.label, label)
|
||
|
|
||
|
class Terminal(Child):
|
||
|
def __init__(self, label: ByteString, content: Any, parent: Node, children: MutableSequence[Child], multi_value: bool):
|
||
|
super().__init__(label, parent, children)
|
||
|
self.multi_value = multi_value
|
||
|
self.content = [content] if multi_value else content
|
||
|
|
||
|
@classmethod
|
||
|
def from_child(cls, child: Child, content: Any, multi_value: bool):
|
||
|
# multi_value param has no effect if already a Terminal. I.e.
|
||
|
# from_child cannot change the multi-value stage of a child that
|
||
|
# is already a Terminal
|
||
|
if isinstance(child, Terminal) and child.multi_value:
|
||
|
# Create a new Terminal instance. Although not needed this is what is expected
|
||
|
# and compatible to the non-multi-value behaviour.
|
||
|
t = cls(child.label, content, child.parent, child.children, child.multi_value)
|
||
|
t.content.extend(child.content) # add back original content
|
||
|
return t
|
||
|
return cls(child.label, content, child.parent, child.children, multi_value)
|
||
|
|
||
|
def to_dot(self) -> str:
|
||
|
s = super().to_dot()
|
||
|
s += f"{self.dot_id()} [color=blue]\n"
|
||
|
return s
|
||
|
|