Skip to content

Instantly share code, notes, and snippets.

@eladn
Last active October 18, 2023 10:25
Show Gist options
  • Select an option

  • Save eladn/d83ad7f0fac15d59c68b613202cce58c to your computer and use it in GitHub Desktop.

Select an option

Save eladn/d83ad7f0fac15d59c68b613202cce58c to your computer and use it in GitHub Desktop.
Python implementation for heap-dict, allowing accessing items by lookup-key for removal or modifying
__author__ = "Elad Nachmias"
__email__ = "[email protected]"
__date__ = "2023-10-18"
from operator import lt
from dataclasses import dataclass
from typing import TypeVar, MutableMapping, Generic, List, Dict, Iterable, Tuple, Any, Callable, Optional, Collection
# Note: The key & value are kept as generic types.
LookupKeyT = TypeVar('LookupKeyT')
PriorityValueT = TypeVar('PriorityValueT')
@dataclass
class HeapDictNode(Generic[LookupKeyT, PriorityValueT]):
"""
Represents a node in the heap.
"""
priority_value: PriorityValueT # The value it is ordered by.
lookup_key: LookupKeyT # The key that is associated
heap_position: int # The index in the flattened array that represent the node's position in the binary-tree.
class HeapDict(MutableMapping[LookupKeyT, PriorityValueT]):
"""
A priority queue, allows inserting/removing items in O(log n) while keeps the min/max accessible in O(1).
The items are stored as key-value pairs. The value is a comparable priority-value used for comparing the items
and defining their ordering. The key is used both for associating/coupling the rankings with their related elements,
and also for being able to identify an item for later removal.
Note that python already has a builtin heap data-structure named `heapq`. However, it doesn't support nether
associating a priority-value with a unique key nor removing an item by its key.
Implementation details:
The priority queue is implemented as a heap data-structure as described here:
https://en.wikipedia.org/wiki/Heap_(data_structure)
The heap's promised invariant is that for each node its value is bigger/smaller (for max/min heap accordingly) by
its entire subtree. Thus, we can still be certain that the root is the biggest/smallest in the entire heap,
although the tree is not entirely sorted (no order between siblings).
The nodes of the binary-tree-structured heap are stored in a flattened array. To translate
between a position in the tree and a position in the flattened array, we use the accessing pattern as described here
https://en.wikipedia.org/wiki/Heap_(data_structure)#Implementation
This implementation is based on `https://github.com/DanielStutzbach/heapdict/blob/master/heapdict.py` but added
more functionality, more verbose namings, tests, typing-annotations and fixed bugs.
"""
def __init__(
self,
*args: Collection[Tuple[LookupKeyT, PriorityValueT]],
min_or_max: str = 'min',
less_than_comparator: Callable[[PriorityValueT, PriorityValueT], bool] = lt,
**kwargs):
"""
:param *args: the items to initialize the heap with upon construction
:param min_or_max: whether this is a minimum or a maximum heap (wrt the comparator's induced ordering)
:param less_than_comparator: less-than comparator function to define the ordering between items
:param **kwargs: to conform to `MutableMapping` interface
"""
assert min_or_max in ('min', 'max')
self._nodes_ordered_by_heap_position: List[HeapDictNode[LookupKeyT, PriorityValueT]] = []
self._key_to_node_mapping: Dict[int, HeapDictNode[LookupKeyT, PriorityValueT]] = {}
self._min_or_max = min_or_max
self._less_than_comparator = less_than_comparator
assert ('min_or_max', 'less_than_comparator') not in kwargs
self.update(*args, **kwargs)
def clear(self) -> None:
self._nodes_ordered_by_heap_position.clear()
self._key_to_node_mapping.clear()
def __setitem__(self, lookup_key: LookupKeyT, priority_value: PriorityValueT) -> None:
if lookup_key in self._key_to_node_mapping:
self.pop(lookup_key)
node = HeapDictNode(priority_value=priority_value, lookup_key=lookup_key, heap_position=len(self))
self._key_to_node_mapping[lookup_key] = node
self._nodes_ordered_by_heap_position.append(node)
self._decrease_key(len(self._nodes_ordered_by_heap_position) - 1)
def _compare_by_positions(self, position1: int, position2: int) -> bool:
value1 = self._nodes_ordered_by_heap_position[position1].priority_value
value2 = self._nodes_ordered_by_heap_position[position2].priority_value
return self._less_than_comparator(value1, value2) if self._min_or_max == 'min' else \
self._less_than_comparator(value2, value1)
def _min_heapify(self, heap_position: int) -> None:
l = self._left_idx(heap_position)
r = self._right_idx(heap_position)
n = len(self._nodes_ordered_by_heap_position)
if l < n and self._compare_by_positions(l, heap_position):
low = l
else:
low = heap_position
if r < n and self._compare_by_positions(r, low):
low = r
if low != heap_position:
self._swap(heap_position, low)
self._min_heapify(low)
def _decrease_key(self, i: int) -> None:
while i:
parent = self._parent_idx(i)
if self._compare_by_positions(parent, i):
break
self._swap(i, parent)
i = parent
def _swap(self, i: int, j: int) -> None:
self._nodes_ordered_by_heap_position[i], self._nodes_ordered_by_heap_position[j] = \
self._nodes_ordered_by_heap_position[j], self._nodes_ordered_by_heap_position[i]
self._nodes_ordered_by_heap_position[i].heap_position = i
self._nodes_ordered_by_heap_position[j].heap_position = j
def __delitem__(self, lookup_key: LookupKeyT):
if lookup_key not in self._key_to_node_mapping:
raise KeyError(lookup_key)
node = self._key_to_node_mapping[lookup_key]
while node.heap_position:
parent_pos = self._parent_idx(node.heap_position)
parent = self._nodes_ordered_by_heap_position[parent_pos]
self._swap(node.heap_position, parent.heap_position)
self.popitem()
def __getitem__(self, lookup_key: LookupKeyT) -> PriorityValueT:
if lookup_key not in self._key_to_node_mapping:
raise KeyError(lookup_key)
return self._key_to_node_mapping[lookup_key].priority_value
def __contains__(self, lookup_key: LookupKeyT) -> bool:
return lookup_key in self._key_to_node_mapping
def __iter__(self) -> Iterable[Tuple[LookupKeyT, PriorityValueT]]:
return ((node.lookup_key, node.priority_value) for node in self._key_to_node_mapping.values())
def sorted_iter(self) -> Iterable[Tuple[LookupKeyT, PriorityValueT]]:
heap = self.copy()
while len(heap):
yield heap.popitem()
def copy(self) -> 'HeapDict[LookupKeyT, PriorityValueT]':
new_heap = HeapDict(min_or_max=self._min_or_max, less_than_comparator=self._less_than_comparator)
new_heap._nodes_ordered_by_heap_position = self._nodes_ordered_by_heap_position
new_heap._key_to_node_mapping = self._key_to_node_mapping
return new_heap
def popitem(self, default: Any = None) -> Optional[Tuple[LookupKeyT, PriorityValueT]]:
if len(self._nodes_ordered_by_heap_position) == 0:
return default
node = self._nodes_ordered_by_heap_position[0]
if len(self._nodes_ordered_by_heap_position) == 1:
self._nodes_ordered_by_heap_position.pop()
else:
self._nodes_ordered_by_heap_position[0] = self._nodes_ordered_by_heap_position.pop(-1)
self._nodes_ordered_by_heap_position[0].heap_position = 0
self._min_heapify(0)
del self._key_to_node_mapping[node.lookup_key]
return node.lookup_key, node.priority_value
def __len__(self) -> int:
return len(self._key_to_node_mapping)
def peek(self, default: Any = None) -> Optional[Tuple[LookupKeyT, PriorityValueT]]:
if len(self._nodes_ordered_by_heap_position) == 0:
return default
return self._nodes_ordered_by_heap_position[0].lookup_key, \
self._nodes_ordered_by_heap_position[0].priority_value
def is_empty(self) -> bool:
return len(self._nodes_ordered_by_heap_position) == 0
@staticmethod
def _parent_idx(heap_position: int) -> int:
"""
The nodes of the binary-tree-structured heap are stored in a flattened array. To translate between a position
in the tree and a position in the flattened array, we use the accessing pattern as described here:
https://en.wikipedia.org/wiki/Heap_(data_structure)#Implementation
"""
return (heap_position - 1) >> 1
@staticmethod
def _left_idx(heap_position: int) -> int:
"""
The nodes of the binary-tree-structured heap are stored in a flattened array. To translate between a position
in the tree and a position in the flattened array, we use the accessing pattern as described here:
https://en.wikipedia.org/wiki/Heap_(data_structure)#Implementation
"""
return (heap_position << 1) + 1
@staticmethod
def _right_idx(heap_position: int) -> int:
"""
The nodes of the binary-tree-structured heap are stored in a flattened array. To translate between a position
in the tree and a position in the flattened array, we use the accessing pattern as described here:
https://en.wikipedia.org/wiki/Heap_(data_structure)#Implementation
"""
return (heap_position + 1) << 1
def test_heap_dict_push_all_then_pop_all():
import numpy as np
rng = np.random.RandomState(seed=0)
heap = HeapDict()
arr = rng.random(size=100)
for item_idx, item in enumerate(arr):
heap[item_idx] = item
items_by_popping_order = []
while not heap.is_empty():
items_by_popping_order.append(heap.popitem()[1])
assert len(items_by_popping_order) == len(arr)
assert all(
items_by_popping_order[item_idx] <= items_by_popping_order[item_idx + 1]
for item_idx in range(len(items_by_popping_order) - 1))
def test_heap_dict_randomly_pop_while_pushing():
import numpy as np
rng = np.random.RandomState(seed=0)
heap = HeapDict()
total_nr_items = 100
arr = list(rng.random(size=total_nr_items))
popped_items_by_popping_order = []
while len(arr) > 0 or len(popped_items_by_popping_order) < total_nr_items:
if len(arr) and len(popped_items_by_popping_order):
operation = rng.choice(['push', 'pop'], p=[0.7, 0.3])
elif len(arr):
operation = 'push'
else:
operation = 'pop'
if operation == 'push':
heap[len(heap) + len(popped_items_by_popping_order)] = arr.pop()
elif operation == 'pop':
popped_items_by_popping_order.append(heap.popitem()[1])
assert len(popped_items_by_popping_order) == total_nr_items
assert all(
popped_items_by_popping_order[item_idx] <= popped_items_by_popping_order[item_idx + 1]
for item_idx in range(len(popped_items_by_popping_order) - 1))
if __name__ == "__main__":
test_heap_dict_push_all_then_pop_all()
test_heap_dict_randomly_pop_while_pushing()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment