// Package aa implements immutable AA trees. package aa import "cmp" // Tree is an immutable AA tree, // a form of self-balancing binary search tree. // // Use *Tree as reference type; the zero value for *Tree (nil) is the empty tree: // // var empty *aa.Tree[int, string] // one := empty.Put(1, "one") // one.Has(1) ⟹ true // // Note: the zero value for Tree{} is a valid, but non-empty, tree. type Tree[K cmp.Ordered, V any] struct { left *Tree[K, V] right *Tree[K, V] key K value V level int8 } // Key returns the key at the root of this tree. // // Note: getting the root key of an empty tree (nil) // causes a runtime panic. func (tree *Tree[K, V]) Key() K { return tree.key } // Value returns the value at the root of this tree. // // Note: getting the root value of an empty tree (nil) // causes a runtime panic. func (tree *Tree[K, V]) Value() V { return tree.value } // Left returns the left subtree of this tree, // containing all keys less than its root key. // // Note: the left subtree of the empty tree is the empty tree (nil). func (tree *Tree[K, V]) Left() *Tree[K, V] { if tree == nil { return nil } return tree.left } // Right returns the right subtree of this tree, // containing all keys greater than its root key. // // Note: the right subtree of the empty tree is the empty tree (nil). func (tree *Tree[K, V]) Right() *Tree[K, V] { if tree == nil { return nil } return tree.right } // Level returns the level of this AA tree. // // Notes: // - the level of the empty tree (nil) is 0. // - the height of a tree of level n is between n and 2·n. // - the number of nodes in a tree of level n is between 2ⁿ-1 and 3ⁿ-1. func (tree *Tree[K, V]) Level() int { if tree == nil { return 0 } return int(tree.level) + 1 } // Len counts the number of nodes in this tree. func (tree *Tree[K, V]) Len() int { // AA trees lean right, so recurse to the left. var len int for tree != nil { len += tree.left.Len() + 1 tree = tree.right } return len } // Min finds the least key in this tree, // and returns the node for that key, // or nil if this tree is empty. func (tree *Tree[K, V]) Min() *Tree[K, V] { if tree == nil { return nil } for tree.left != nil { tree = tree.left } return tree } // Max finds the greatest key in this tree, // and returns the node for that key, // or nil if this tree is empty. func (tree *Tree[K, V]) Max() *Tree[K, V] { if tree == nil { return nil } for tree.right != nil { tree = tree.right } return tree } // Floor finds the greatest key in this tree less-than or equal-to key, // and returns the node for that key, // or nil if no such key exists in this tree. func (tree *Tree[K, V]) Floor(key K) *Tree[K, V] { var node *Tree[K, V] for tree != nil { if cmp.Less(key, tree.key) { tree = tree.left } else { node = tree tree = tree.right } } return node } // Ceil finds the least key in this tree greater-than or equal-to key, // and returns the node for that key, // or nil if no such key exists in this tree. func (tree *Tree[K, V]) Ceil(key K) *Tree[K, V] { var node *Tree[K, V] for tree != nil { if cmp.Less(tree.key, key) { tree = tree.right } else { node = tree tree = tree.left } } return node } // Get retrieves the value for a given key; // found indicates whether key exists in this tree. func (tree *Tree[K, V]) Get(key K) (value V, found bool) { // Floor uses 2-way search, which is faster for strings: // https://go.dev/issue/71270 // https://user.it.uu.se/~arnea/ps/searchproc.pdf node := tree.Floor(key) if node != nil && (key == node.key || key != key) { return node.value, true } return // zero, false } // Has reports whether key exists in this tree. func (tree *Tree[K, V]) Has(key K) bool { _, found := tree.Get(key) return found } // Put returns a modified tree with key set to value. // // tree.Put(key, value).Get(key) ⟹ (value, true) func (tree *Tree[K, V]) Put(key K, value V) *Tree[K, V] { return tree.Patch(key, func(*Tree[K, V]) (V, bool) { return value, true }) } // Add returns a (possibly) modified tree that contains key. // // tree.Add(key).Has(key) ⟹ true func (tree *Tree[K, V]) Add(key K) *Tree[K, V] { return tree.Patch(key, func(node *Tree[K, V]) (value V, ok bool) { return value, node == nil }) } // Patch finds key in this tree, calls update with the node for that key // (or nil, if key is not found), and returns a (possibly) modified tree. // // The update callback can opt to set/update the value for the key, // by returning (value, true), or not, by returning false. func (tree *Tree[K, V]) Patch(key K, update func(node *Tree[K, V]) (value V, ok bool)) *Tree[K, V] { if tree == nil { if value, ok := update(tree); ok { return &Tree[K, V]{key: key, value: value} } return nil } switch cmp.Compare(key, tree.key) { default: if value, ok := update(tree); ok { copy := *tree copy.value = value return © } return tree case -1: left := tree.left.Patch(key, update) if left == tree.left { return tree } copy := *tree copy.left = left return copy.ins_rebalance() case +1: right := tree.right.Patch(key, update) if right == tree.right { return tree } copy := *tree copy.right = right return copy.ins_rebalance() } } // Delete returns a (possibly) modified tree with key removed from it. // // tree.Delete(key).Has(key) ⟹ false func (tree *Tree[K, V]) Delete(key K) *Tree[K, V] { if tree == nil { return nil } switch cmp.Compare(key, tree.key) { case -1: left := tree.left.Delete(key) if left == tree.left { return tree } copy := *tree copy.left = left return copy.del_rebalance() case +1: right := tree.right.Delete(key) if right == tree.right { return tree } copy := *tree copy.right = right return copy.del_rebalance() default: if tree.left == nil { return tree.right } var heir *Tree[K, V] copy := *tree copy.left, heir = tree.left.DeleteMax() copy.key = heir.key copy.value = heir.value return copy.del_rebalance() } } // DeleteMin returns a modified tree with its least key removed from it, // and the removed node. func (tree *Tree[K, V]) DeleteMin() (_, node *Tree[K, V]) { if tree == nil { return nil, nil } if tree.left == nil { return tree.right, tree } copy := *tree copy.left, node = tree.left.DeleteMin() return copy.del_rebalance(), node } // DeleteMax returns a modified tree with its greatest key removed from it, // and the removed node. func (tree *Tree[K, V]) DeleteMax() (_, node *Tree[K, V]) { if tree == nil { return nil, nil } if tree.right == nil { return tree.left, tree } copy := *tree copy.right, node = tree.right.DeleteMax() return copy.del_rebalance(), node }
package aa import ( "cmp" "iter" ) // Ascend returns an ascending iterator for this tree. func (tree *Tree[K, V]) Ascend() iter.Seq2[K, V] { return func(yield func(K, V) bool) { tree.ascend(yield) } } func (tree *Tree[K, V]) ascend(yield func(K, V) bool) bool { if tree == nil { return true } return tree.left.ascend(yield) && yield(tree.key, tree.value) && tree.right.ascend(yield) } // Descend returns a descending iterator for this tree. func (tree *Tree[K, V]) Descend() iter.Seq2[K, V] { return func(yield func(K, V) bool) { tree.descend(yield) } } func (tree *Tree[K, V]) descend(yield func(K, V) bool) bool { if tree == nil { return true } return tree.right.descend(yield) && yield(tree.key, tree.value) && tree.left.descend(yield) } // AscendCeil returns an ascending iterator for this tree, // starting at the least key in this tree greater-than or equal-to pivot. func (tree *Tree[K, V]) AscendCeil(pivot K) iter.Seq2[K, V] { return func(yield func(K, V) bool) { tree.ascendCeil(pivot, yield) } } func (tree *Tree[K, V]) ascendCeil(pivot K, yield func(K, V) bool) bool { for tree != nil { if !cmp.Less(tree.key, pivot) { return tree.left.ascendCeil(pivot, yield) && yield(tree.key, tree.value) && tree.right.ascend(yield) } tree = tree.right } return true } // AscendFloor returns an ascending iterator for this tree, // starting at the greatest key in this tree less-than or equal-to pivot. func (tree *Tree[K, V]) AscendFloor(pivot K) iter.Seq2[K, V] { return func(yield func(K, V) bool) { tree.ascendFloor(pivot, nil, yield) } } func (tree *Tree[K, V]) ascendFloor(pivot K, node *Tree[K, V], yield func(K, V) bool) bool { for tree != nil { if cmp.Less(pivot, tree.key) { return tree.left.ascendFloor(pivot, node, yield) && yield(tree.key, tree.value) && tree.right.ascend(yield) } node = tree tree = tree.right } if node != nil { return yield(node.key, node.value) } return true } // DescendFloor returns a descending iterator for this tree, // starting at the greatest key in this tree less-than or equal-to pivot. func (tree *Tree[K, V]) DescendFloor(pivot K) iter.Seq2[K, V] { return func(yield func(K, V) bool) { tree.descendFloor(pivot, yield) } } func (tree *Tree[K, V]) descendFloor(pivot K, yield func(K, V) bool) bool { for tree != nil { if !cmp.Less(pivot, tree.key) { return tree.right.descendFloor(pivot, yield) && yield(tree.key, tree.value) && tree.left.descend(yield) } tree = tree.left } return true } // DescendCeil returns a descending iterator for this tree, // starting at the least key in this tree greater-than or equal-to pivot. func (tree *Tree[K, V]) DescendCeil(pivot K) iter.Seq2[K, V] { return func(yield func(K, V) bool) { tree.descendCeil(pivot, nil, yield) } } func (tree *Tree[K, V]) descendCeil(pivot K, node *Tree[K, V], yield func(K, V) bool) bool { for tree != nil { if cmp.Less(tree.key, pivot) { return tree.right.descendCeil(pivot, node, yield) && yield(tree.key, tree.value) && tree.left.descend(yield) } node = tree tree = tree.left } if node != nil { return yield(node.key, node.value) } return true }
package aa func (tree *Tree[K, V]) ins_rebalance() *Tree[K, V] { if tree.need_raise() { // avoid 2 rotations and allocs tree.level++ return tree } return tree.skew().split() } func (tree *Tree[K, V]) del_rebalance() *Tree[K, V] { var want int8 if tree.left != nil && tree.right != nil { want = 1 + min(tree.left.level, tree.right.level) } if tree.level > want { tree.level = want if tree.right != nil && tree.right.level > want { copy := *tree.right copy.level = want tree.right = © } return tree.skew_rec().split_rec() } return tree } func (tree *Tree[K, V]) need_skew() bool { return tree != nil && tree.left != nil && tree.left.level == tree.level } func (tree *Tree[K, V]) need_split() bool { return tree != nil && tree.right != nil && tree.right.right != nil && tree.right.right.level == tree.level } func (tree *Tree[K, V]) need_raise() bool { return tree != nil && tree.left != nil && tree.right != nil && tree.left.level == tree.level && tree.right.level == tree.level } func (tree *Tree[K, V]) skew() *Tree[K, V] { if tree.need_skew() { copy := *tree.left tree.left = copy.right copy.right = tree return © } return tree } func (tree *Tree[K, V]) skew_rec() *Tree[K, V] { if tree.need_skew() { copy := *tree.left tree.left = copy.right copy.right = tree.skew_rec() return © } if tree.right.need_skew() { node := *tree.right tree.right = node.skew_rec() } return tree } func (tree *Tree[K, V]) split() *Tree[K, V] { if tree.need_split() { copy := *tree.right tree.right = copy.left copy.left = tree copy.level++ return © } return tree } func (tree *Tree[K, V]) split_rec() *Tree[K, V] { if tree.need_split() { copy := *tree.right tree.right = copy.left copy.left = tree copy.level++ if copy.right.need_split() { node := *copy.right copy.right = node.split() } return © } return tree }
package aa import "cmp" // Split partitions this tree around a key. It returns // a left tree with keys less than key, // a right tree with keys greater than key, // and the node for that key. func (tree *Tree[K, V]) Split(key K) (left, node, right *Tree[K, V]) { if tree == nil { return nil, nil, nil } switch cmp.Compare(key, tree.key) { default: return tree.left, tree, tree.right case -1: left, node, right = tree.left.Split(key) return left, node, join(right, tree, tree.right) case +1: left, node, right = tree.right.Split(key) return join(tree.left, tree, left), node, right } } // Union returns the set union of two trees. // For keys in both trees, the value from t1 is retained. func Union[K cmp.Ordered, V any](t1, t2 *Tree[K, V]) *Tree[K, V] { if t1 == nil { return t2 } if t2 == nil { return t1 } left, _, right := t2.Split(t1.key) left = Union(t1.left, left) right = Union(t1.right, right) return join(left, t1, right) } // Intersection returns the set intersection of two trees. // Values are taken from t1. func Intersection[K cmp.Ordered, V any](t1, t2 *Tree[K, V]) *Tree[K, V] { if t1 == nil { return nil } if t2 == nil { return nil } left, node, right := t1.Split(t2.key) left = Intersection(left, t2.left) right = Intersection(right, t2.right) if node == nil { return join2(left, right) } return join(left, node, right) } // Difference returns the set difference of two trees. func Difference[K cmp.Ordered, V any](t1, t2 *Tree[K, V]) *Tree[K, V] { if t1 == nil { return nil } if t2 == nil { return t1 } left, _, right := t1.Split(t2.key) left = Difference(left, t2.left) right = Difference(right, t2.right) return join2(left, right) } func join[K cmp.Ordered, V any](left, node, right *Tree[K, V]) *Tree[K, V] { ll := left.Level() rl := right.Level() switch { case ll < rl: copy := *right copy.left = join(left, node, copy.left) return copy.ins_rebalance() case ll > rl: copy := *left copy.right = join(copy.right, node, right) return copy.ins_rebalance() default: return &Tree[K, V]{ left: left, right: right, key: node.key, value: node.value, level: int8(ll), // left.level + 1 } } } func join2[K cmp.Ordered, V any](left, right *Tree[K, V]) *Tree[K, V] { if left == nil { return right } if right == nil { return left } left, node := left.DeleteMax() return join(left, node, right) }