Implementing Binary Trees, Tries, and Union-Find in Rust
Binary Tree
Implemented with Option<Box<T>>. Box is a smart pointer that allocates on the heap, ideal for recursive types that would ohterwise have unbounded size. LeetCode often uses Option<Rc<RefCell<T>>>, which is quite bloated.
#[derive(PartialEq)]
enum ChildSide {
LEFT,
RIGHT,
}
#[derive(Default)]
struct BinaryNode {
value: i32,
left: Option<Box<BinaryNode>>,
right: Option<Box<BinaryNode>>,
}
impl BinaryNode {
pub fn new() -> Self {
Self { value: 0, left: None, right: None }
}
pub fn new_with_value(val: i32) -> Self {
Self { value: val, left: None, right: None }
}
pub fn attach_child(&mut self, child: BinaryNode, side: ChildSide) {
match side {
ChildSide::LEFT => self.left = Some(Box::new(child)),
ChildSide::RIGHT => self.right = Some(Box::new(child)),
}
}
pub fn detach_child(&mut self, side: ChildSide) {
match side {
ChildSide::LEFT => self.left = None,
ChildSide::RIGHT => self.right = None,
}
}
pub fn preorder(&self) {
println!("{}", self.value);
if let Some(ref left_node) = self.left {
left_node.preorder();
}
if let Some(ref right_node) = self.right {
right_node.preorder();
}
}
}
fn main() {
let mut root = BinaryNode::new();
root.attach_child(BinaryNode::new_with_value(5), ChildSide::LEFT);
if let Some(ref mut left_child) = root.left {
left_child.attach_child(BinaryNode::new_with_value(15), ChildSide::RIGHT);
}
root.preorder();
let p = root.left.as_ref().unwrap().right.as_ref().unwrap();
println!("Before detach: {}", p.value);
root.detach_child(ChildSide::LEFT);
root.attach_child(BinaryNode::new_with_value(12), ChildSide::RIGHT);
// println!("After detach: {}", p.value); // fails to compile: value borrowed after move
}
Trie
Built using HashMap<char, TrieNode>.
use std::collections::HashMap;
#[derive(Default, Debug)]
struct TrieNode {
is_terminal: bool,
child_nodes: HashMap<char, TrieNode>,
}
#[derive(Default, Debug)]
pub struct Trie {
root: TrieNode,
}
impl Trie {
pub fn new() -> Self {
Trie { root: TrieNode::default() }
}
pub fn insert(&mut self, word: &str) {
let mut node = &mut self.root;
for ch in word.chars() {
node = node.child_nodes.entry(ch).or_default();
}
node.is_terminal = true;
}
pub fn contains(&self, word: &str) -> bool {
let mut node = &self.root;
for ch in word.chars() {
match node.child_nodes.get(&ch) {
Some(next) => node = next,
None => return false,
}
}
node.is_terminal
}
}
fn main() {
let mut trie = Trie::new();
trie.insert("hello");
trie.insert("hi");
trie.insert("hey");
trie.insert("world");
println!("contains 'hiiii'? {}", trie.contains("hiiii"));
}
For a faster alternative, replace the first few lines with:
use std::collections::HashMap;
use fxhash::FxBuildHasher;
type FastHashMap<K, V> = HashMap<K, V, FxBuildHasher>;
#[derive(Default)]
struct TrieNode {
is_terminal: bool,
child_nodes: FastHashMap<char, TrieNode>,
}
This uses a simpler, faster hash functon.
Union-Find
use std::collections::HashMap;
struct DisjointSet {
parent: Vec<usize>,
}
impl DisjointSet {
pub fn with_size(size: usize) -> Self {
DisjointSet { parent: (0..size).collect() }
}
pub fn set_count(&self) -> usize {
(0..self.parent.len()).filter(|&i| self.parent[i] == i).count()
}
pub fn list_sets(&mut self) -> Vec<Vec<usize>> {
let mut groups: HashMap<usize, Vec<usize>> = HashMap::new();
for i in 0..self.parent.len() {
groups.entry(self.root_of(i)).or_default().push(i);
}
groups.into_values().collect()
}
pub fn root_of(&mut self, index: usize) -> usize {
if index != self.parent[index] {
self.parent[index] = self.root_of(self.parent[index]);
}
self.parent[index]
}
pub fn unite(&mut self, a: usize, b: usize) {
let root_a = self.root_of(a);
let root_b = self.root_of(b);
if root_a != root_b {
self.parent[root_a] = root_b;
}
}
pub fn is_same_set(&mut self, a: usize, b: usize) -> bool {
self.root_of(a) == self.root_of(b)
}
}
fn main() {
let n = 5;
let pairs = vec![vec![0, 2], vec![0, 3]];
let mut ds = DisjointSet::with_size(n);
for p in pairs {
ds.unite(p[0], p[1]);
}
println!("Number of sets: {}", ds.set_count()); // output: Number of sets: 3
println!("Disjoint sets: {:?}", ds.list_sets()); // e.g., [[0,2,3],[1],[4]] (order may vary)
}