Skip to content

Instantly share code, notes, and snippets.

@csjh
Last active December 4, 2025 14:24
Show Gist options
  • Select an option

  • Save csjh/a70d0b0f2508a6e3d9a80eaf3213298a to your computer and use it in GitHub Desktop.

Select an option

Save csjh/a70d0b0f2508a6e3d9a80eaf3213298a to your computer and use it in GitHub Desktop.
regex compiler
from typing import Callable, Union
# ascii is [0, 128)
MAX_CHAR = 0x80
EPSILON = 0
class RegEx:
# dummy reference so mmap isn't freed before we're done
coderef: Buffer
func: Callable[[str], bool]
def __init__(self, regex: str):
self.coderef, self.func = compile(regex)
def test(self, s: str) -> bool:
return self.func(s)
def compile(regex: str) -> tuple[Buffer, Callable[[str], bool]]:
ast: AST = parse(regex)
nfa: NFA = nfaize(ast)
dfa: DFA = dfaize(nfa)
return jit(dfa)
def interpret(regex: str, s: str) -> bool:
ast: AST = parse(regex)
nfa: NFA = nfaize(ast)
transitions, start, acceptors = dfaize(nfa)
q = start
for c in map(ord, s):
q = transitions[q].get(c)
if q is None:
return False
return q in acceptors
# region Parsing
from dataclasses import dataclass
@dataclass
class kleene:
starred: "AST"
@dataclass
class union:
lhs: "AST"
rhs: "AST"
@dataclass
class concat:
lhs: "AST"
rhs: "AST"
class wildcard:
pass
AST = int | concat | union | kleene | wildcard
def parse(regex: str) -> AST:
word = peekable(regex)
def parse_C() -> AST:
if word.peek() == "\\":
next(word)
if word.exhausted():
raise Exception("invalid return sequence")
c = next(word)
if c == "x":
c1, c2 = next(word), next(word)
if c1 is None or c2 is None:
raise Exception("invalid hex escape sequence")
hexnum = int(c1 + c2, 16)
if hexnum == 0:
raise Exception("cannot match on null character")
if hexnum > 127:
raise Exception("cannot match non-ascii characters")
return hexnum
else:
return ord(c)
elif word.peek() == ".":
next(word)
return wildcard()
elif (c := word.peek()) and c not in ("(", ")", "*", "+", "\0"):
next(word)
return ord(c)
return EPSILON
def parse_G() -> AST:
if word.peek() == "(":
next(word)
s = parse_S()
if next(word) != ")":
raise Exception("missing closing bracket on regex group")
return s
else:
return parse_C()
def parse_K() -> AST:
g = parse_G()
if word.peek() == "*":
next(word)
return kleene(g)
else:
return g
def parse_R() -> AST:
lhs = parse_K()
if not word.exhausted() and lhs != EPSILON:
rhs = parse_R()
if rhs != EPSILON:
return concat(lhs, rhs)
return lhs
def parse_S() -> AST:
lhs = parse_R()
if word.peek() == "+":
next(word)
rhs = parse_S()
return union(lhs, rhs)
return lhs
return parse_S()
# endregion
# region IR / Optimization
from collections import defaultdict
from functools import cache, reduce
from operator import or_
# implicit: alphabet [0, 255]
# FA = (transitions, start_state, accept_states)
# transitions is specifying all non-sink state transitions (transitions[i] = transitions for state i)
# sink state is the state where all transitions go to itself
NFA = tuple[list[dict[int, list[int]]], set[int], set[int]]
DFA = tuple[list[dict[int, int]], int, set[int]]
def nfaize(ast: AST) -> NFA:
transitions: list[dict[int, list[int]]] = []
def traverse(node: AST) -> tuple[int, set[int]]:
match node:
case wildcard():
start = len(transitions)
transitions.append(defaultdict(list))
end = len(transitions)
transitions.append(defaultdict(list))
for c in range(1, MAX_CHAR):
transitions[start][c].append(end)
return start, {end}
case int(n):
start = len(transitions)
transitions.append(defaultdict(list))
end = len(transitions)
transitions.append(defaultdict(list))
transitions[start][n].append(end)
return start, {end}
case concat(lhs, rhs):
lhs_start, lhs_acceptors = traverse(lhs)
rhs_start, rhs_acceptors = traverse(rhs)
for acceptor in lhs_acceptors:
transitions[acceptor][EPSILON].append(rhs_start)
return lhs_start, rhs_acceptors
case union(lhs, rhs):
start = len(transitions)
transitions.append(defaultdict(list))
lhs_start, lhs_acceptors = traverse(lhs)
rhs_start, rhs_acceptors = traverse(rhs)
transitions[start][EPSILON].append(lhs_start)
transitions[start][EPSILON].append(rhs_start)
return start, lhs_acceptors | rhs_acceptors
case kleene(starred):
start, acceptors = traverse(starred)
new_start = len(transitions)
transitions.append(defaultdict(list))
transitions[new_start][EPSILON].append(start)
for acceptor in acceptors:
transitions[acceptor][EPSILON].append(new_start)
return new_start, {new_start}
start, acceptors = traverse(ast)
return transitions, {start}, acceptors
def reverse(nfa: NFA) -> NFA:
transitions, start, acceptors = nfa
new_transitions = [defaultdict(list) for _ in transitions]
for src, mapping in enumerate(transitions):
for sym, destinations in mapping.items():
for dst in destinations:
new_transitions[dst][sym].append(src)
return new_transitions, acceptors, start
def reachable_powerset(nfa: NFA) -> NFA:
transitions, start, acceptors = nfa
acceptor_mask = reduce(or_, map(lambda x: 1 << x, acceptors))
# maps bitsets to serial id
set_to_state_id: dict[int, int] = defaultdict(
lambda: new_transitions.append(defaultdict(list)) or len(set_to_state_id)
)
new_transitions: list[dict[str, list[int]]] = []
new_acceptors: set[int] = set()
@cache
def _epsilon_closure(n: int):
epsilon_seen = set()
q = [n]
r = 0
while q:
n = q.pop()
if n in epsilon_seen:
continue
epsilon_seen.add(n)
r |= 1 << n
q.extend(transitions[n][EPSILON])
return r
def epsilon_closure(states: set[int]):
return reduce(or_, map(_epsilon_closure, states))
new_start = epsilon_closure(start)
seen = set()
q: list[int] = [new_start]
while q:
n = q.pop()
if n in seen:
continue
seen.add(n)
if n & acceptor_mask:
new_acceptors.add(set_to_state_id[n])
for c in range(1, MAX_CHAR):
results = set()
for i in range(n.bit_length()):
if n & (1 << i):
results.update(transitions[i][c])
if results:
new_state = epsilon_closure(results)
new_transitions[set_to_state_id[n]][c].append(
set_to_state_id[new_state]
)
q.append(new_state)
return new_transitions, {set_to_state_id[new_start]}, new_acceptors
def dfaize(nfa: NFA) -> DFA:
transitions, start, acceptors = reachable_powerset(
reverse(reachable_powerset(reverse(nfa)))
)
return (
[{k: v[0] for k, v in d.items()} for d in transitions],
start.pop(),
acceptors,
)
# endregion
# region JIT Compilation
import cffi
from io import BytesIO
from itertools import pairwise
def jit(dfa: DFA) -> tuple[Buffer, Callable[[str], bool]]:
transitions, start, acceptors = dfa
writer = BytesIO()
starts = [0] * len(transitions)
pending: list[tuple[int, int]] = []
def write_instruction(instr: int):
writer.write(instr.to_bytes(4, "little"))
def bcond(offset: int, cond: int):
assert offset % 4 == 0
offset //= 4
imm19 = (offset + 0x80000) & 0x7FFFF
write_instruction(0b01010100000000000000000000000000 | (imm19 << 5) | cond)
def b(offset: int):
assert offset % 4 == 0
offset //= 4
imm26 = (offset + 0x4000000) & 0x3FFFFFF
write_instruction(0b00010100000000000000000000000000 | imm26)
def conditional_jump(state: int, cond: int):
if starts[state] == 0:
bcond(2 * 4, cond ^ 1)
pending.append((state, writer.tell()))
b(0)
else:
offset = starts[state] - writer.tell()
# check if jump fits in 19 bit immediate
if -0x100000 <= offset < 0xFFFFC:
bcond(offset, cond)
else:
raise Exception("regex compilation error")
def single_jump(state: int, char: int):
# cmp x1, #char
write_instruction(0b1111000100000000000000_00001_11111 | (char << 10))
conditional_jump(state, 0b0000) # if x1 == char, jump to state
def range_jump(state: int, start_char: int, length: int):
# sub x2, x1, #start_char
write_instruction(0b1101000100000000000000_00001_00010 | (start_char << 10))
# cmp x2, #length
write_instruction(0b1111000100000000000000_00010_11111 | (length << 10))
conditional_jump(state, 0b0011) # if x2 < length, jump to state
def write_handler(state: int):
starts[state] = writer.tell()
# ldrb x1, [x0], #1
writer.write(b"\x01\x14\x40\x38")
if state in acceptors:
# cbnz x1, .+8
writer.write(b"\x61\x00\x00\xb5")
# mov x0, #1
writer.write(b"\x20\x00\x80\xd2")
# ret
writer.write(b"\xc0\x03\x5f\xd6")
connections = transitions[state]
chars = list(connections.keys())
assert chars == sorted(chars)
# len(chars) == 0 is possible for acceptor states wiht no out-edges
if len(chars) > 0:
# categorize characters into singles or ranges
streaks = []
start, streak = chars[0], 1
for c1, c2 in pairwise(chars):
if c2 - c1 == 1 and connections[c1] == connections[c2]:
streak += 1
else:
streaks.append((start, streak))
start, streak = c2, 1
streaks.append((start, streak))
for c, s in streaks:
if s == 1:
single_jump(connections[c], c)
else:
range_jump(connections[c], c, s)
# mov x0, #0
writer.write(b"\x00\x00\x80\xd2")
# ret
writer.write(b"\xc0\x03\x5f\xd6")
for i in reversed(range(len(transitions))):
write_handler(i)
for state, pos in pending:
writer.seek(pos)
b(starts[state] - pos)
code = to_executable_buffer(writer)
ffi = cffi.FFI()
c_buf = ffi.from_buffer("char *", code) + starts[start]
c_ptr = ffi.cast("bool (*)(const char *)", c_buf)
return code, lambda s: c_ptr(ffi.new("char[]", s.encode()))
# below this isn't very relevant to the class, just code generation boilerplate
from collections.abc import Buffer
import ctypes
import mmap
libc = ctypes.CDLL(None)
libc.mprotect.restype = ctypes.c_int
libc.mprotect.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int]
libc.sys_icache_invalidate.restype = None
libc.sys_icache_invalidate.argtypes = [ctypes.c_void_p, ctypes.c_size_t]
from io import SEEK_END
def to_executable_buffer(buf: BytesIO) -> Buffer:
buf.flush()
buf.seek(0, SEEK_END)
code = mmap.mmap(
-1,
buf.tell(),
flags=mmap.MAP_PRIVATE | mmap.MAP_ANON | mmap.MAP_JIT,
prot=mmap.PROT_READ | mmap.PROT_WRITE,
)
code.write(buf.getvalue())
code.flush()
addr = ctypes.addressof(ctypes.c_char.from_buffer(code))
if libc.mprotect(addr, len(code), mmap.PROT_READ | mmap.PROT_EXEC) != 0:
raise Exception("mprotect failed")
libc.sys_icache_invalidate(addr, len(code))
return code
class peekable:
def __init__(self, iterable):
self.iterable = iter(iterable)
self.peeked = None
def exhausted(self):
return self.peek() is None
def peek(self):
if self.peeked is None:
self.peeked = next(self.iterable, None)
return self.peeked
def __next__(self):
if self.peeked is not None:
result = self.peeked
self.peeked = None
return result
else:
return next(self.iterable)
# endregion
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment