Skip to content

Instantly share code, notes, and snippets.

@cheery
Created September 24, 2025 12:24
Show Gist options
  • Select an option

  • Save cheery/78213736a7b321d1ee7d5b23cd0e228e to your computer and use it in GitHub Desktop.

Select an option

Save cheery/78213736a7b321d1ee7d5b23cd0e228e to your computer and use it in GitHub Desktop.
AVL trees and ropes
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Union
@dataclass(eq=False)
class BalancedTree:
height : int = field(init = False)
def __post_init__(self):
if self.is_empty:
self.height = 0
else:
self.height = 1 + max(self.left.height, self.right.height)
@property
def balance(self):
if self.height == 0:
return 0
else:
return self.left.height - self.right.height
def pluck(left : BalancedTree, right : BalancedTree) -> BalancedTree:
if left.height == 0:
return right
elif right.height == 0:
return left
else:
successor = right
spine = []
while successor.left.height != 0:
spine.append(successor)
successor = successor.left
newright = successor.right
for node in reversed(spine):
newright = rebalance(node.retain(newright, node.right))
return rebalance(successor.retain(left, newright))
def rebalance(node : BalancedTree) -> BalancedTree:
balance = node.balance
if balance > 1:
if node.left.balance >= 0:
return right_rotate(node)
else:
temp = node.retain(left_rotate(node.left), node.right)
return right_rotate(temp)
elif balance < -1:
if node.right.balance <= 0:
return left_rotate(node)
else:
temp = node.retain(node.left, right_rotate(node.right))
return left_rotate(temp)
return node
def right_rotate(z : BalancedTree) -> BalancedTree:
x = z.left
y = x.right
return x.retain(x.left, z.retain(y, z.right))
def left_rotate(x : BalancedTree) -> BalancedTree:
z = x.right
y = z.left
return z.retain(x.retain(x.left, y), z.right)
@dataclass(eq=False)
class Rope(BalancedTree):
is_empty = True
length : int = field(init = False)
def __iter__(self):
return iter(())
def __post_init__(self):
self.length = 0
BalancedTree.__post_init__(self)
def insert(self, pos, text):
if pos != 0:
raise IndexError
if text == "":
return self
return RopeSegment(text, self, self)
def erase(self, pos, length):
if self.pos != 0 or length != 0:
raise IndexError
return self
@dataclass(eq=False)
class RopeSegment(Rope):
is_empty = False
text : str
left : Rope
right : Rope
def __iter__(self):
yield from self.left
yield self.text
yield from self.right
def __post_init__(self):
self.length = len(self.text) + self.left.length + self.right.length
BalancedTree.__post_init__(self)
def retain(self, left, right):
return RopeSegment(self.text, left, right)
def insert(self, pos, text):
ledge = self.left.length
redge = self.left.length + len(self.text)
cut = pos - ledge
if pos < ledge:
node = self.retain(self.left.insert(pos, text), self.right)
elif pos > redge:
node = self.retain(self.left, self.right.insert(pos - redge, text))
elif len(self.text) + len(text) > 8:
left = self.left.insert(self.left.length, self.text[:cut])
right = self.right.insert(0, self.text[cut:])
node = RopeSegment(text, left, right)
else:
text = self.text[:cut] + text + self.text[cut:]
node = RopeSegment(text, self.left, self.right)
return rebalance(node)
def erase(self, start, stop):
ledge = self.left.length
redge = self.left.length + len(self.text)
if start < ledge:
left = self.left.erase(start, min(ledge, stop))
else:
left = self.left
if redge < stop:
right = self.right.erase(max(0, start - redge), stop - redge)
else:
right = self.right
if start < redge and ledge < stop:
start = max(ledge, min(redge, start)) - ledge
stop = max(ledge, min(redge, stop)) - ledge
text = self.text[:start] + self.text[stop:]
if len(text) > 0:
node = RopeSegment(text, left, right)
else:
node = pluck(left, right)
else:
node = self.retain(left, right)
return rebalance(node)
@dataclass(eq=False)
class Avl(BalancedTree):
def insert(self, key, *args):
match self.compare(key):
case -1:
node = self.retain(self.left.insert(key, *args), self.right)
case 1:
node = self.retain(self.left, self.right.insert(key, *args))
case 0:
node = self.refine(*args)
return rebalance(node)
def delete(self, key):
match self.compare(key):
case -1:
node = self.retain(self.left.delete(key), self.right)
case 1:
node = self.retain(self.left, self.right.delete(key))
case 0:
node = pluck(self.left, self.right)
return rebalance(node)
def query(self, key):
match self.compare(key):
case -1:
return self.left.query(key)
case 1:
return self.right.query(key)
case 0:
return self.retrieve()
foo = Rope()
foo = foo.insert(0, "Hello world")
foo = foo.insert(5, "Hello world")
foo = foo.insert(10, "Hello world")
foo = foo.erase(5, 15)
print(list(foo))
# @dataclass(eq=False)
# class Empty(Avl):
# is_empty = True
# def __iter__(self):
# return iter(())
#
# def insert(self, key, value=None):
# return Node(key, value)
#
# def delete(self, key):
# raise KeyError
#
# def query(self, key):
# raise KeyError
#
# empty = Empty()
#
# @dataclass
# class Node(Avl):
# is_empty = False
# key : Any
# value : Any
# left : Avl = field(default_factory=lambda: empty)
# right : Avl = field(default_factory=lambda: empty)
#
# def __iter__(self):
# yield from self.left
# yield self.key, self.value
# yield from self.right
#
# def compare(node, key):
# if key < node.key:
# return -1
# elif node.key < key:
# return +1
# else:
# return 0
#
# def retain(node, left, right):
# return Node(node.key, node.value, left, right)
#
# def refine(node, value):
# return Node(node.key, value, self.left, self.right)
#
# def retrieve(node):
# return node.value
#
# if __name__=='__main__':
# root = empty
# root = root.insert(10)
# root = root.insert(20)
# root = root.insert(30)
# root = root.insert(40)
# root = root.insert(50)
# root = root.insert(25)
# root = root.delete(40)
#
# for i in root:
# print(i)
# print(root.query(25))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment