Skip to content

Commit

Permalink
Add typing annotations to Python scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
wismill committed Oct 30, 2024
1 parent cd02943 commit b0b5d5c
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 38 deletions.
66 changes: 39 additions & 27 deletions scripts/perfect_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,16 @@

from __future__ import absolute_import, division, print_function

import sys
import random
import shutil
import string
import subprocess
import shutil
import sys
import tempfile
from collections import defaultdict
from optparse import Values
from os.path import join
from typing import Any, Sequence, TypeVar

if sys.version_info[0] == 2:
from cStringIO import StringIO
Expand All @@ -109,14 +111,14 @@ class Graph(object):
the desired edge value (mod N).
"""

def __init__(self, N):
def __init__(self, N: int):
self.N = N # number of vertices

# maps a vertex number to the list of tuples (vertex, edge value)
# to which it is connected by edges.
self.adjacent = defaultdict(list)
self.adjacent: dict[int, list[tuple[int, int]]] = defaultdict(list)

def connect(self, vertex1, vertex2, edge_value):
def connect(self, vertex1: int, vertex2: int, edge_value: int) -> None:
"""
Connect 'vertex1' and 'vertex2' with an edge, with associated
value 'value'
Expand All @@ -125,7 +127,7 @@ def connect(self, vertex1, vertex2, edge_value):
self.adjacent[vertex1].append((vertex2, edge_value))
self.adjacent[vertex2].append((vertex1, edge_value))

def assign_vertex_values(self):
def assign_vertex_values(self) -> bool:
"""
Try to assign the vertex values, such that, for each edge, you can
add the values for the two vertices involved and get the desired
Expand All @@ -150,7 +152,7 @@ def assign_vertex_values(self):
self.vertex_values[root] = 0 # set arbitrarily to zero

# Stack of vertices to visit, a list of tuples (parent, vertex)
tovisit = [(None, root)]
tovisit: list[tuple[int | None, int]] = [(None, root)]
while tovisit:
parent, vertex = tovisit.pop()
visited[vertex] = True
Expand Down Expand Up @@ -184,7 +186,7 @@ def assign_vertex_values(self):
return True


class StrSaltHash(object):
class StrSaltHash:
"""
Random hash function generator.
Simple byte level hashing: each byte is multiplied to another byte from
Expand All @@ -194,11 +196,11 @@ class StrSaltHash(object):

chars = string.ascii_letters + string.digits

def __init__(self, N):
def __init__(self, N: int):
self.N = N
self.salt = ""

def __call__(self, key):
def __call__(self, key: Sequence[str]) -> int:
# XXX: xkbcommon modification: make the salt length a power of 2
# so that the % operation in the hash is fast.
while len(self.salt) < max(len(key), 32): # add more salt as necessary
Expand All @@ -216,18 +218,18 @@ def perfect_hash(key):
"""


class IntSaltHash(object):
class IntSaltHash:
"""
Random hash function generator.
Simple byte level hashing, each byte is multiplied in sequence to a table
containing random numbers, summed tp, and finally modulo NG is taken.
"""

def __init__(self, N):
self.N = N
self.salt = []
def __init__(self, N: int):
self.N: int = N
self.salt: list[int] = []

def __call__(self, key):
def __call__(self, key: Sequence[str]) -> int:
while len(self.salt) < len(key): # add more salt as necessary
self.salt.append(random.randint(1, self.N - 1))

Expand All @@ -246,7 +248,10 @@ def perfect_hash(key):
"""


def builtin_template(Hash):
H = TypeVar("H", StrSaltHash, IntSaltHash)


def builtin_template(Hash: type[H]) -> str:
return (
"""\
# =======================================================================
Expand All @@ -272,7 +277,9 @@ class TooManyInterationsError(Exception):
pass


def generate_hash(keys, Hash=StrSaltHash):
def generate_hash(
keys: list[str], Hash: type[H] = StrSaltHash
) -> tuple[H, H, list[int]]:
"""
Return hash functions f1 and f2, and G for a perfect minimal hash.
Input is an iterable of 'keys', whos indicies are the desired hash values.
Expand Down Expand Up @@ -349,17 +356,17 @@ def generate_hash(keys, Hash=StrSaltHash):


class Format(object):
def __init__(self, width=76, indent=4, delimiter=", "):
def __init__(self, width: int = 76, indent: int = 4, delimiter: str = ", "):
self.width = width
self.indent = indent
self.delimiter = delimiter

def print_format(self):
def print_format(self) -> None:
print("Format options:")
for name in "width", "indent", "delimiter":
print(" %s: %r" % (name, getattr(self, name)))

def __call__(self, data, quote=False):
def __call__(self, data: Any, quote: bool = False) -> str:
if not isinstance(data, (list, tuple)):
return str(data)

Expand All @@ -384,7 +391,12 @@ def __call__(self, data, quote=False):
return "\n".join(l.rstrip() for l in aux.getvalue().split("\n"))


def generate_code(keys, Hash=StrSaltHash, template=None, options=None):
def generate_code(
keys: list[str],
Hash: type = StrSaltHash,
template: str | None = None,
options: Values | None = None,
) -> str:
"""
Takes a list of key value pairs and inserts the generated parameter
lists into the 'template' string. 'Hash' is the random hash function
Expand Down Expand Up @@ -424,7 +436,7 @@ def generate_code(keys, Hash=StrSaltHash, template=None, options=None):
)


def read_table(filename, options):
def read_table(filename: str, options: Values) -> list[str]:
"""
Reads keys and desired hash value pairs from a file. If no column
for the hash value is specified, a sequence of hash values is generated,
Expand Down Expand Up @@ -455,7 +467,7 @@ def read_table(filename, options):
row = [col.strip() for col in line.split(options.splitby)]

try:
key = row[options.keycol - 1]
key: str = row[options.keycol - 1]
except IndexError:
sys.exit(
"%s:%d: Error: Cannot read key, not enough columns." % (filename, n + 1)
Expand All @@ -471,7 +483,7 @@ def read_table(filename, options):
return keys


def read_template(filename):
def read_template(filename: str) -> str:
if verbose:
print("Reading template from file `%s'" % filename)
try:
Expand All @@ -481,7 +493,7 @@ def read_template(filename):
sys.exit("Error: Could not open `%s' for reading." % filename)


def run_code(code):
def run_code(code: str) -> None:
tmpdir = tempfile.mkdtemp()
path = join(tmpdir, "t.py")
with open(path, "w") as fo:
Expand All @@ -494,7 +506,7 @@ def run_code(code):
shutil.rmtree(tmpdir)


def main():
def main() -> None:
from optparse import OptionParser

usage = "usage: %prog [options] KEYS_FILE [TMPL_FILE]"
Expand Down Expand Up @@ -642,7 +654,7 @@ def main():
parser.error("template filename does not contain 'tmpl'")

if options.hft == 1:
Hash = StrSaltHash
Hash: type = StrSaltHash
elif options.hft == 2:
Hash = IntSaltHash
else:
Expand Down
2 changes: 1 addition & 1 deletion scripts/update-headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def generate(
data: dict[str, Any],
root: Path,
file: Path,
):
) -> None:
"""Generate a file from its Jinja2 template"""
template_path = file.with_suffix(f"{file.suffix}.jinja")
template = env.get_template(str(template_path))
Expand Down
4 changes: 2 additions & 2 deletions scripts/update-message-registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class Example:
after: str | None

@classmethod
def parse(cls, entry: Any) -> Example:
def parse(cls, entry: dict[str, Any]) -> Example:
name = entry.get("name")
assert name, entry

Expand Down Expand Up @@ -89,7 +89,7 @@ class Entry:
"""

@classmethod
def parse(cls, entry: Any) -> Entry:
def parse(cls, entry: dict[str, Any]) -> Entry:
code = entry.get("code")
assert code is not None and isinstance(code, int) and code > 0, entry

Expand Down
16 changes: 8 additions & 8 deletions scripts/update-unicode.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,11 +621,11 @@ def __iadd__(self, x):
return NotImplemented

@classmethod
def from_singleton(cls, chunk: tuple[T, ...]):
def from_singleton(cls, chunk: tuple[T, ...]) -> Self:
return cls(data=chunk, offsets={chunk: 0})

@classmethod
def from_pair(cls, pair: DeltasPair):
def from_pair(cls, pair: DeltasPair) -> Self:
return cls(
data=pair.d1 + pair.d2[pair.overlap :],
offsets={
Expand All @@ -635,7 +635,7 @@ def from_pair(cls, pair: DeltasPair):
)

@classmethod
def from_iterable(cls, ts: Iterable[tuple[T, ...]]):
def from_iterable(cls, ts: Iterable[tuple[T, ...]]) -> Self:
return reduce(lambda s, t: s.add(t), ts, cls((), {}))

@classmethod
Expand Down Expand Up @@ -1228,7 +1228,7 @@ def stats(self, int_size) -> Stats:
offsets2_int_size=0,
)

@classmethod
@staticmethod
def test(cls):
c1 = (1, 2, 3, 4)
c2 = (2, 3)
Expand All @@ -1238,9 +1238,9 @@ def test(cls):
s += c2
s += c3
s += c4
groups = {c1: [0, 3], c2: [4], c3: [1, 2]}
a = cls.from_overlapped_sequences(s, groups)
assert a == cls(
groups: Groups[int] = {c1: [0, 3], c2: [4], c3: [1, 2]}
a = CompressedArray.from_overlapped_sequences(s, groups)
assert a == CompressedArray(
data=s.data, offsets=(0, 2, 2, 0, 1), chunk_offsets=s.offsets
), a

Expand Down Expand Up @@ -1275,7 +1275,7 @@ def test_compression(cls):
c4 = (3, 4, 5)
c5 = (0, 1, 2)
c6 = (2, 3, 5)
groups = {
groups: Groups[int] = {
c1: [0],
c2: [1],
c3: [2],
Expand Down

0 comments on commit b0b5d5c

Please sign in to comment.