Created
September 24, 2025 12:24
-
-
Save cheery/78213736a7b321d1ee7d5b23cd0e228e to your computer and use it in GitHub Desktop.
AVL trees and ropes
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import annotations | |
| from dataclasses import dataclass, field | |
| from typing import Any, Union | |
| @dataclass(eq=False) | |
| class BalancedTree: | |
| height : int = field(init = False) | |
| def __post_init__(self): | |
| if self.is_empty: | |
| self.height = 0 | |
| else: | |
| self.height = 1 + max(self.left.height, self.right.height) | |
| @property | |
| def balance(self): | |
| if self.height == 0: | |
| return 0 | |
| else: | |
| return self.left.height - self.right.height | |
| def pluck(left : BalancedTree, right : BalancedTree) -> BalancedTree: | |
| if left.height == 0: | |
| return right | |
| elif right.height == 0: | |
| return left | |
| else: | |
| successor = right | |
| spine = [] | |
| while successor.left.height != 0: | |
| spine.append(successor) | |
| successor = successor.left | |
| newright = successor.right | |
| for node in reversed(spine): | |
| newright = rebalance(node.retain(newright, node.right)) | |
| return rebalance(successor.retain(left, newright)) | |
| def rebalance(node : BalancedTree) -> BalancedTree: | |
| balance = node.balance | |
| if balance > 1: | |
| if node.left.balance >= 0: | |
| return right_rotate(node) | |
| else: | |
| temp = node.retain(left_rotate(node.left), node.right) | |
| return right_rotate(temp) | |
| elif balance < -1: | |
| if node.right.balance <= 0: | |
| return left_rotate(node) | |
| else: | |
| temp = node.retain(node.left, right_rotate(node.right)) | |
| return left_rotate(temp) | |
| return node | |
| def right_rotate(z : BalancedTree) -> BalancedTree: | |
| x = z.left | |
| y = x.right | |
| return x.retain(x.left, z.retain(y, z.right)) | |
| def left_rotate(x : BalancedTree) -> BalancedTree: | |
| z = x.right | |
| y = z.left | |
| return z.retain(x.retain(x.left, y), z.right) | |
| @dataclass(eq=False) | |
| class Rope(BalancedTree): | |
| is_empty = True | |
| length : int = field(init = False) | |
| def __iter__(self): | |
| return iter(()) | |
| def __post_init__(self): | |
| self.length = 0 | |
| BalancedTree.__post_init__(self) | |
| def insert(self, pos, text): | |
| if pos != 0: | |
| raise IndexError | |
| if text == "": | |
| return self | |
| return RopeSegment(text, self, self) | |
| def erase(self, pos, length): | |
| if self.pos != 0 or length != 0: | |
| raise IndexError | |
| return self | |
| @dataclass(eq=False) | |
| class RopeSegment(Rope): | |
| is_empty = False | |
| text : str | |
| left : Rope | |
| right : Rope | |
| def __iter__(self): | |
| yield from self.left | |
| yield self.text | |
| yield from self.right | |
| def __post_init__(self): | |
| self.length = len(self.text) + self.left.length + self.right.length | |
| BalancedTree.__post_init__(self) | |
| def retain(self, left, right): | |
| return RopeSegment(self.text, left, right) | |
| def insert(self, pos, text): | |
| ledge = self.left.length | |
| redge = self.left.length + len(self.text) | |
| cut = pos - ledge | |
| if pos < ledge: | |
| node = self.retain(self.left.insert(pos, text), self.right) | |
| elif pos > redge: | |
| node = self.retain(self.left, self.right.insert(pos - redge, text)) | |
| elif len(self.text) + len(text) > 8: | |
| left = self.left.insert(self.left.length, self.text[:cut]) | |
| right = self.right.insert(0, self.text[cut:]) | |
| node = RopeSegment(text, left, right) | |
| else: | |
| text = self.text[:cut] + text + self.text[cut:] | |
| node = RopeSegment(text, self.left, self.right) | |
| return rebalance(node) | |
| def erase(self, start, stop): | |
| ledge = self.left.length | |
| redge = self.left.length + len(self.text) | |
| if start < ledge: | |
| left = self.left.erase(start, min(ledge, stop)) | |
| else: | |
| left = self.left | |
| if redge < stop: | |
| right = self.right.erase(max(0, start - redge), stop - redge) | |
| else: | |
| right = self.right | |
| if start < redge and ledge < stop: | |
| start = max(ledge, min(redge, start)) - ledge | |
| stop = max(ledge, min(redge, stop)) - ledge | |
| text = self.text[:start] + self.text[stop:] | |
| if len(text) > 0: | |
| node = RopeSegment(text, left, right) | |
| else: | |
| node = pluck(left, right) | |
| else: | |
| node = self.retain(left, right) | |
| return rebalance(node) | |
| @dataclass(eq=False) | |
| class Avl(BalancedTree): | |
| def insert(self, key, *args): | |
| match self.compare(key): | |
| case -1: | |
| node = self.retain(self.left.insert(key, *args), self.right) | |
| case 1: | |
| node = self.retain(self.left, self.right.insert(key, *args)) | |
| case 0: | |
| node = self.refine(*args) | |
| return rebalance(node) | |
| def delete(self, key): | |
| match self.compare(key): | |
| case -1: | |
| node = self.retain(self.left.delete(key), self.right) | |
| case 1: | |
| node = self.retain(self.left, self.right.delete(key)) | |
| case 0: | |
| node = pluck(self.left, self.right) | |
| return rebalance(node) | |
| def query(self, key): | |
| match self.compare(key): | |
| case -1: | |
| return self.left.query(key) | |
| case 1: | |
| return self.right.query(key) | |
| case 0: | |
| return self.retrieve() | |
| foo = Rope() | |
| foo = foo.insert(0, "Hello world") | |
| foo = foo.insert(5, "Hello world") | |
| foo = foo.insert(10, "Hello world") | |
| foo = foo.erase(5, 15) | |
| print(list(foo)) | |
| # @dataclass(eq=False) | |
| # class Empty(Avl): | |
| # is_empty = True | |
| # def __iter__(self): | |
| # return iter(()) | |
| # | |
| # def insert(self, key, value=None): | |
| # return Node(key, value) | |
| # | |
| # def delete(self, key): | |
| # raise KeyError | |
| # | |
| # def query(self, key): | |
| # raise KeyError | |
| # | |
| # empty = Empty() | |
| # | |
| # @dataclass | |
| # class Node(Avl): | |
| # is_empty = False | |
| # key : Any | |
| # value : Any | |
| # left : Avl = field(default_factory=lambda: empty) | |
| # right : Avl = field(default_factory=lambda: empty) | |
| # | |
| # def __iter__(self): | |
| # yield from self.left | |
| # yield self.key, self.value | |
| # yield from self.right | |
| # | |
| # def compare(node, key): | |
| # if key < node.key: | |
| # return -1 | |
| # elif node.key < key: | |
| # return +1 | |
| # else: | |
| # return 0 | |
| # | |
| # def retain(node, left, right): | |
| # return Node(node.key, node.value, left, right) | |
| # | |
| # def refine(node, value): | |
| # return Node(node.key, value, self.left, self.right) | |
| # | |
| # def retrieve(node): | |
| # return node.value | |
| # | |
| # if __name__=='__main__': | |
| # root = empty | |
| # root = root.insert(10) | |
| # root = root.insert(20) | |
| # root = root.insert(30) | |
| # root = root.insert(40) | |
| # root = root.insert(50) | |
| # root = root.insert(25) | |
| # root = root.delete(40) | |
| # | |
| # for i in root: | |
| # print(i) | |
| # print(root.query(25)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment