Skip to content

Instantly share code, notes, and snippets.

@SchrodingerZhu
Last active April 9, 2023 02:33
Show Gist options
  • Select an option

  • Save SchrodingerZhu/8255a94ead7c3c50c297fb83a4c08ad8 to your computer and use it in GitHub Desktop.

Select an option

Save SchrodingerZhu/8255a94ead7c3c50c297fb83a4c08ad8 to your computer and use it in GitHub Desktop.
dynforest
#![no_std]
extern crate alloc;
use alloc::rc::Rc;
use core::cell::UnsafeCell;
use core::ptr::NonNull;
struct Node {
tree_parent: Option<NonNull<Self>>,
tree_left: Option<NonNull<Self>>,
tree_right: Option<NonNull<Self>>,
path_parent: Option<NonNull<Self>>,
tree_reversed: bool,
}
#[derive(Clone)]
pub struct Handle {
target: Rc<UnsafeCell<Node>>,
}
pub struct Connection {
targets: [Rc<UnsafeCell<Node>>; 2],
}
impl Node {
fn get_child(&self, is_right: bool) -> Option<NonNull<Self>> {
if is_right {
self.tree_right
} else {
self.tree_left
}
}
fn set_child(&mut self, is_right: bool, node: Option<NonNull<Self>>) {
if is_right {
self.tree_right = node;
} else {
self.tree_left = node;
}
}
unsafe fn is_right_child(&self) -> bool {
match self.tree_parent {
None => false,
Some(x) => unsafe {
match x.as_ref().tree_right.as_ref() {
None => false,
Some(x) => core::ptr::eq(x.as_ref(), self),
}
},
}
}
fn allocate() -> Handle {
Handle {
target: Rc::new(UnsafeCell::from(Self {
tree_parent: None,
tree_left: None,
tree_right: None,
path_parent: None,
tree_reversed: false,
})),
}
}
unsafe fn push_down(&mut self) {
if self.tree_reversed {
core::mem::swap(&mut self.tree_left, &mut self.tree_right);
if let Some(mut left) = self.tree_left {
(*left.as_mut()).tree_reversed = !(*left.as_mut()).tree_reversed;
}
if let Some(mut right) = self.tree_right {
right.as_mut().tree_reversed = !right.as_mut().tree_reversed;
}
self.tree_reversed = false;
}
}
unsafe fn unchecked_mut_parent(&mut self) -> &mut Node {
self.tree_parent.unwrap_unchecked().as_mut()
}
unsafe fn rotate(&mut self) {
debug_assert!(self.tree_parent.is_some());
let mut this = NonNull::from(self);
// First, check the `reversed` flags for all touched nodes. Push down the flags if needed.
if let Some(mut grandparent) = this.as_mut().unchecked_mut_parent().tree_parent {
grandparent.as_mut().push_down();
}
this.as_mut().unchecked_mut_parent().push_down();
this.as_mut().push_down();
// Secondly, during the process of going up, the lower node also need to hand over the pointer to upper-level
// paths so that only the root for each auxiliary tree can have non-null upper pointers.
let tmp = this.as_mut().unchecked_mut_parent().path_parent;
this.as_mut().unchecked_mut_parent().path_parent = this.as_ref().path_parent;
this.as_mut().path_parent = tmp;
// Prepare to swap pointers
let is_right = this.as_ref().is_right_child();
let mut parent_backup = this.as_ref().tree_parent.unwrap_unchecked();
// Update parents
if let Some(mut grandparent) = parent_backup.as_ref().tree_parent {
grandparent
.as_mut()
.set_child(parent_backup.as_ref().is_right_child(), Some(this));
}
this.as_mut().tree_parent = parent_backup.as_ref().tree_parent;
// Update children
let target_child = this.as_ref().get_child(!is_right);
parent_backup.as_mut().set_child(is_right, target_child);
if let Some(mut target_child) = target_child {
target_child.as_mut().tree_parent = Some(parent_backup);
}
// Update current node
this.as_mut().set_child(!is_right, Some(parent_backup));
parent_backup.as_mut().tree_parent = Some(this);
}
// Keep rotating until current node reaches the root
unsafe fn splay(&mut self) {
while let Some(mut parent) = self.tree_parent {
match parent.as_ref().tree_parent {
None => self.rotate(),
Some(mut grandparent) => {
grandparent.as_mut().push_down();
parent.as_mut().push_down();
if self.is_right_child() == parent.as_ref().is_right_child() {
parent.as_mut().rotate();
self.rotate();
} else {
self.rotate();
self.rotate();
}
}
}
}
}
// Separate current auxiliary tree into two smaller trees such that all nodes that are deeper than the current node
// (those who come later during in-order traversal) are cut off from the auxiliary tree.
// After the separation, current node is the root of its auxiliary tree.
unsafe fn separate_deeper_nodes(&mut self) {
self.splay();
self.push_down();
if let Some(mut right) = self.tree_right {
right.as_mut().tree_parent = None;
self.tree_right = None;
right.as_mut().path_parent = Some(NonNull::from(self));
}
}
// Merge current auxiliary tree with the upper-level one.
// The merge process makes sure the current node is the "deepest" one in the merged auxiliary tree (by cutting off
// irrelevant subtrees).
// After the extension, current node is the root of the merged tree.
// Return false if there is no upper level path.
unsafe fn extend_upper_level_path(&mut self) -> bool {
self.splay();
match self.path_parent {
None => false,
Some(mut upper) => {
(*upper.as_mut()).separate_deeper_nodes();
self.tree_parent = Some(upper);
self.path_parent = None;
(*upper.as_mut()).tree_right = Some(NonNull::from(self));
true
}
}
}
// Extend the auxiliary tree all the way to root.
// After the extension, current node is the root of its auxiliary tree.
unsafe fn extend_to_root(&mut self) {
self.separate_deeper_nodes();
while self.extend_upper_level_path() {}
}
// Lift the node to the root of its tree (not the auxiliary tree).
// To do so, we first extend the auxiliary tree to root, which represents the path from root to the current node.
// To set the current node as root, we reverse the order of the auxiliary tree such that previous
// root (who has the least depth) now has the deepest depth and the current node (who has the deepest depth) now has
// the lowest depth.
unsafe fn lift_to_root(&mut self) {
self.extend_to_root();
self.splay();
self.tree_reversed = !self.tree_reversed;
}
unsafe fn find_min(&mut self) -> NonNull<Node> {
let mut current = NonNull::from(self);
current.as_mut().push_down();
while let Some(left) = current.as_ref().tree_left {
current = left;
current.as_mut().push_down();
}
current.as_mut().splay();
current
}
}
impl Handle {
pub fn new() -> Self {
Node::allocate()
}
pub fn is_connected(&self, other: &Self) -> bool {
if Rc::ptr_eq(&self.target, &other.target) {
return true;
}
unsafe {
(*self.target.get()).lift_to_root();
(*other.target.get()).extend_to_root();
(*other.target.get()).splay();
core::ptr::eq((*other.target.get()).find_min().as_mut(), self.target.get())
}
}
/// # Safety
/// The caller must ensure that the two handles are not connected.
pub unsafe fn connect_unchecked(&self, other: &Self) -> Connection {
(*other.target.get()).lift_to_root();
(*other.target.get()).path_parent = Some(NonNull::from(&mut *self.target.get()));
Connection {
targets: [self.target.clone(), other.target.clone()],
}
}
pub fn connect(&self, other: &Self) -> Option<Connection> {
if self.is_connected(other) {
return None;
}
Some(unsafe { self.connect_unchecked(other) })
}
}
impl Default for Handle {
fn default() -> Self {
Self::new()
}
}
impl PartialEq for Handle {
fn eq(&self, other: &Self) -> bool {
Rc::ptr_eq(&self.target, &other.target)
}
}
impl Eq for Handle {}
impl Drop for Connection {
fn drop(&mut self) {
unsafe {
(*self.targets[0].get()).lift_to_root();
(*self.targets[1].get()).extend_to_root();
(*self.targets[1].get()).splay();
(*self.targets[1].get()).push_down();
debug_assert!((*self.targets[1].get()).tree_left.is_some());
(*self.targets[1]
.get())
.tree_left
.unwrap_unchecked()
.as_mut()
.tree_parent = None;
(*self.targets[1].get()).tree_left = None;
}
}
}
#[cfg(test)]
mod test {
extern crate std;
use super::*;
use alloc::vec::Vec;
#[test]
fn test_trivial_connection() {
let a = Handle::new();
let b = Handle::new();
let c = Handle::new();
let d = Handle::new();
let e = Handle::new();
let handles = [a.clone(), b.clone(), c.clone(), d.clone(), e.clone()];
for i in handles.iter() {
for j in handles.iter() {
assert!(!i.is_connected(j) || i == j);
}
}
let _ab = a.connect(&b).unwrap();
let cd = c.connect(&d).unwrap();
assert!(a.is_connected(&b));
assert!(b.is_connected(&a));
assert!(c.is_connected(&d));
assert!(d.is_connected(&c));
for i in [a.clone(), b.clone()] {
for j in [c.clone(), d.clone(), e.clone()] {
assert!(!i.is_connected(&j));
}
}
for i in [c.clone(), d.clone()] {
for j in [a.clone(), b.clone(), e.clone()] {
assert!(!i.is_connected(&j));
}
}
for i in [a.clone(), b.clone(), c.clone(), d.clone()] {
assert!(!i.is_connected(&e));
}
let _eb = e.connect(&b).unwrap();
let _ad = a.connect(&d).unwrap();
for i in handles.iter() {
for j in handles.iter() {
assert!(i.is_connected(j));
}
}
drop(cd);
for i in [a.clone(), b.clone(), d.clone(), e.clone()] {
for j in [a.clone(), b.clone(), d.clone(), e.clone()] {
assert!(i.is_connected(&j));
assert!(!c.is_connected(&i));
assert!(!i.is_connected(&c));
}
}
}
#[test]
fn test_large_forests() {
const LENGTH: usize = 1000;
const STEP: usize = LENGTH / 10;
let mut handles = Vec::new();
let mut connections = std::collections::HashMap::new();
for _ in 0..LENGTH {
handles.push(Handle::new());
}
for i in 1..LENGTH {
connections.insert((i - 1, i), handles[i - 1].connect(&handles[i]).unwrap());
}
for i in 0..LENGTH {
for j in 0..LENGTH {
assert!(handles[i].is_connected(&handles[j]));
}
}
for i in (STEP..LENGTH).step_by(STEP) {
connections.remove(&(i - 1, i));
}
for i in (0..LENGTH).step_by(STEP) {
for j in i..(i + STEP) {
for k in i..(i + STEP) {
assert!(handles[j].is_connected(&handles[k]));
}
}
}
for i in (0..LENGTH).step_by(STEP) {
for j in i..(i + STEP) {
for k in 0..i {
assert!(!handles[j].is_connected(&handles[k]));
}
for k in i + STEP..LENGTH {
assert!(!handles[j].is_connected(&handles[k]));
}
}
}
let mut count = 0;
for i in 0..(LENGTH / STEP - 1) {
let a = handles[count + STEP + i].clone();
let b = handles[count + i].clone();
let handle = a.connect(&b).unwrap();
connections.insert((count + i, count + STEP + i), handle);
count += STEP;
}
for i in 0..LENGTH {
for j in 0..LENGTH {
assert!(handles[i].is_connected(&handles[j]));
}
}
for i in (0..(LENGTH / 2 - STEP)).step_by(STEP) {
connections
.remove(&(i + (STEP / 2) - 1, i + (STEP / 2)))
.unwrap();
for j in (i + STEP / 2)..(i + STEP) {
for k in (i + STEP / 2)..(i + STEP) {
assert!(handles[j].is_connected(&handles[k]));
}
}
for j in (i + STEP / 2)..(i + STEP) {
for k in 0..(i + STEP / 2) {
assert!(!handles[j].is_connected(&handles[k]));
}
}
for j in (i + STEP / 2)..(i + STEP) {
for k in (i + STEP)..LENGTH {
assert!(!handles[j].is_connected(&handles[k]));
}
}
let a = handles[i + (STEP / 2) - 1].clone();
let mut b = handles[i + (STEP / 2)].clone();
connections.insert(
(i + (STEP / 2) - 1, i + (STEP / 2)),
a.connect(&mut b).unwrap(),
);
}
for i in 0..LENGTH {
for j in 0..LENGTH {
assert!(handles[i].is_connected(&handles[j]));
}
}
}
#[test]
fn test_random() {
const LENGTH: usize = 200;
use rand::Rng;
let mut rng = rand::thread_rng();
let mut handles = Vec::new();
let mut connections = std::collections::HashMap::new();
for _ in 0..LENGTH {
handles.push(Handle::new());
}
for _ in 0..10 * LENGTH {
let i = rng.gen_range(0..LENGTH - 1);
let j = rng.gen_range(0..(i + 1));
if i == j {
continue;
}
if let std::collections::hash_map::Entry::Vacant(e) = connections.entry((j, i)) {
let a = handles[i].clone();
let b = handles[j].clone();
if let Some(h) = a.connect(&b) {
e.insert(h);
}
} else {
assert!(handles[i].is_connected(&handles[j]));
connections.remove(&(j, i));
assert!(!handles[j].is_connected(&handles[i]));
}
}
for i in 0..LENGTH {
for j in i..LENGTH {
if i == j || connections.contains_key(&(i, j)) {
assert!(handles[i].is_connected(&handles[j]));
}
}
}
for i in 0..LENGTH {
// symmetric
assert!(handles[i].is_connected(&handles[i]));
for j in i..LENGTH {
if handles[i].is_connected(&handles[j]) {
// reflexive
assert!(handles[j].is_connected(&handles[i]));
for k in 0..LENGTH {
if handles[j].is_connected(&handles[k]) {
// transitive
assert!(handles[i].is_connected(&handles[k]));
}
}
}
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment