Last active
October 18, 2023 10:25
-
-
Save eladn/d83ad7f0fac15d59c68b613202cce58c to your computer and use it in GitHub Desktop.
Python implementation for heap-dict, allowing accessing items by lookup-key for removal or modifying
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| __author__ = "Elad Nachmias" | |
| __email__ = "[email protected]" | |
| __date__ = "2023-10-18" | |
| from operator import lt | |
| from dataclasses import dataclass | |
| from typing import TypeVar, MutableMapping, Generic, List, Dict, Iterable, Tuple, Any, Callable, Optional, Collection | |
| # Note: The key & value are kept as generic types. | |
| LookupKeyT = TypeVar('LookupKeyT') | |
| PriorityValueT = TypeVar('PriorityValueT') | |
| @dataclass | |
| class HeapDictNode(Generic[LookupKeyT, PriorityValueT]): | |
| """ | |
| Represents a node in the heap. | |
| """ | |
| priority_value: PriorityValueT # The value it is ordered by. | |
| lookup_key: LookupKeyT # The key that is associated | |
| heap_position: int # The index in the flattened array that represent the node's position in the binary-tree. | |
| class HeapDict(MutableMapping[LookupKeyT, PriorityValueT]): | |
| """ | |
| A priority queue, allows inserting/removing items in O(log n) while keeps the min/max accessible in O(1). | |
| The items are stored as key-value pairs. The value is a comparable priority-value used for comparing the items | |
| and defining their ordering. The key is used both for associating/coupling the rankings with their related elements, | |
| and also for being able to identify an item for later removal. | |
| Note that python already has a builtin heap data-structure named `heapq`. However, it doesn't support nether | |
| associating a priority-value with a unique key nor removing an item by its key. | |
| Implementation details: | |
| The priority queue is implemented as a heap data-structure as described here: | |
| https://en.wikipedia.org/wiki/Heap_(data_structure) | |
| The heap's promised invariant is that for each node its value is bigger/smaller (for max/min heap accordingly) by | |
| its entire subtree. Thus, we can still be certain that the root is the biggest/smallest in the entire heap, | |
| although the tree is not entirely sorted (no order between siblings). | |
| The nodes of the binary-tree-structured heap are stored in a flattened array. To translate | |
| between a position in the tree and a position in the flattened array, we use the accessing pattern as described here | |
| https://en.wikipedia.org/wiki/Heap_(data_structure)#Implementation | |
| This implementation is based on `https://github.com/DanielStutzbach/heapdict/blob/master/heapdict.py` but added | |
| more functionality, more verbose namings, tests, typing-annotations and fixed bugs. | |
| """ | |
| def __init__( | |
| self, | |
| *args: Collection[Tuple[LookupKeyT, PriorityValueT]], | |
| min_or_max: str = 'min', | |
| less_than_comparator: Callable[[PriorityValueT, PriorityValueT], bool] = lt, | |
| **kwargs): | |
| """ | |
| :param *args: the items to initialize the heap with upon construction | |
| :param min_or_max: whether this is a minimum or a maximum heap (wrt the comparator's induced ordering) | |
| :param less_than_comparator: less-than comparator function to define the ordering between items | |
| :param **kwargs: to conform to `MutableMapping` interface | |
| """ | |
| assert min_or_max in ('min', 'max') | |
| self._nodes_ordered_by_heap_position: List[HeapDictNode[LookupKeyT, PriorityValueT]] = [] | |
| self._key_to_node_mapping: Dict[int, HeapDictNode[LookupKeyT, PriorityValueT]] = {} | |
| self._min_or_max = min_or_max | |
| self._less_than_comparator = less_than_comparator | |
| assert ('min_or_max', 'less_than_comparator') not in kwargs | |
| self.update(*args, **kwargs) | |
| def clear(self) -> None: | |
| self._nodes_ordered_by_heap_position.clear() | |
| self._key_to_node_mapping.clear() | |
| def __setitem__(self, lookup_key: LookupKeyT, priority_value: PriorityValueT) -> None: | |
| if lookup_key in self._key_to_node_mapping: | |
| self.pop(lookup_key) | |
| node = HeapDictNode(priority_value=priority_value, lookup_key=lookup_key, heap_position=len(self)) | |
| self._key_to_node_mapping[lookup_key] = node | |
| self._nodes_ordered_by_heap_position.append(node) | |
| self._decrease_key(len(self._nodes_ordered_by_heap_position) - 1) | |
| def _compare_by_positions(self, position1: int, position2: int) -> bool: | |
| value1 = self._nodes_ordered_by_heap_position[position1].priority_value | |
| value2 = self._nodes_ordered_by_heap_position[position2].priority_value | |
| return self._less_than_comparator(value1, value2) if self._min_or_max == 'min' else \ | |
| self._less_than_comparator(value2, value1) | |
| def _min_heapify(self, heap_position: int) -> None: | |
| l = self._left_idx(heap_position) | |
| r = self._right_idx(heap_position) | |
| n = len(self._nodes_ordered_by_heap_position) | |
| if l < n and self._compare_by_positions(l, heap_position): | |
| low = l | |
| else: | |
| low = heap_position | |
| if r < n and self._compare_by_positions(r, low): | |
| low = r | |
| if low != heap_position: | |
| self._swap(heap_position, low) | |
| self._min_heapify(low) | |
| def _decrease_key(self, i: int) -> None: | |
| while i: | |
| parent = self._parent_idx(i) | |
| if self._compare_by_positions(parent, i): | |
| break | |
| self._swap(i, parent) | |
| i = parent | |
| def _swap(self, i: int, j: int) -> None: | |
| self._nodes_ordered_by_heap_position[i], self._nodes_ordered_by_heap_position[j] = \ | |
| self._nodes_ordered_by_heap_position[j], self._nodes_ordered_by_heap_position[i] | |
| self._nodes_ordered_by_heap_position[i].heap_position = i | |
| self._nodes_ordered_by_heap_position[j].heap_position = j | |
| def __delitem__(self, lookup_key: LookupKeyT): | |
| if lookup_key not in self._key_to_node_mapping: | |
| raise KeyError(lookup_key) | |
| node = self._key_to_node_mapping[lookup_key] | |
| while node.heap_position: | |
| parent_pos = self._parent_idx(node.heap_position) | |
| parent = self._nodes_ordered_by_heap_position[parent_pos] | |
| self._swap(node.heap_position, parent.heap_position) | |
| self.popitem() | |
| def __getitem__(self, lookup_key: LookupKeyT) -> PriorityValueT: | |
| if lookup_key not in self._key_to_node_mapping: | |
| raise KeyError(lookup_key) | |
| return self._key_to_node_mapping[lookup_key].priority_value | |
| def __contains__(self, lookup_key: LookupKeyT) -> bool: | |
| return lookup_key in self._key_to_node_mapping | |
| def __iter__(self) -> Iterable[Tuple[LookupKeyT, PriorityValueT]]: | |
| return ((node.lookup_key, node.priority_value) for node in self._key_to_node_mapping.values()) | |
| def sorted_iter(self) -> Iterable[Tuple[LookupKeyT, PriorityValueT]]: | |
| heap = self.copy() | |
| while len(heap): | |
| yield heap.popitem() | |
| def copy(self) -> 'HeapDict[LookupKeyT, PriorityValueT]': | |
| new_heap = HeapDict(min_or_max=self._min_or_max, less_than_comparator=self._less_than_comparator) | |
| new_heap._nodes_ordered_by_heap_position = self._nodes_ordered_by_heap_position | |
| new_heap._key_to_node_mapping = self._key_to_node_mapping | |
| return new_heap | |
| def popitem(self, default: Any = None) -> Optional[Tuple[LookupKeyT, PriorityValueT]]: | |
| if len(self._nodes_ordered_by_heap_position) == 0: | |
| return default | |
| node = self._nodes_ordered_by_heap_position[0] | |
| if len(self._nodes_ordered_by_heap_position) == 1: | |
| self._nodes_ordered_by_heap_position.pop() | |
| else: | |
| self._nodes_ordered_by_heap_position[0] = self._nodes_ordered_by_heap_position.pop(-1) | |
| self._nodes_ordered_by_heap_position[0].heap_position = 0 | |
| self._min_heapify(0) | |
| del self._key_to_node_mapping[node.lookup_key] | |
| return node.lookup_key, node.priority_value | |
| def __len__(self) -> int: | |
| return len(self._key_to_node_mapping) | |
| def peek(self, default: Any = None) -> Optional[Tuple[LookupKeyT, PriorityValueT]]: | |
| if len(self._nodes_ordered_by_heap_position) == 0: | |
| return default | |
| return self._nodes_ordered_by_heap_position[0].lookup_key, \ | |
| self._nodes_ordered_by_heap_position[0].priority_value | |
| def is_empty(self) -> bool: | |
| return len(self._nodes_ordered_by_heap_position) == 0 | |
| @staticmethod | |
| def _parent_idx(heap_position: int) -> int: | |
| """ | |
| The nodes of the binary-tree-structured heap are stored in a flattened array. To translate between a position | |
| in the tree and a position in the flattened array, we use the accessing pattern as described here: | |
| https://en.wikipedia.org/wiki/Heap_(data_structure)#Implementation | |
| """ | |
| return (heap_position - 1) >> 1 | |
| @staticmethod | |
| def _left_idx(heap_position: int) -> int: | |
| """ | |
| The nodes of the binary-tree-structured heap are stored in a flattened array. To translate between a position | |
| in the tree and a position in the flattened array, we use the accessing pattern as described here: | |
| https://en.wikipedia.org/wiki/Heap_(data_structure)#Implementation | |
| """ | |
| return (heap_position << 1) + 1 | |
| @staticmethod | |
| def _right_idx(heap_position: int) -> int: | |
| """ | |
| The nodes of the binary-tree-structured heap are stored in a flattened array. To translate between a position | |
| in the tree and a position in the flattened array, we use the accessing pattern as described here: | |
| https://en.wikipedia.org/wiki/Heap_(data_structure)#Implementation | |
| """ | |
| return (heap_position + 1) << 1 | |
| def test_heap_dict_push_all_then_pop_all(): | |
| import numpy as np | |
| rng = np.random.RandomState(seed=0) | |
| heap = HeapDict() | |
| arr = rng.random(size=100) | |
| for item_idx, item in enumerate(arr): | |
| heap[item_idx] = item | |
| items_by_popping_order = [] | |
| while not heap.is_empty(): | |
| items_by_popping_order.append(heap.popitem()[1]) | |
| assert len(items_by_popping_order) == len(arr) | |
| assert all( | |
| items_by_popping_order[item_idx] <= items_by_popping_order[item_idx + 1] | |
| for item_idx in range(len(items_by_popping_order) - 1)) | |
| def test_heap_dict_randomly_pop_while_pushing(): | |
| import numpy as np | |
| rng = np.random.RandomState(seed=0) | |
| heap = HeapDict() | |
| total_nr_items = 100 | |
| arr = list(rng.random(size=total_nr_items)) | |
| popped_items_by_popping_order = [] | |
| while len(arr) > 0 or len(popped_items_by_popping_order) < total_nr_items: | |
| if len(arr) and len(popped_items_by_popping_order): | |
| operation = rng.choice(['push', 'pop'], p=[0.7, 0.3]) | |
| elif len(arr): | |
| operation = 'push' | |
| else: | |
| operation = 'pop' | |
| if operation == 'push': | |
| heap[len(heap) + len(popped_items_by_popping_order)] = arr.pop() | |
| elif operation == 'pop': | |
| popped_items_by_popping_order.append(heap.popitem()[1]) | |
| assert len(popped_items_by_popping_order) == total_nr_items | |
| assert all( | |
| popped_items_by_popping_order[item_idx] <= popped_items_by_popping_order[item_idx + 1] | |
| for item_idx in range(len(popped_items_by_popping_order) - 1)) | |
| if __name__ == "__main__": | |
| test_heap_dict_push_all_then_pop_all() | |
| test_heap_dict_randomly_pop_while_pushing() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment