Last active
September 28, 2025 19:31
-
-
Save DarinM223/6963cacf6129ee9d20e54c0d99220ea0 to your computer and use it in GitHub Desktop.
Linear scan register allocation in Python
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import annotations # type: ignore | |
| from typing import * # type: ignore | |
| from dataclasses import dataclass, field | |
| @dataclass | |
| class Physical: | |
| register: int | |
| ReuseOperand = TypedDict("ReuseOperand", {"reuse": int}) | |
| @dataclass | |
| class Virtual: | |
| id: int | |
| constraint: Physical | ReuseOperand | None = None | |
| @dataclass | |
| class Immediate: | |
| value: int | |
| @dataclass | |
| class StackSlot: | |
| offset: int | |
| @dataclass(kw_only=True) | |
| class MemoryAddress: | |
| base: Physical | Virtual | None = None | |
| index: Physical | Virtual | None = None | |
| scale: int = 0 | |
| displacement: int = 0 | |
| Operand = Physical | Virtual | Immediate | StackSlot | MemoryAddress | |
| OperandPair = tuple[Operand, Literal["use", "def"]] | |
| @dataclass(eq=False) | |
| class Instruction: | |
| name: str | |
| operands: list[OperandPair] | |
| def __init__(self, name: str, *args: OperandPair): | |
| self.name = name | |
| self.operands = list(args) | |
| @dataclass(eq=False) | |
| class Block: | |
| id: int | |
| predecessors: list[Block] | |
| successors: list[Block] | |
| instructions: list[Instruction] | |
| @dataclass(eq=False) | |
| class Function: | |
| id: int | |
| params: list[Operand] | |
| blocks: list[Block] | |
| ALL_INTEGER_REGS = tuple(Physical(register) for register in range(0, 16)) | |
| RAX, RBX, RCX, RDX, RDI, RSI, RBP, RSP, R8, R9, R10, R11, R12, R13, R14, R15 = ( | |
| ALL_INTEGER_REGS | |
| ) | |
| def number_instructions(function: Function) -> dict[Instruction, int]: | |
| count = -1 | |
| return { | |
| instruction: (count := count + 1) | |
| for block in function.blocks | |
| for instruction in block.instructions | |
| } | |
| @dataclass(unsafe_hash=True) | |
| class Interval: | |
| start: int | |
| end: int | |
| weight: int = field(compare=False) | |
| # True if the interval already has a physical register before register allocation. | |
| fixed: bool = field(compare=False) | |
| reg: Physical | StackSlot | MemoryAddress | None = field(compare=False) | |
| def linear_scan(free: set[Physical], intervals: list[Interval]) -> None: | |
| """ | |
| Performs linear scan register allocation on the given function. | |
| After this function is called, all virtual registers will be replaced with physical registers. | |
| :param: all intervals in increasing order of their start points | |
| """ | |
| active: set[Interval] = set() | |
| inactive: set[Interval] = set() | |
| handled: set[Interval] = set() | |
| for unhandled_idx in range(0, len(intervals)): | |
| cur = intervals[unhandled_idx] | |
| unhandled = (intervals[idx] for idx in range(unhandled_idx + 1, len(intervals))) | |
| fixed_unhandled = (i for i in unhandled if i.fixed) | |
| # Check for active intervals that have expired: | |
| for i in active: | |
| if i.end < cur.start: | |
| active.remove(i) | |
| handled.add(i) | |
| if isinstance(i.reg, Physical): | |
| free.add(i.reg) | |
| elif not overlap(i, cur.start): | |
| active.remove(i) | |
| inactive.add(i) | |
| if isinstance(i.reg, Physical): | |
| free.add(i.reg) | |
| # Check for inactive intervals that expired or become reactivated: | |
| for i in inactive: | |
| if i.end < cur.start: | |
| inactive.remove(i) | |
| handled.add(i) | |
| elif overlap(i, cur.start): | |
| inactive.remove(i) | |
| active.add(i) | |
| if isinstance(i.reg, Physical): | |
| free.remove(i.reg) | |
| # Collect available registers in f: | |
| f = free.copy() | |
| for i in inactive: | |
| if overlap(i, cur) and isinstance(i.reg, Physical): | |
| f.remove(i.reg) | |
| for i in fixed_unhandled: | |
| if overlap(i, cur) and isinstance(i.reg, Physical): | |
| f.remove(i.reg) | |
| # Select a register from f: | |
| if f: | |
| if not isinstance(cur.reg, Physical): | |
| cur.reg = next(iter(f)) | |
| free.remove(cur.reg) | |
| active.add(cur) | |
| else: | |
| assign_mem_loc(active, inactive, handled, set(fixed_unhandled), cur) | |
| stack_offset = 0 | |
| def new_memory_location() -> StackSlot: | |
| global stack_offset | |
| result = StackSlot(stack_offset) | |
| stack_offset -= 8 | |
| return result | |
| def assign_mem_loc( | |
| active: set[Interval], | |
| inactive: set[Interval], | |
| handled: set[Interval], | |
| fixed_unhandled: set[Interval], | |
| cur: Interval, | |
| ) -> None: | |
| w: dict[Physical | StackSlot | MemoryAddress, int] = { | |
| register: 0 for register in ALL_INTEGER_REGS | |
| } | |
| for i in active.union(inactive).union(fixed_unhandled): | |
| if overlap(i, cur) and i.reg is not None: | |
| w[i.reg] += i.weight | |
| r = min(w, key=lambda k: w[k]) | |
| if cur.weight < w[r]: | |
| # assign a memory location to cur | |
| cur.reg = new_memory_location() | |
| handled.add(cur) | |
| else: | |
| # assign memory locations to the intervals occupied by r | |
| # move all active or inactive intervals to which r was assigned to handled and | |
| # assign memory locations to them | |
| for i in active: | |
| if i.reg == r: | |
| i.reg = new_memory_location() | |
| active.remove(i) | |
| handled.add(i) | |
| for i in inactive: | |
| if i.reg == r: | |
| i.reg = new_memory_location() | |
| inactive.remove(i) | |
| handled.add(i) | |
| cur.reg = r | |
| active.add(cur) | |
| def overlap(i: Interval, j: int | Interval) -> bool: | |
| match j: | |
| case int(): | |
| return i.start <= j <= i.end | |
| case Interval(): | |
| return max(i.start, j.start) <= min(i.end, j.end) | |
| instruction = Instruction( | |
| "add", | |
| (Virtual(0), "use"), | |
| (MemoryAddress(base=Virtual(1), displacement=10), "use"), | |
| (Virtual(2, {"reuse": 0}), "def"), | |
| ) | |
| parallel_move = Instruction( | |
| "parmov", | |
| (Virtual(3), "use"), | |
| (Virtual(4), "use"), | |
| (Virtual(5, RCX), "def"), | |
| (Virtual(6, RDX), "def"), | |
| ) | |
| phi = Instruction("phi", (Virtual(7), "use"), (Virtual(8), "use"), (Virtual(9), "def")) | |
| print(f"Instructions: {[instruction, parallel_move, phi]}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment