Last active
September 28, 2025 18:12
-
-
Save DarinM223/e2d75e55d3e869fa6b7f6aecf934efc5 to your computer and use it in GitHub Desktop.
Numba experiments
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Quadtree in Numba: | |
| Notes: | |
| 1. `if field:` or `if not field:` for an optional field | |
| type doesn't work in numba, instead do `if field is not None:` or | |
| `if field is None:`. | |
| 2. There is a bug with numba code generation with assignment | |
| to deferred types. To work around that, add `set_field` methods | |
| in the @jitclass (like `set_top_left_tree`) and use those inside a @njit | |
| function. | |
| 3. Insert and search can't be methods because they recurse into | |
| a deferred type, so they are moved out into `quadtree_insert` and `quadtree_search`. | |
| """ | |
| from __future__ import annotations | |
| from numba import njit # type: ignore | |
| from numba import deferred_type, int64, optional # type: ignore | |
| from numba.experimental import jitclass # type: ignore | |
| from typing import Optional | |
| @jitclass([("x", int64), ("y", int64)]) # type: ignore | |
| class Point: | |
| x: int | |
| y: int | |
| def __init__(self, x: int, y: int): | |
| self.x = x | |
| self.y = y | |
| @jitclass([("data", int64)]) # type: ignore | |
| class Node: | |
| pos: Point | |
| data: int | |
| def __init__(self, pos: Point, data: int): | |
| self.pos = pos | |
| self.data = data | |
| quadtree_type = deferred_type() | |
| spec = [ | |
| ("top_left_tree", optional(quadtree_type)), | |
| ("top_right_tree", optional(quadtree_type)), | |
| ("bot_left_tree", optional(quadtree_type)), | |
| ("bot_right_tree", optional(quadtree_type)), | |
| ] | |
| @jitclass(spec) # type: ignore | |
| class Quadtree: | |
| top_left: Point | |
| bot_right: Point | |
| node: Optional[Node] | |
| def __init__(self, topL: Point, botR: Point): | |
| self.node = None | |
| self.top_left = topL | |
| self.bot_right = botR | |
| self.top_left_tree: Optional[Quadtree] = None | |
| self.top_right_tree: Optional[Quadtree] = None | |
| self.bot_left_tree: Optional[Quadtree] = None | |
| self.bot_right_tree: Optional[Quadtree] = None | |
| def set_top_left_tree(self, tree: Optional[Quadtree] = None): | |
| self.top_left_tree = tree | |
| def set_bot_left_tree(self, tree: Optional[Quadtree] = None): | |
| self.bot_left_tree = tree | |
| def set_top_right_tree(self, tree: Optional[Quadtree] = None): | |
| self.top_right_tree = tree | |
| def set_bot_right_tree(self, tree: Optional[Quadtree] = None): | |
| self.bot_right_tree = tree | |
| def set_node(self, node: Optional[Node] = None): | |
| self.node = node | |
| def in_boundary(self, p: Point) -> bool: | |
| return ( | |
| p.x >= self.top_left.x | |
| and p.x <= self.bot_right.x | |
| and p.y >= self.top_left.y | |
| and p.y <= self.bot_right.y | |
| ) | |
| def insert(self, node: Optional[Node] = None): | |
| return quadtree_insert(self, node) | |
| def search(self, p: Point) -> Optional[Node]: | |
| return quadtree_search(self, p) | |
| @njit | |
| def quadtree_insert(self: Quadtree, node: Optional[Node] = None): | |
| if node is None: | |
| return | |
| if not self.in_boundary(node.pos): | |
| return | |
| if ( | |
| abs(self.top_left.x - self.bot_right.x) <= 1 | |
| and abs(self.top_left.y - self.bot_right.y) <= 1 | |
| ): | |
| if self.node is None: | |
| self.set_node(node) | |
| return | |
| if (self.top_left.x + self.bot_right.x) // 2 >= node.pos.x: | |
| if (self.top_left.y + self.bot_right.y) // 2 >= node.pos.y: | |
| if self.top_left_tree is None: | |
| self.set_top_left_tree( | |
| Quadtree( | |
| Point(self.top_left.x, self.top_left.y), | |
| Point( | |
| (self.top_left.x + self.bot_right.x) // 2, | |
| (self.top_left.y + self.bot_right.y) // 2, | |
| ), | |
| ) | |
| ) | |
| quadtree_insert(self.top_left_tree, node) # type: ignore | |
| else: | |
| if self.bot_left_tree is None: | |
| self.set_bot_left_tree( | |
| Quadtree( | |
| Point( | |
| self.top_left.x, (self.top_left.y + self.bot_right.y) // 2 | |
| ), | |
| Point( | |
| (self.top_left.x + self.bot_right.x) // 2, self.bot_right.y | |
| ), | |
| ) | |
| ) | |
| quadtree_insert(self.bot_left_tree, node) # type: ignore | |
| else: | |
| if (self.top_left.y + self.bot_right.y) // 2 >= node.pos.y: | |
| if self.top_right_tree is None: | |
| self.set_top_right_tree( | |
| Quadtree( | |
| Point( | |
| (self.top_left.x + self.bot_right.x) // 2, self.top_left.y | |
| ), | |
| Point( | |
| self.bot_right.x, (self.top_left.y + self.bot_right.y) // 2 | |
| ), | |
| ) | |
| ) | |
| quadtree_insert(self.top_right_tree, node) # type: ignore | |
| else: | |
| if self.bot_right_tree is None: | |
| self.set_bot_right_tree( | |
| Quadtree( | |
| Point( | |
| (self.top_left.x + self.bot_right.x) // 2, | |
| (self.top_left.y + self.bot_right.y) // 2, | |
| ), | |
| Point(self.bot_right.x, self.bot_right.y), | |
| ) | |
| ) | |
| quadtree_insert(self.bot_right_tree, node) # type: ignore | |
| @njit | |
| def quadtree_search(self: Quadtree, p: Point) -> Optional[Node]: | |
| if not self.in_boundary(p): | |
| return None | |
| if self.node is not None: | |
| return self.node | |
| if (self.top_left.x + self.bot_right.x) // 2 >= p.x: | |
| if (self.top_left.y + self.bot_right.y) // 2 >= p.y: | |
| if self.top_left_tree is None: | |
| return None | |
| return quadtree_search(self.top_left_tree, p) | |
| else: | |
| if self.bot_left_tree is None: | |
| return None | |
| return quadtree_search(self.bot_left_tree, p) | |
| else: | |
| if (self.top_left.y + self.bot_right.y) // 2 >= p.y: | |
| if self.top_right_tree is None: | |
| return None | |
| return quadtree_search(self.top_right_tree, p) | |
| else: | |
| if self.bot_right_tree is None: | |
| return None | |
| return quadtree_search(self.bot_right_tree, p) | |
| quadtree_type.define(Quadtree.class_type.instance_type) # type: ignore | |
| @njit | |
| def main(): | |
| center = Quadtree(Point(0, 0), Point(8, 8)) | |
| center.insert(Node(Point(1, 1), 1)) | |
| center.insert(Node(Point(2, 5), 2)) | |
| center.insert(Node(Point(7, 6), 3)) | |
| return ( | |
| center.search(Point(1, 1)), | |
| center.search(Point(2, 5)), | |
| center.search(Point(7, 6)), | |
| center.search(Point(5, 5)), | |
| ) | |
| def print_node(node: Optional[Node]) -> None: | |
| if node is None: | |
| print("Non existing node") | |
| return | |
| print(f"Node: {node.data}") | |
| a, b, c, d = main() | |
| print_node(a) | |
| print_node(b) | |
| print_node(c) | |
| print_node(d) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment