Last active
December 4, 2025 14:24
-
-
Save csjh/a70d0b0f2508a6e3d9a80eaf3213298a to your computer and use it in GitHub Desktop.
regex compiler
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from typing import Callable, Union | |
| # ascii is [0, 128) | |
| MAX_CHAR = 0x80 | |
| EPSILON = 0 | |
| class RegEx: | |
| # dummy reference so mmap isn't freed before we're done | |
| coderef: Buffer | |
| func: Callable[[str], bool] | |
| def __init__(self, regex: str): | |
| self.coderef, self.func = compile(regex) | |
| def test(self, s: str) -> bool: | |
| return self.func(s) | |
| def compile(regex: str) -> tuple[Buffer, Callable[[str], bool]]: | |
| ast: AST = parse(regex) | |
| nfa: NFA = nfaize(ast) | |
| dfa: DFA = dfaize(nfa) | |
| return jit(dfa) | |
| def interpret(regex: str, s: str) -> bool: | |
| ast: AST = parse(regex) | |
| nfa: NFA = nfaize(ast) | |
| transitions, start, acceptors = dfaize(nfa) | |
| q = start | |
| for c in map(ord, s): | |
| q = transitions[q].get(c) | |
| if q is None: | |
| return False | |
| return q in acceptors | |
| # region Parsing | |
| from dataclasses import dataclass | |
| @dataclass | |
| class kleene: | |
| starred: "AST" | |
| @dataclass | |
| class union: | |
| lhs: "AST" | |
| rhs: "AST" | |
| @dataclass | |
| class concat: | |
| lhs: "AST" | |
| rhs: "AST" | |
| class wildcard: | |
| pass | |
| AST = int | concat | union | kleene | wildcard | |
| def parse(regex: str) -> AST: | |
| word = peekable(regex) | |
| def parse_C() -> AST: | |
| if word.peek() == "\\": | |
| next(word) | |
| if word.exhausted(): | |
| raise Exception("invalid return sequence") | |
| c = next(word) | |
| if c == "x": | |
| c1, c2 = next(word), next(word) | |
| if c1 is None or c2 is None: | |
| raise Exception("invalid hex escape sequence") | |
| hexnum = int(c1 + c2, 16) | |
| if hexnum == 0: | |
| raise Exception("cannot match on null character") | |
| if hexnum > 127: | |
| raise Exception("cannot match non-ascii characters") | |
| return hexnum | |
| else: | |
| return ord(c) | |
| elif word.peek() == ".": | |
| next(word) | |
| return wildcard() | |
| elif (c := word.peek()) and c not in ("(", ")", "*", "+", "\0"): | |
| next(word) | |
| return ord(c) | |
| return EPSILON | |
| def parse_G() -> AST: | |
| if word.peek() == "(": | |
| next(word) | |
| s = parse_S() | |
| if next(word) != ")": | |
| raise Exception("missing closing bracket on regex group") | |
| return s | |
| else: | |
| return parse_C() | |
| def parse_K() -> AST: | |
| g = parse_G() | |
| if word.peek() == "*": | |
| next(word) | |
| return kleene(g) | |
| else: | |
| return g | |
| def parse_R() -> AST: | |
| lhs = parse_K() | |
| if not word.exhausted() and lhs != EPSILON: | |
| rhs = parse_R() | |
| if rhs != EPSILON: | |
| return concat(lhs, rhs) | |
| return lhs | |
| def parse_S() -> AST: | |
| lhs = parse_R() | |
| if word.peek() == "+": | |
| next(word) | |
| rhs = parse_S() | |
| return union(lhs, rhs) | |
| return lhs | |
| return parse_S() | |
| # endregion | |
| # region IR / Optimization | |
| from collections import defaultdict | |
| from functools import cache, reduce | |
| from operator import or_ | |
| # implicit: alphabet [0, 255] | |
| # FA = (transitions, start_state, accept_states) | |
| # transitions is specifying all non-sink state transitions (transitions[i] = transitions for state i) | |
| # sink state is the state where all transitions go to itself | |
| NFA = tuple[list[dict[int, list[int]]], set[int], set[int]] | |
| DFA = tuple[list[dict[int, int]], int, set[int]] | |
| def nfaize(ast: AST) -> NFA: | |
| transitions: list[dict[int, list[int]]] = [] | |
| def traverse(node: AST) -> tuple[int, set[int]]: | |
| match node: | |
| case wildcard(): | |
| start = len(transitions) | |
| transitions.append(defaultdict(list)) | |
| end = len(transitions) | |
| transitions.append(defaultdict(list)) | |
| for c in range(1, MAX_CHAR): | |
| transitions[start][c].append(end) | |
| return start, {end} | |
| case int(n): | |
| start = len(transitions) | |
| transitions.append(defaultdict(list)) | |
| end = len(transitions) | |
| transitions.append(defaultdict(list)) | |
| transitions[start][n].append(end) | |
| return start, {end} | |
| case concat(lhs, rhs): | |
| lhs_start, lhs_acceptors = traverse(lhs) | |
| rhs_start, rhs_acceptors = traverse(rhs) | |
| for acceptor in lhs_acceptors: | |
| transitions[acceptor][EPSILON].append(rhs_start) | |
| return lhs_start, rhs_acceptors | |
| case union(lhs, rhs): | |
| start = len(transitions) | |
| transitions.append(defaultdict(list)) | |
| lhs_start, lhs_acceptors = traverse(lhs) | |
| rhs_start, rhs_acceptors = traverse(rhs) | |
| transitions[start][EPSILON].append(lhs_start) | |
| transitions[start][EPSILON].append(rhs_start) | |
| return start, lhs_acceptors | rhs_acceptors | |
| case kleene(starred): | |
| start, acceptors = traverse(starred) | |
| new_start = len(transitions) | |
| transitions.append(defaultdict(list)) | |
| transitions[new_start][EPSILON].append(start) | |
| for acceptor in acceptors: | |
| transitions[acceptor][EPSILON].append(new_start) | |
| return new_start, {new_start} | |
| start, acceptors = traverse(ast) | |
| return transitions, {start}, acceptors | |
| def reverse(nfa: NFA) -> NFA: | |
| transitions, start, acceptors = nfa | |
| new_transitions = [defaultdict(list) for _ in transitions] | |
| for src, mapping in enumerate(transitions): | |
| for sym, destinations in mapping.items(): | |
| for dst in destinations: | |
| new_transitions[dst][sym].append(src) | |
| return new_transitions, acceptors, start | |
| def reachable_powerset(nfa: NFA) -> NFA: | |
| transitions, start, acceptors = nfa | |
| acceptor_mask = reduce(or_, map(lambda x: 1 << x, acceptors)) | |
| # maps bitsets to serial id | |
| set_to_state_id: dict[int, int] = defaultdict( | |
| lambda: new_transitions.append(defaultdict(list)) or len(set_to_state_id) | |
| ) | |
| new_transitions: list[dict[str, list[int]]] = [] | |
| new_acceptors: set[int] = set() | |
| @cache | |
| def _epsilon_closure(n: int): | |
| epsilon_seen = set() | |
| q = [n] | |
| r = 0 | |
| while q: | |
| n = q.pop() | |
| if n in epsilon_seen: | |
| continue | |
| epsilon_seen.add(n) | |
| r |= 1 << n | |
| q.extend(transitions[n][EPSILON]) | |
| return r | |
| def epsilon_closure(states: set[int]): | |
| return reduce(or_, map(_epsilon_closure, states)) | |
| new_start = epsilon_closure(start) | |
| seen = set() | |
| q: list[int] = [new_start] | |
| while q: | |
| n = q.pop() | |
| if n in seen: | |
| continue | |
| seen.add(n) | |
| if n & acceptor_mask: | |
| new_acceptors.add(set_to_state_id[n]) | |
| for c in range(1, MAX_CHAR): | |
| results = set() | |
| for i in range(n.bit_length()): | |
| if n & (1 << i): | |
| results.update(transitions[i][c]) | |
| if results: | |
| new_state = epsilon_closure(results) | |
| new_transitions[set_to_state_id[n]][c].append( | |
| set_to_state_id[new_state] | |
| ) | |
| q.append(new_state) | |
| return new_transitions, {set_to_state_id[new_start]}, new_acceptors | |
| def dfaize(nfa: NFA) -> DFA: | |
| transitions, start, acceptors = reachable_powerset( | |
| reverse(reachable_powerset(reverse(nfa))) | |
| ) | |
| return ( | |
| [{k: v[0] for k, v in d.items()} for d in transitions], | |
| start.pop(), | |
| acceptors, | |
| ) | |
| # endregion | |
| # region JIT Compilation | |
| import cffi | |
| from io import BytesIO | |
| from itertools import pairwise | |
| def jit(dfa: DFA) -> tuple[Buffer, Callable[[str], bool]]: | |
| transitions, start, acceptors = dfa | |
| writer = BytesIO() | |
| starts = [0] * len(transitions) | |
| pending: list[tuple[int, int]] = [] | |
| def write_instruction(instr: int): | |
| writer.write(instr.to_bytes(4, "little")) | |
| def bcond(offset: int, cond: int): | |
| assert offset % 4 == 0 | |
| offset //= 4 | |
| imm19 = (offset + 0x80000) & 0x7FFFF | |
| write_instruction(0b01010100000000000000000000000000 | (imm19 << 5) | cond) | |
| def b(offset: int): | |
| assert offset % 4 == 0 | |
| offset //= 4 | |
| imm26 = (offset + 0x4000000) & 0x3FFFFFF | |
| write_instruction(0b00010100000000000000000000000000 | imm26) | |
| def conditional_jump(state: int, cond: int): | |
| if starts[state] == 0: | |
| bcond(2 * 4, cond ^ 1) | |
| pending.append((state, writer.tell())) | |
| b(0) | |
| else: | |
| offset = starts[state] - writer.tell() | |
| # check if jump fits in 19 bit immediate | |
| if -0x100000 <= offset < 0xFFFFC: | |
| bcond(offset, cond) | |
| else: | |
| raise Exception("regex compilation error") | |
| def single_jump(state: int, char: int): | |
| # cmp x1, #char | |
| write_instruction(0b1111000100000000000000_00001_11111 | (char << 10)) | |
| conditional_jump(state, 0b0000) # if x1 == char, jump to state | |
| def range_jump(state: int, start_char: int, length: int): | |
| # sub x2, x1, #start_char | |
| write_instruction(0b1101000100000000000000_00001_00010 | (start_char << 10)) | |
| # cmp x2, #length | |
| write_instruction(0b1111000100000000000000_00010_11111 | (length << 10)) | |
| conditional_jump(state, 0b0011) # if x2 < length, jump to state | |
| def write_handler(state: int): | |
| starts[state] = writer.tell() | |
| # ldrb x1, [x0], #1 | |
| writer.write(b"\x01\x14\x40\x38") | |
| if state in acceptors: | |
| # cbnz x1, .+8 | |
| writer.write(b"\x61\x00\x00\xb5") | |
| # mov x0, #1 | |
| writer.write(b"\x20\x00\x80\xd2") | |
| # ret | |
| writer.write(b"\xc0\x03\x5f\xd6") | |
| connections = transitions[state] | |
| chars = list(connections.keys()) | |
| assert chars == sorted(chars) | |
| # len(chars) == 0 is possible for acceptor states wiht no out-edges | |
| if len(chars) > 0: | |
| # categorize characters into singles or ranges | |
| streaks = [] | |
| start, streak = chars[0], 1 | |
| for c1, c2 in pairwise(chars): | |
| if c2 - c1 == 1 and connections[c1] == connections[c2]: | |
| streak += 1 | |
| else: | |
| streaks.append((start, streak)) | |
| start, streak = c2, 1 | |
| streaks.append((start, streak)) | |
| for c, s in streaks: | |
| if s == 1: | |
| single_jump(connections[c], c) | |
| else: | |
| range_jump(connections[c], c, s) | |
| # mov x0, #0 | |
| writer.write(b"\x00\x00\x80\xd2") | |
| # ret | |
| writer.write(b"\xc0\x03\x5f\xd6") | |
| for i in reversed(range(len(transitions))): | |
| write_handler(i) | |
| for state, pos in pending: | |
| writer.seek(pos) | |
| b(starts[state] - pos) | |
| code = to_executable_buffer(writer) | |
| ffi = cffi.FFI() | |
| c_buf = ffi.from_buffer("char *", code) + starts[start] | |
| c_ptr = ffi.cast("bool (*)(const char *)", c_buf) | |
| return code, lambda s: c_ptr(ffi.new("char[]", s.encode())) | |
| # below this isn't very relevant to the class, just code generation boilerplate | |
| from collections.abc import Buffer | |
| import ctypes | |
| import mmap | |
| libc = ctypes.CDLL(None) | |
| libc.mprotect.restype = ctypes.c_int | |
| libc.mprotect.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int] | |
| libc.sys_icache_invalidate.restype = None | |
| libc.sys_icache_invalidate.argtypes = [ctypes.c_void_p, ctypes.c_size_t] | |
| from io import SEEK_END | |
| def to_executable_buffer(buf: BytesIO) -> Buffer: | |
| buf.flush() | |
| buf.seek(0, SEEK_END) | |
| code = mmap.mmap( | |
| -1, | |
| buf.tell(), | |
| flags=mmap.MAP_PRIVATE | mmap.MAP_ANON | mmap.MAP_JIT, | |
| prot=mmap.PROT_READ | mmap.PROT_WRITE, | |
| ) | |
| code.write(buf.getvalue()) | |
| code.flush() | |
| addr = ctypes.addressof(ctypes.c_char.from_buffer(code)) | |
| if libc.mprotect(addr, len(code), mmap.PROT_READ | mmap.PROT_EXEC) != 0: | |
| raise Exception("mprotect failed") | |
| libc.sys_icache_invalidate(addr, len(code)) | |
| return code | |
| class peekable: | |
| def __init__(self, iterable): | |
| self.iterable = iter(iterable) | |
| self.peeked = None | |
| def exhausted(self): | |
| return self.peek() is None | |
| def peek(self): | |
| if self.peeked is None: | |
| self.peeked = next(self.iterable, None) | |
| return self.peeked | |
| def __next__(self): | |
| if self.peeked is not None: | |
| result = self.peeked | |
| self.peeked = None | |
| return result | |
| else: | |
| return next(self.iterable) | |
| # endregion |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment