Factor out utility methods

This commit is contained in:
Armin Friedl 2020-10-10 10:11:34 +02:00
parent 92077efb43
commit aace705f2e
2 changed files with 32 additions and 30 deletions

View file

@ -1,6 +1,8 @@
from __future__ import annotations from __future__ import annotations
from typing import Sequence, MutableSequence, ByteString, Any, Optional from typing import Sequence, MutableSequence, ByteString, Any, Optional
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from .util import (has_common_prefix, common_prefix, is_prefix_of, find_first,
cut_off_prefix)
import logging import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -97,36 +99,6 @@ class ByteTrie:
def to_dot(self) -> str: def to_dot(self) -> str:
return "graph {\n\n"+self.root.to_dot()+"\n}" return "graph {\n\n"+self.root.to_dot()+"\n}"
def has_common_prefix(label: ByteString, other_label: ByteString) -> bool:
""" Whether label and other_label have a prefix in common. """
assert label and other_label
return True if label[0] == other_label[0] else False
def common_prefix(label: ByteString, other_label: ByteString) -> ByteString:
""" Get the common prefix of label and other_label. """
buffer = bytearray()
for (a,b) in zip(label, other_label):
if a == b: buffer.append(a)
else: break
return buffer
def is_prefix_of(prefix: ByteString, label: ByteString) -> bool:
""" Whether label starts with prefix """
if len(prefix) > len(label):
return False
for (a,b) in zip(prefix, label):
if a != b: return False
return True
def find_first(predicate, iterable):
""" Return the first element in iterable that satisfies predicate or None """
try: return next(filter(predicate, iterable))
except StopIteration: return None
def cut_off_prefix(prefix: ByteString, label: ByteString) -> ByteString:
""" Cut prefix from start of label. Return rest of label. """
assert is_prefix_of(prefix, label)
return bytes(label[len(prefix):])
class Node(ABC): class Node(ABC):
def __init__(self, children: MutableSequence[Child]): def __init__(self, children: MutableSequence[Child]):

30
bytetrie/util.py Normal file
View file

@ -0,0 +1,30 @@
def has_common_prefix(label: ByteString, other_label: ByteString) -> bool:
""" Whether label and other_label have a prefix in common. """
assert label and other_label
return True if label[0] == other_label[0] else False
def common_prefix(label: ByteString, other_label: ByteString) -> ByteString:
""" Get the common prefix of label and other_label. """
buffer = bytearray()
for (a,b) in zip(label, other_label):
if a == b: buffer.append(a)
else: break
return buffer
def is_prefix_of(prefix: ByteString, label: ByteString) -> bool:
""" Whether label starts with prefix """
if len(prefix) > len(label):
return False
for (a,b) in zip(prefix, label):
if a != b: return False
return True
def find_first(predicate, iterable):
""" Return the first element in iterable that satisfies predicate or None """
try: return next(filter(predicate, iterable))
except StopIteration: return None
def cut_off_prefix(prefix: ByteString, label: ByteString) -> ByteString:
""" Cut prefix from start of label. Return rest of label. """
assert is_prefix_of(prefix, label)
return bytes(label[len(prefix):])