Skip to content

Instantly share code, notes, and snippets.

@DarinM223
Last active September 28, 2025 19:31
Show Gist options
  • Select an option

  • Save DarinM223/6963cacf6129ee9d20e54c0d99220ea0 to your computer and use it in GitHub Desktop.

Select an option

Save DarinM223/6963cacf6129ee9d20e54c0d99220ea0 to your computer and use it in GitHub Desktop.
Linear scan register allocation in Python
from __future__ import annotations # type: ignore
from typing import * # type: ignore
from dataclasses import dataclass, field
@dataclass
class Physical:
register: int
ReuseOperand = TypedDict("ReuseOperand", {"reuse": int})
@dataclass
class Virtual:
id: int
constraint: Physical | ReuseOperand | None = None
@dataclass
class Immediate:
value: int
@dataclass
class StackSlot:
offset: int
@dataclass(kw_only=True)
class MemoryAddress:
base: Physical | Virtual | None = None
index: Physical | Virtual | None = None
scale: int = 0
displacement: int = 0
Operand = Physical | Virtual | Immediate | StackSlot | MemoryAddress
OperandPair = tuple[Operand, Literal["use", "def"]]
@dataclass(eq=False)
class Instruction:
name: str
operands: list[OperandPair]
def __init__(self, name: str, *args: OperandPair):
self.name = name
self.operands = list(args)
@dataclass(eq=False)
class Block:
id: int
predecessors: list[Block]
successors: list[Block]
instructions: list[Instruction]
@dataclass(eq=False)
class Function:
id: int
params: list[Operand]
blocks: list[Block]
ALL_INTEGER_REGS = tuple(Physical(register) for register in range(0, 16))
RAX, RBX, RCX, RDX, RDI, RSI, RBP, RSP, R8, R9, R10, R11, R12, R13, R14, R15 = (
ALL_INTEGER_REGS
)
def number_instructions(function: Function) -> dict[Instruction, int]:
count = -1
return {
instruction: (count := count + 1)
for block in function.blocks
for instruction in block.instructions
}
@dataclass(unsafe_hash=True)
class Interval:
start: int
end: int
weight: int = field(compare=False)
# True if the interval already has a physical register before register allocation.
fixed: bool = field(compare=False)
reg: Physical | StackSlot | MemoryAddress | None = field(compare=False)
def linear_scan(free: set[Physical], intervals: list[Interval]) -> None:
"""
Performs linear scan register allocation on the given function.
After this function is called, all virtual registers will be replaced with physical registers.
:param: all intervals in increasing order of their start points
"""
active: set[Interval] = set()
inactive: set[Interval] = set()
handled: set[Interval] = set()
for unhandled_idx in range(0, len(intervals)):
cur = intervals[unhandled_idx]
unhandled = (intervals[idx] for idx in range(unhandled_idx + 1, len(intervals)))
fixed_unhandled = (i for i in unhandled if i.fixed)
# Check for active intervals that have expired:
for i in active:
if i.end < cur.start:
active.remove(i)
handled.add(i)
if isinstance(i.reg, Physical):
free.add(i.reg)
elif not overlap(i, cur.start):
active.remove(i)
inactive.add(i)
if isinstance(i.reg, Physical):
free.add(i.reg)
# Check for inactive intervals that expired or become reactivated:
for i in inactive:
if i.end < cur.start:
inactive.remove(i)
handled.add(i)
elif overlap(i, cur.start):
inactive.remove(i)
active.add(i)
if isinstance(i.reg, Physical):
free.remove(i.reg)
# Collect available registers in f:
f = free.copy()
for i in inactive:
if overlap(i, cur) and isinstance(i.reg, Physical):
f.remove(i.reg)
for i in fixed_unhandled:
if overlap(i, cur) and isinstance(i.reg, Physical):
f.remove(i.reg)
# Select a register from f:
if f:
if not isinstance(cur.reg, Physical):
cur.reg = next(iter(f))
free.remove(cur.reg)
active.add(cur)
else:
assign_mem_loc(active, inactive, handled, set(fixed_unhandled), cur)
stack_offset = 0
def new_memory_location() -> StackSlot:
global stack_offset
result = StackSlot(stack_offset)
stack_offset -= 8
return result
def assign_mem_loc(
active: set[Interval],
inactive: set[Interval],
handled: set[Interval],
fixed_unhandled: set[Interval],
cur: Interval,
) -> None:
w: dict[Physical | StackSlot | MemoryAddress, int] = {
register: 0 for register in ALL_INTEGER_REGS
}
for i in active.union(inactive).union(fixed_unhandled):
if overlap(i, cur) and i.reg is not None:
w[i.reg] += i.weight
r = min(w, key=lambda k: w[k])
if cur.weight < w[r]:
# assign a memory location to cur
cur.reg = new_memory_location()
handled.add(cur)
else:
# assign memory locations to the intervals occupied by r
# move all active or inactive intervals to which r was assigned to handled and
# assign memory locations to them
for i in active:
if i.reg == r:
i.reg = new_memory_location()
active.remove(i)
handled.add(i)
for i in inactive:
if i.reg == r:
i.reg = new_memory_location()
inactive.remove(i)
handled.add(i)
cur.reg = r
active.add(cur)
def overlap(i: Interval, j: int | Interval) -> bool:
match j:
case int():
return i.start <= j <= i.end
case Interval():
return max(i.start, j.start) <= min(i.end, j.end)
instruction = Instruction(
"add",
(Virtual(0), "use"),
(MemoryAddress(base=Virtual(1), displacement=10), "use"),
(Virtual(2, {"reuse": 0}), "def"),
)
parallel_move = Instruction(
"parmov",
(Virtual(3), "use"),
(Virtual(4), "use"),
(Virtual(5, RCX), "def"),
(Virtual(6, RDX), "def"),
)
phi = Instruction("phi", (Virtual(7), "use"), (Virtual(8), "use"), (Virtual(9), "def"))
print(f"Instructions: {[instruction, parallel_move, phi]}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment