Last active
April 9, 2023 02:33
-
-
Save SchrodingerZhu/8255a94ead7c3c50c297fb83a4c08ad8 to your computer and use it in GitHub Desktop.
dynforest
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #![no_std] | |
| extern crate alloc; | |
| use alloc::rc::Rc; | |
| use core::cell::UnsafeCell; | |
| use core::ptr::NonNull; | |
| struct Node { | |
| tree_parent: Option<NonNull<Self>>, | |
| tree_left: Option<NonNull<Self>>, | |
| tree_right: Option<NonNull<Self>>, | |
| path_parent: Option<NonNull<Self>>, | |
| tree_reversed: bool, | |
| } | |
| #[derive(Clone)] | |
| pub struct Handle { | |
| target: Rc<UnsafeCell<Node>>, | |
| } | |
| pub struct Connection { | |
| targets: [Rc<UnsafeCell<Node>>; 2], | |
| } | |
| impl Node { | |
| fn get_child(&self, is_right: bool) -> Option<NonNull<Self>> { | |
| if is_right { | |
| self.tree_right | |
| } else { | |
| self.tree_left | |
| } | |
| } | |
| fn set_child(&mut self, is_right: bool, node: Option<NonNull<Self>>) { | |
| if is_right { | |
| self.tree_right = node; | |
| } else { | |
| self.tree_left = node; | |
| } | |
| } | |
| unsafe fn is_right_child(&self) -> bool { | |
| match self.tree_parent { | |
| None => false, | |
| Some(x) => unsafe { | |
| match x.as_ref().tree_right.as_ref() { | |
| None => false, | |
| Some(x) => core::ptr::eq(x.as_ref(), self), | |
| } | |
| }, | |
| } | |
| } | |
| fn allocate() -> Handle { | |
| Handle { | |
| target: Rc::new(UnsafeCell::from(Self { | |
| tree_parent: None, | |
| tree_left: None, | |
| tree_right: None, | |
| path_parent: None, | |
| tree_reversed: false, | |
| })), | |
| } | |
| } | |
| unsafe fn push_down(&mut self) { | |
| if self.tree_reversed { | |
| core::mem::swap(&mut self.tree_left, &mut self.tree_right); | |
| if let Some(mut left) = self.tree_left { | |
| (*left.as_mut()).tree_reversed = !(*left.as_mut()).tree_reversed; | |
| } | |
| if let Some(mut right) = self.tree_right { | |
| right.as_mut().tree_reversed = !right.as_mut().tree_reversed; | |
| } | |
| self.tree_reversed = false; | |
| } | |
| } | |
| unsafe fn unchecked_mut_parent(&mut self) -> &mut Node { | |
| self.tree_parent.unwrap_unchecked().as_mut() | |
| } | |
| unsafe fn rotate(&mut self) { | |
| debug_assert!(self.tree_parent.is_some()); | |
| let mut this = NonNull::from(self); | |
| // First, check the `reversed` flags for all touched nodes. Push down the flags if needed. | |
| if let Some(mut grandparent) = this.as_mut().unchecked_mut_parent().tree_parent { | |
| grandparent.as_mut().push_down(); | |
| } | |
| this.as_mut().unchecked_mut_parent().push_down(); | |
| this.as_mut().push_down(); | |
| // Secondly, during the process of going up, the lower node also need to hand over the pointer to upper-level | |
| // paths so that only the root for each auxiliary tree can have non-null upper pointers. | |
| let tmp = this.as_mut().unchecked_mut_parent().path_parent; | |
| this.as_mut().unchecked_mut_parent().path_parent = this.as_ref().path_parent; | |
| this.as_mut().path_parent = tmp; | |
| // Prepare to swap pointers | |
| let is_right = this.as_ref().is_right_child(); | |
| let mut parent_backup = this.as_ref().tree_parent.unwrap_unchecked(); | |
| // Update parents | |
| if let Some(mut grandparent) = parent_backup.as_ref().tree_parent { | |
| grandparent | |
| .as_mut() | |
| .set_child(parent_backup.as_ref().is_right_child(), Some(this)); | |
| } | |
| this.as_mut().tree_parent = parent_backup.as_ref().tree_parent; | |
| // Update children | |
| let target_child = this.as_ref().get_child(!is_right); | |
| parent_backup.as_mut().set_child(is_right, target_child); | |
| if let Some(mut target_child) = target_child { | |
| target_child.as_mut().tree_parent = Some(parent_backup); | |
| } | |
| // Update current node | |
| this.as_mut().set_child(!is_right, Some(parent_backup)); | |
| parent_backup.as_mut().tree_parent = Some(this); | |
| } | |
| // Keep rotating until current node reaches the root | |
| unsafe fn splay(&mut self) { | |
| while let Some(mut parent) = self.tree_parent { | |
| match parent.as_ref().tree_parent { | |
| None => self.rotate(), | |
| Some(mut grandparent) => { | |
| grandparent.as_mut().push_down(); | |
| parent.as_mut().push_down(); | |
| if self.is_right_child() == parent.as_ref().is_right_child() { | |
| parent.as_mut().rotate(); | |
| self.rotate(); | |
| } else { | |
| self.rotate(); | |
| self.rotate(); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| // Separate current auxiliary tree into two smaller trees such that all nodes that are deeper than the current node | |
| // (those who come later during in-order traversal) are cut off from the auxiliary tree. | |
| // After the separation, current node is the root of its auxiliary tree. | |
| unsafe fn separate_deeper_nodes(&mut self) { | |
| self.splay(); | |
| self.push_down(); | |
| if let Some(mut right) = self.tree_right { | |
| right.as_mut().tree_parent = None; | |
| self.tree_right = None; | |
| right.as_mut().path_parent = Some(NonNull::from(self)); | |
| } | |
| } | |
| // Merge current auxiliary tree with the upper-level one. | |
| // The merge process makes sure the current node is the "deepest" one in the merged auxiliary tree (by cutting off | |
| // irrelevant subtrees). | |
| // After the extension, current node is the root of the merged tree. | |
| // Return false if there is no upper level path. | |
| unsafe fn extend_upper_level_path(&mut self) -> bool { | |
| self.splay(); | |
| match self.path_parent { | |
| None => false, | |
| Some(mut upper) => { | |
| (*upper.as_mut()).separate_deeper_nodes(); | |
| self.tree_parent = Some(upper); | |
| self.path_parent = None; | |
| (*upper.as_mut()).tree_right = Some(NonNull::from(self)); | |
| true | |
| } | |
| } | |
| } | |
| // Extend the auxiliary tree all the way to root. | |
| // After the extension, current node is the root of its auxiliary tree. | |
| unsafe fn extend_to_root(&mut self) { | |
| self.separate_deeper_nodes(); | |
| while self.extend_upper_level_path() {} | |
| } | |
| // Lift the node to the root of its tree (not the auxiliary tree). | |
| // To do so, we first extend the auxiliary tree to root, which represents the path from root to the current node. | |
| // To set the current node as root, we reverse the order of the auxiliary tree such that previous | |
| // root (who has the least depth) now has the deepest depth and the current node (who has the deepest depth) now has | |
| // the lowest depth. | |
| unsafe fn lift_to_root(&mut self) { | |
| self.extend_to_root(); | |
| self.splay(); | |
| self.tree_reversed = !self.tree_reversed; | |
| } | |
| unsafe fn find_min(&mut self) -> NonNull<Node> { | |
| let mut current = NonNull::from(self); | |
| current.as_mut().push_down(); | |
| while let Some(left) = current.as_ref().tree_left { | |
| current = left; | |
| current.as_mut().push_down(); | |
| } | |
| current.as_mut().splay(); | |
| current | |
| } | |
| } | |
| impl Handle { | |
| pub fn new() -> Self { | |
| Node::allocate() | |
| } | |
| pub fn is_connected(&self, other: &Self) -> bool { | |
| if Rc::ptr_eq(&self.target, &other.target) { | |
| return true; | |
| } | |
| unsafe { | |
| (*self.target.get()).lift_to_root(); | |
| (*other.target.get()).extend_to_root(); | |
| (*other.target.get()).splay(); | |
| core::ptr::eq((*other.target.get()).find_min().as_mut(), self.target.get()) | |
| } | |
| } | |
| /// # Safety | |
| /// The caller must ensure that the two handles are not connected. | |
| pub unsafe fn connect_unchecked(&self, other: &Self) -> Connection { | |
| (*other.target.get()).lift_to_root(); | |
| (*other.target.get()).path_parent = Some(NonNull::from(&mut *self.target.get())); | |
| Connection { | |
| targets: [self.target.clone(), other.target.clone()], | |
| } | |
| } | |
| pub fn connect(&self, other: &Self) -> Option<Connection> { | |
| if self.is_connected(other) { | |
| return None; | |
| } | |
| Some(unsafe { self.connect_unchecked(other) }) | |
| } | |
| } | |
| impl Default for Handle { | |
| fn default() -> Self { | |
| Self::new() | |
| } | |
| } | |
| impl PartialEq for Handle { | |
| fn eq(&self, other: &Self) -> bool { | |
| Rc::ptr_eq(&self.target, &other.target) | |
| } | |
| } | |
| impl Eq for Handle {} | |
| impl Drop for Connection { | |
| fn drop(&mut self) { | |
| unsafe { | |
| (*self.targets[0].get()).lift_to_root(); | |
| (*self.targets[1].get()).extend_to_root(); | |
| (*self.targets[1].get()).splay(); | |
| (*self.targets[1].get()).push_down(); | |
| debug_assert!((*self.targets[1].get()).tree_left.is_some()); | |
| (*self.targets[1] | |
| .get()) | |
| .tree_left | |
| .unwrap_unchecked() | |
| .as_mut() | |
| .tree_parent = None; | |
| (*self.targets[1].get()).tree_left = None; | |
| } | |
| } | |
| } | |
| #[cfg(test)] | |
| mod test { | |
| extern crate std; | |
| use super::*; | |
| use alloc::vec::Vec; | |
| #[test] | |
| fn test_trivial_connection() { | |
| let a = Handle::new(); | |
| let b = Handle::new(); | |
| let c = Handle::new(); | |
| let d = Handle::new(); | |
| let e = Handle::new(); | |
| let handles = [a.clone(), b.clone(), c.clone(), d.clone(), e.clone()]; | |
| for i in handles.iter() { | |
| for j in handles.iter() { | |
| assert!(!i.is_connected(j) || i == j); | |
| } | |
| } | |
| let _ab = a.connect(&b).unwrap(); | |
| let cd = c.connect(&d).unwrap(); | |
| assert!(a.is_connected(&b)); | |
| assert!(b.is_connected(&a)); | |
| assert!(c.is_connected(&d)); | |
| assert!(d.is_connected(&c)); | |
| for i in [a.clone(), b.clone()] { | |
| for j in [c.clone(), d.clone(), e.clone()] { | |
| assert!(!i.is_connected(&j)); | |
| } | |
| } | |
| for i in [c.clone(), d.clone()] { | |
| for j in [a.clone(), b.clone(), e.clone()] { | |
| assert!(!i.is_connected(&j)); | |
| } | |
| } | |
| for i in [a.clone(), b.clone(), c.clone(), d.clone()] { | |
| assert!(!i.is_connected(&e)); | |
| } | |
| let _eb = e.connect(&b).unwrap(); | |
| let _ad = a.connect(&d).unwrap(); | |
| for i in handles.iter() { | |
| for j in handles.iter() { | |
| assert!(i.is_connected(j)); | |
| } | |
| } | |
| drop(cd); | |
| for i in [a.clone(), b.clone(), d.clone(), e.clone()] { | |
| for j in [a.clone(), b.clone(), d.clone(), e.clone()] { | |
| assert!(i.is_connected(&j)); | |
| assert!(!c.is_connected(&i)); | |
| assert!(!i.is_connected(&c)); | |
| } | |
| } | |
| } | |
| #[test] | |
| fn test_large_forests() { | |
| const LENGTH: usize = 1000; | |
| const STEP: usize = LENGTH / 10; | |
| let mut handles = Vec::new(); | |
| let mut connections = std::collections::HashMap::new(); | |
| for _ in 0..LENGTH { | |
| handles.push(Handle::new()); | |
| } | |
| for i in 1..LENGTH { | |
| connections.insert((i - 1, i), handles[i - 1].connect(&handles[i]).unwrap()); | |
| } | |
| for i in 0..LENGTH { | |
| for j in 0..LENGTH { | |
| assert!(handles[i].is_connected(&handles[j])); | |
| } | |
| } | |
| for i in (STEP..LENGTH).step_by(STEP) { | |
| connections.remove(&(i - 1, i)); | |
| } | |
| for i in (0..LENGTH).step_by(STEP) { | |
| for j in i..(i + STEP) { | |
| for k in i..(i + STEP) { | |
| assert!(handles[j].is_connected(&handles[k])); | |
| } | |
| } | |
| } | |
| for i in (0..LENGTH).step_by(STEP) { | |
| for j in i..(i + STEP) { | |
| for k in 0..i { | |
| assert!(!handles[j].is_connected(&handles[k])); | |
| } | |
| for k in i + STEP..LENGTH { | |
| assert!(!handles[j].is_connected(&handles[k])); | |
| } | |
| } | |
| } | |
| let mut count = 0; | |
| for i in 0..(LENGTH / STEP - 1) { | |
| let a = handles[count + STEP + i].clone(); | |
| let b = handles[count + i].clone(); | |
| let handle = a.connect(&b).unwrap(); | |
| connections.insert((count + i, count + STEP + i), handle); | |
| count += STEP; | |
| } | |
| for i in 0..LENGTH { | |
| for j in 0..LENGTH { | |
| assert!(handles[i].is_connected(&handles[j])); | |
| } | |
| } | |
| for i in (0..(LENGTH / 2 - STEP)).step_by(STEP) { | |
| connections | |
| .remove(&(i + (STEP / 2) - 1, i + (STEP / 2))) | |
| .unwrap(); | |
| for j in (i + STEP / 2)..(i + STEP) { | |
| for k in (i + STEP / 2)..(i + STEP) { | |
| assert!(handles[j].is_connected(&handles[k])); | |
| } | |
| } | |
| for j in (i + STEP / 2)..(i + STEP) { | |
| for k in 0..(i + STEP / 2) { | |
| assert!(!handles[j].is_connected(&handles[k])); | |
| } | |
| } | |
| for j in (i + STEP / 2)..(i + STEP) { | |
| for k in (i + STEP)..LENGTH { | |
| assert!(!handles[j].is_connected(&handles[k])); | |
| } | |
| } | |
| let a = handles[i + (STEP / 2) - 1].clone(); | |
| let mut b = handles[i + (STEP / 2)].clone(); | |
| connections.insert( | |
| (i + (STEP / 2) - 1, i + (STEP / 2)), | |
| a.connect(&mut b).unwrap(), | |
| ); | |
| } | |
| for i in 0..LENGTH { | |
| for j in 0..LENGTH { | |
| assert!(handles[i].is_connected(&handles[j])); | |
| } | |
| } | |
| } | |
| #[test] | |
| fn test_random() { | |
| const LENGTH: usize = 200; | |
| use rand::Rng; | |
| let mut rng = rand::thread_rng(); | |
| let mut handles = Vec::new(); | |
| let mut connections = std::collections::HashMap::new(); | |
| for _ in 0..LENGTH { | |
| handles.push(Handle::new()); | |
| } | |
| for _ in 0..10 * LENGTH { | |
| let i = rng.gen_range(0..LENGTH - 1); | |
| let j = rng.gen_range(0..(i + 1)); | |
| if i == j { | |
| continue; | |
| } | |
| if let std::collections::hash_map::Entry::Vacant(e) = connections.entry((j, i)) { | |
| let a = handles[i].clone(); | |
| let b = handles[j].clone(); | |
| if let Some(h) = a.connect(&b) { | |
| e.insert(h); | |
| } | |
| } else { | |
| assert!(handles[i].is_connected(&handles[j])); | |
| connections.remove(&(j, i)); | |
| assert!(!handles[j].is_connected(&handles[i])); | |
| } | |
| } | |
| for i in 0..LENGTH { | |
| for j in i..LENGTH { | |
| if i == j || connections.contains_key(&(i, j)) { | |
| assert!(handles[i].is_connected(&handles[j])); | |
| } | |
| } | |
| } | |
| for i in 0..LENGTH { | |
| // symmetric | |
| assert!(handles[i].is_connected(&handles[i])); | |
| for j in i..LENGTH { | |
| if handles[i].is_connected(&handles[j]) { | |
| // reflexive | |
| assert!(handles[j].is_connected(&handles[i])); | |
| for k in 0..LENGTH { | |
| if handles[j].is_connected(&handles[k]) { | |
| // transitive | |
| assert!(handles[i].is_connected(&handles[k])); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment