package aa import ( "math" "math/bits" ) const mask = bits.UintSize - 1 // Balance stores both the level and the cardinality of an AA tree, // both offset by one: the zero value represents a leaf node of level one. // // The lowest 5 or 6 bits store the level, // the highest 27 or 58 bits store cardinality. // // This is sufficient for trees up to 32 or 64 levels, // with 134217728 or 288230376151711744 elements. // // With the smallest possible node size (4 machine words), // you'd need to exhaust at least half the address space (2GiB or 8EiB) // before you could hit these limits. type balance uint func (b balance) level() int { return 1 + int(b&mask) } func (b balance) len() int { return 1 + int(b>>bits.Len(mask)) } func (tree *Tree[K, V]) setLevel(level int) { tree.balance = balance(level-1) | tree.balance&^mask } func (tree *Tree[K, V]) fixup() *Tree[K, V] { var sum uint64 if tree.left != nil { sum += 1<<bits.Len(mask) + uint64(tree.left.balance) + 1 } if tree.right != nil { sum += 1<<bits.Len(mask) + uint64(tree.right.balance)&^mask } if sum > math.MaxUint { panic("overflow") } tree.balance = balance(sum) return tree }
package aa import ( "cmp" "iter" ) // Equal reports whether two trees contain the same key/value pairs. func Equal[K cmp.Ordered, V comparable](t1, t2 *Tree[K, V]) bool { if t1 == t2 { return true } l1 := t1.Len() l2 := t2.Len() if l1 != l2 { return false } n1, s1 := iter.Pull2(t1.Ascend()) defer s1() n2, s2 := iter.Pull2(t2.Ascend()) defer s2() for range l1 { k1, v1, _ := n1() k2, v2, _ := n2() if (k1 != k2 && (k1 == k1 || k2 == k2)) || v1 != v2 { return false } } return true } // Subset reports whether t2 contains every t1 key/value pair. func Subset[K cmp.Ordered, V comparable](t1, t2 *Tree[K, V]) bool { if t1 == t2 { return true } l1 := t1.Len() if l1 == 0 { return true } l2 := t2.Len() if l1 > l2 { return false } n1, s1 := iter.Pull2(t1.Ascend()) defer s1() n2, s2 := iter.Pull2(t2.Ascend()) defer s2() loop1: for ; l1 <= l2 && l1 > 0; l1-- { k1, v1, _ := n1() for ; l1 <= l2 && l2 > 0; l2-- { k2, v2, _ := n2() switch cmp.Compare(k1, k2) { case -1: return false // won't find k1 in t2 case +1: continue } if v1 == v2 { continue loop1 } return false // found k1, v1 != v2 } return false // couldn't find k1 in t2 } return l1 == 0 // checked all k1 } // Overlap reports whether t1 contains any keys from t2. func Overlap[K cmp.Ordered, V any](t1, t2 *Tree[K, V]) bool { if t1 == nil || t2 == nil { return false } if t1 == t2 { return true } n1, s1 := iter.Pull2(t1.Ascend()) defer s1() n2, s2 := iter.Pull2(t2.Ascend()) defer s2() k1, _, ok1 := n1() k2, _, ok2 := n2() for ok1 && ok2 { switch cmp.Compare(k1, k2) { case -1: k1, _, ok1 = n1() case +1: k2, _, ok2 = n2() default: return true } } return false }
package aa import ( "cmp" "iter" ) // Ascend returns an ascending iterator for this tree. func (tree *Tree[K, V]) Ascend() iter.Seq2[K, V] { return func(yield func(K, V) bool) { tree.ascend(yield) } } func (tree *Tree[K, V]) ascend(yield func(K, V) bool) bool { if tree == nil { return true } return tree.left.ascend(yield) && yield(tree.key, tree.value) && tree.right.ascend(yield) } // Descend returns a descending iterator for this tree. func (tree *Tree[K, V]) Descend() iter.Seq2[K, V] { return func(yield func(K, V) bool) { tree.descend(yield) } } func (tree *Tree[K, V]) descend(yield func(K, V) bool) bool { if tree == nil { return true } return tree.right.descend(yield) && yield(tree.key, tree.value) && tree.left.descend(yield) } // AscendCeil returns an ascending iterator for this tree, // starting at the least key in this tree greater-than or equal-to pivot. func (tree *Tree[K, V]) AscendCeil(pivot K) iter.Seq2[K, V] { return func(yield func(K, V) bool) { tree.ascendCeil(pivot, yield) } } func (tree *Tree[K, V]) ascendCeil(pivot K, yield func(K, V) bool) bool { for tree != nil { if !cmp.Less(tree.key, pivot) { return tree.left.ascendCeil(pivot, yield) && yield(tree.key, tree.value) && tree.right.ascend(yield) } tree = tree.right } return true } // AscendFloor returns an ascending iterator for this tree, // starting at the greatest key in this tree less-than or equal-to pivot. func (tree *Tree[K, V]) AscendFloor(pivot K) iter.Seq2[K, V] { return func(yield func(K, V) bool) { tree.ascendFloor(pivot, nil, yield) } } func (tree *Tree[K, V]) ascendFloor(pivot K, node *Tree[K, V], yield func(K, V) bool) bool { for tree != nil { if cmp.Less(pivot, tree.key) { return tree.left.ascendFloor(pivot, node, yield) && yield(tree.key, tree.value) && tree.right.ascend(yield) } node = tree tree = tree.right } if node != nil { return yield(node.key, node.value) } return true } // DescendFloor returns a descending iterator for this tree, // starting at the greatest key in this tree less-than or equal-to pivot. func (tree *Tree[K, V]) DescendFloor(pivot K) iter.Seq2[K, V] { return func(yield func(K, V) bool) { tree.descendFloor(pivot, yield) } } func (tree *Tree[K, V]) descendFloor(pivot K, yield func(K, V) bool) bool { for tree != nil { if !cmp.Less(pivot, tree.key) { return tree.right.descendFloor(pivot, yield) && yield(tree.key, tree.value) && tree.left.descend(yield) } tree = tree.left } return true } // DescendCeil returns a descending iterator for this tree, // starting at the least key in this tree greater-than or equal-to pivot. func (tree *Tree[K, V]) DescendCeil(pivot K) iter.Seq2[K, V] { return func(yield func(K, V) bool) { tree.descendCeil(pivot, nil, yield) } } func (tree *Tree[K, V]) descendCeil(pivot K, node *Tree[K, V], yield func(K, V) bool) bool { for tree != nil { if cmp.Less(tree.key, pivot) { return tree.right.descendCeil(pivot, node, yield) && yield(tree.key, tree.value) && tree.left.descend(yield) } node = tree tree = tree.left } if node != nil { return yield(node.key, node.value) } return true }
package aa import "cmp" // Split partitions this tree around a key. It returns // a left tree with keys less than key, // a right tree with keys greater than key, // and the node for the key // (or nil if no such key exists in this tree). func (tree *Tree[K, V]) Split(key K) (left, node, right *Tree[K, V]) { if tree == nil { return nil, nil, nil } switch cmp.Compare(key, tree.key) { default: return tree.left, tree, tree.right case -1: left, node, right = tree.left.Split(key) return left, node, join(right, tree, tree.right) case +1: left, node, right = tree.right.Split(key) return join(tree.left, tree, left), node, right } } // Filter returns a tree of nodes for which pred returns true. func (tree *Tree[K, V]) Filter(pred func(node *Tree[K, V]) bool) *Tree[K, V] { if tree == nil { return nil } left := tree.left.Filter(pred) right := tree.right.Filter(pred) if pred(tree) { return join(left, tree, right) } return join2(left, right) } // Partition returns a tree of nodes for which pred returns true, // and a tree of nodes for which it returns false. func (tree *Tree[K, V]) Partition(pred func(node *Tree[K, V]) bool) (t, f *Tree[K, V]) { if tree == nil { return nil, nil } lt, lf := tree.left.Partition(pred) rt, rf := tree.right.Partition(pred) if pred(tree) { return join(lt, tree, rt), join2(lf, rf) } return join2(lt, rt), join(lf, tree, rf) } func join[K cmp.Ordered, V any](left, node, right *Tree[K, V]) *Tree[K, V] { if left == node.left && right == node.right { return node } llevel := left.Level() rlevel := right.Level() //nolint:nilaway if rlevel == llevel+1 && llevel == right.right.Level() { // Can we create a 3-node? rlevel = llevel // Avoid recursion, rebalancing. } switch { case llevel < rlevel: copy := *right //nolint:nilaway copy.left = join(left, node, right.left) return copy.ins_rebalance() case llevel > rlevel: copy := *left //nolint:nilaway copy.right = join(left.right, node, right) return copy.ins_rebalance() default: return makeNode(node.key, node.value, left, right) } } func join2[K cmp.Ordered, V any](left, right *Tree[K, V]) *Tree[K, V] { switch { case left == nil: return right case right == nil: return left } // Both DeleteMin/DeleteMax work; DeleteMin minimizes allocs. right, node := right.DeleteMin() return join(left, node, right) }
package aa import ( "cmp" "slices" ) // MakeSet builds a tree from a set of keys. func MakeSet[K cmp.Ordered](keys ...K) *Tree[K, struct{}] { if !increasing(keys) { keys = slices.Clone(keys) slices.Sort(keys) keys = slices.Compact(keys) } return makeTree[K, struct{}](keys, nil) } // MakeMap builds a tree from a key-value map. func MakeMap[K cmp.Ordered, V any](m map[K]V) *Tree[K, V] { i, keys := 0, make([]K, len(m)) for key := range m { keys[i] = key i++ } slices.Sort(keys) return makeTree(keys, m) } func makeTree[K cmp.Ordered, V any](keys []K, m map[K]V) *Tree[K, V] { if len(keys) == 0 { return nil } // AA trees lean right, so round down. mid := (len(keys) - 1) / 2 left := makeTree(keys[:mid], m) right := makeTree(keys[mid+1:], m) key := keys[mid] return makeNode(key, m[key], left, right) } func makeNode[K cmp.Ordered, V any](k K, v V, left, right *Tree[K, V]) *Tree[K, V] { node := Tree[K, V]{ left: left, right: right, key: k, value: v, } return node.fixup() } // Increasing tests if keys are in strictly increasing order. func increasing[K cmp.Ordered](keys []K) bool { for i := len(keys) - 1; i > 0; i-- { //nolint:nilaway if !cmp.Less(keys[i-1], keys[i]) { return false } } return true } // Collect collects key-value pairs from this tree // into a new map and returns it. func (tree *Tree[K, V]) Collect() map[K]V { m := make(map[K]V, tree.Len()) tree.collect(m) return m } func (tree *Tree[K, V]) collect(m map[K]V) { // AA trees lean right, so recurse left. for tree != nil { m[tree.key] = tree.value tree.left.collect(m) tree = tree.right } }
package aa func (tree *Tree[K, V]) ins_rebalance() *Tree[K, V] { if tree.need_raise() { // Avoid 2 rotations and allocs. return tree.fixup() } return tree.skew().split() } func (tree *Tree[K, V]) del_rebalance() *Tree[K, V] { max := 1 + min(tree.left.Level(), tree.right.Level()) if tree.Level() > max { tree.setLevel(max) if tree.right.Level() > max { copy := *tree.right copy.setLevel(max) tree.right = © } return tree.skew_rec().split_rec() } return tree.fixup() } func (tree *Tree[K, V]) need_skew() bool { return tree != nil && tree.Level() == tree.left.Level() } func (tree *Tree[K, V]) need_split() bool { return tree != nil && tree.right != nil && tree.Level() == tree.right.right.Level() } func (tree *Tree[K, V]) need_raise() bool { return tree != nil && tree.Level() == tree.left.Level() && tree.Level() == tree.right.Level() } func (tree *Tree[K, V]) skew() *Tree[K, V] { if tree.need_skew() { // Rotate right. copy := *tree.left tree.left = copy.right copy.right = tree.fixup() tree = © } return tree.fixup() } func (tree *Tree[K, V]) skew_rec() *Tree[K, V] { if tree.need_skew() { // Rotate right. copy := *tree.left tree.left = copy.right copy.right = tree.skew_rec() // Recurse. tree = © } if tree.right.need_skew() { node := *tree.right tree.right = node.skew_rec() // Recurse. } return tree.fixup() } func (tree *Tree[K, V]) split() *Tree[K, V] { if tree.need_split() { // Rotate left. copy := *tree.right tree.right = copy.left copy.left = tree.fixup() tree = © } return tree.fixup() } func (tree *Tree[K, V]) split_rec() *Tree[K, V] { if tree.need_split() { // Rotate left. copy := *tree.right tree.right = copy.left copy.left = tree.fixup() tree = © } if tree.right.need_split() { node := *tree.right tree.right = node.split() // Recurse once. } return tree.fixup() }
package aa import "cmp" // Union returns the set union of two trees, // last value wins. func Union[K cmp.Ordered, V any](t1, t2 *Tree[K, V]) *Tree[K, V] { switch { case t1 == t2 || t1 == nil: return t2 case t2 == nil: return t1 } left, _, right := t1.Split(t2.key) left = Union(left, t2.left) right = Union(right, t2.right) return join(left, t2, right) } // Intersection returns the set intersection of two trees, // first value wins. func Intersection[K cmp.Ordered, V any](t1, t2 *Tree[K, V]) *Tree[K, V] { switch { case t1 == t2: return t1 case t1 == nil || t2 == nil: return nil } left, node, right := t1.Split(t2.key) left = Intersection(left, t2.left) right = Intersection(right, t2.right) if node == nil { return join2(left, right) } return join(left, node, right) } // Difference returns the set difference of two trees. func Difference[K cmp.Ordered, V any](t1, t2 *Tree[K, V]) *Tree[K, V] { switch { case t1 == t2 || t1 == nil: return nil case t2 == nil: return t1 } left, _, right := t1.Split(t2.key) left = Difference(left, t2.left) right = Difference(right, t2.right) return join2(left, right) } // SymmetricDifference returns the set symmetric difference of two trees. func SymmetricDifference[K cmp.Ordered, V any](t1, t2 *Tree[K, V]) *Tree[K, V] { switch { case t1 == t2: return nil case t1 == nil: return t2 case t2 == nil: return t1 } left, node, right := t1.Split(t2.key) left = SymmetricDifference(left, t2.left) right = SymmetricDifference(right, t2.right) if node == nil { return join(left, t2, right) } return join2(left, right) }
package aa import "cmp" // Select finds the node at index i of this tree and returns it // (or nil if i is out of range). func (tree *Tree[K, V]) Select(i int) *Tree[K, V] { for tree != nil { p := tree.left.Len() switch cmp.Compare(i, p) { case -1: tree = tree.left case +1: tree = tree.right i -= p + 1 default: return tree } } return tree } // Rank finds the rank of key, // the number of nodes in this tree less than key. // // tree.Rank(tree.Select(i).Key()) ⟹ i, iff 0 ≤ i < tree.Len() func (tree *Tree[K, V]) Rank(key K) int { k := 0 for tree != nil { switch cmp.Compare(key, tree.key) { case -1: tree = tree.left case +1: k += tree.left.Len() + 1 tree = tree.right default: return k + tree.left.Len() } } return k }
// Package aa implements immutable AA trees. package aa import "cmp" // Tree is an immutable AA tree, // a form of self-balancing binary search tree. // // Use *Tree as a reference type; the zero value for *Tree (nil) is the empty tree: // // var empty *aa.Tree[int, string] // one := empty.Put(1, "one") // one.Has(1) ⟹ true // // Note: the zero value for Tree{} is a valid — but non-empty — tree. type Tree[K cmp.Ordered, V any] struct { left *Tree[K, V] right *Tree[K, V] key K value V balance } // Key returns the key at the root of this tree. // // Note: getting the root key of an empty tree (nil) // causes a runtime panic. func (tree *Tree[K, V]) Key() K { return tree.key } // Value returns the value at the root of this tree. // // Note: getting the root value of an empty tree (nil) // causes a runtime panic. func (tree *Tree[K, V]) Value() V { return tree.value } // Left returns the left subtree of this tree, // containing all keys less than its root key. // // Note: the left subtree of the empty tree is the empty tree (nil). func (tree *Tree[K, V]) Left() *Tree[K, V] { if tree == nil { return nil } return tree.left } // Right returns the right subtree of this tree, // containing all keys greater than its root key. // // Note: the right subtree of the empty tree is the empty tree (nil). func (tree *Tree[K, V]) Right() *Tree[K, V] { if tree == nil { return nil } return tree.right } // Level returns the level of this AA tree. // // Notes: // - the level of the empty tree (nil) is 0. // - the height of a tree of level n is between n and 2·n. // - the number of nodes in a tree of level n is between 2ⁿ-1 and 3ⁿ-1. func (tree *Tree[K, V]) Level() int { if tree == nil { return 0 } return tree.balance.level() } // Len returns the number of nodes in this tree. func (tree *Tree[K, V]) Len() int { if tree == nil { return 0 } return tree.balance.len() } // Min finds the least key in this tree, // and returns the node for that key, // or nil if this tree is empty. func (tree *Tree[K, V]) Min() *Tree[K, V] { if tree == nil { return nil } for tree.left != nil { tree = tree.left } return tree } // Max finds the greatest key in this tree, // and returns the node for that key, // or nil if this tree is empty. func (tree *Tree[K, V]) Max() *Tree[K, V] { if tree == nil { return nil } for tree.right != nil { tree = tree.right } return tree } // Floor finds the greatest key in this tree less-than or equal-to key, // and returns the node for that key, // or nil if no such key exists in this tree. func (tree *Tree[K, V]) Floor(key K) *Tree[K, V] { var node *Tree[K, V] for tree != nil { if cmp.Less(key, tree.key) { tree = tree.left } else { node = tree tree = tree.right } } return node } // Ceil finds the least key in this tree greater-than or equal-to key, // and returns the node for that key, // or nil if no such key exists in this tree. func (tree *Tree[K, V]) Ceil(key K) *Tree[K, V] { var node *Tree[K, V] for tree != nil { if cmp.Less(tree.key, key) { tree = tree.right } else { node = tree tree = tree.left } } return node } // Get retrieves the value for a given key; // found indicates whether key exists in this tree. func (tree *Tree[K, V]) Get(key K) (value V, found bool) { // Floor uses 2-way search, which is faster for strings: // https://go.dev/issue/71270 // https://user.it.uu.se/~arnea/ps/searchproc.pdf // Both Floor/Ceil work; Floor is faster since AA trees lean right. node := tree.Floor(key) if node != nil && (key == node.key || key != key) { return node.value, true } return // zero, false } // Has reports whether key exists in this tree. func (tree *Tree[K, V]) Has(key K) bool { _, found := tree.Get(key) return found } // Put returns a modified tree with key set to value. // // tree.Put(key, value).Get(key) ⟹ (value, true) func (tree *Tree[K, V]) Put(key K, value V) *Tree[K, V] { return tree.Patch(key, func(*Tree[K, V]) (V, bool) { return value, true }) } // Add returns a (possibly) modified tree that contains key. // // tree.Add(key).Has(key) ⟹ true func (tree *Tree[K, V]) Add(key K) *Tree[K, V] { return tree.Patch(key, func(node *Tree[K, V]) (value V, ok bool) { return value, node == nil }) } // Patch finds key in this tree, calls update with the node for that key // (or nil, if key is not found), and returns a (possibly) modified tree. // // The update callback can opt to set/update the value for the key, // by returning (value, true), or not, by returning false. func (tree *Tree[K, V]) Patch(key K, update func(node *Tree[K, V]) (value V, ok bool)) *Tree[K, V] { if tree == nil { if value, ok := update(tree); ok { return &Tree[K, V]{key: key, value: value} } return nil } switch cmp.Compare(key, tree.key) { default: if value, ok := update(tree); ok { copy := *tree copy.value = value return © } return tree case -1: left := tree.left.Patch(key, update) if left == tree.left { return tree } copy := *tree copy.left = left return copy.ins_rebalance() case +1: right := tree.right.Patch(key, update) if right == tree.right { return tree } copy := *tree copy.right = right return copy.ins_rebalance() } } // Delete returns a (possibly) modified tree with key removed from it. // The optional pred is called to confirm deletion. // // tree.Delete(key).Has(key) ⟹ false func (tree *Tree[K, V]) Delete(key K, pred ...func(node *Tree[K, V]) bool) *Tree[K, V] { var p func(*Tree[K, V]) bool if len(pred) > 0 { p = pred[0] } return tree.delete(key, p) } func (tree *Tree[K, V]) delete(key K, pred func(node *Tree[K, V]) bool) *Tree[K, V] { if tree == nil { return nil } switch cmp.Compare(key, tree.key) { case -1: left := tree.left.delete(key, pred) if left == tree.left { return tree } copy := *tree copy.left = left return copy.del_rebalance() case +1: right := tree.right.delete(key, pred) if right == tree.right { return tree } copy := *tree copy.right = right return copy.del_rebalance() default: if pred != nil && !pred(tree) { return tree } // If tree.right is nil, tree.left is too. if tree.left == nil { return tree.right } copy := *tree var heir *Tree[K, V] // Either works; this saves a few allocs. if copy.Level() == copy.right.Level() { copy.right, heir = copy.right.DeleteMin() } else { copy.left, heir = copy.left.DeleteMax() } copy.key = heir.key copy.value = heir.value return copy.del_rebalance() } } // DeleteMin returns a modified tree with its least key removed from it, // and the removed node. func (tree *Tree[K, V]) DeleteMin() (_, node *Tree[K, V]) { if tree == nil { return nil, nil } if tree.left == nil { return tree.right, tree } copy := *tree copy.left, node = tree.left.DeleteMin() return copy.del_rebalance(), node } // DeleteMax returns a modified tree with its greatest key removed from it, // and the removed node. func (tree *Tree[K, V]) DeleteMax() (_, node *Tree[K, V]) { if tree == nil { return nil, nil } if tree.right == nil { return nil, tree // tree.left, tree } copy := *tree copy.right, node = tree.right.DeleteMax() return copy.del_rebalance(), node }