package aa
import (
"math"
"math/bits"
)
const mask = bits.UintSize - 1
// Balance stores both the level and the cardinality of an AA tree,
// both offset by one: the zero value represents a leaf node of level one.
//
// The lowest 5 or 6 bits store the level,
// the highest 27 or 58 bits store cardinality.
//
// This is sufficient for trees up to 32 or 64 levels,
// with 134217728 or 288230376151711744 elements.
//
// With the smallest possible node size (4 machine words),
// you'd need to exhaust at least half the address space (2GiB or 8EiB)
// before you could hit these limits.
type balance uint
func (b balance) level() int {
return 1 + int(b&mask)
}
func (b balance) len() int {
return 1 + int(b>>bits.Len(mask))
}
func (tree *Tree[K, V]) setLevel(level int) {
tree.balance = balance(level-1) | tree.balance&^mask
}
func (tree *Tree[K, V]) fixup() *Tree[K, V] {
var sum uint64
if tree.left != nil {
sum += 1<<bits.Len(mask) + uint64(tree.left.balance) + 1
}
if tree.right != nil {
sum += 1<<bits.Len(mask) + uint64(tree.right.balance)&^mask
}
if sum > math.MaxUint {
panic("overflow")
}
tree.balance = balance(sum)
return tree
}
package aa
import (
"cmp"
"iter"
)
// Equal reports whether two trees contain the same key/value pairs.
func Equal[K cmp.Ordered, V comparable](t1, t2 *Tree[K, V]) bool {
if t1 == t2 {
return true
}
l1 := t1.Len()
l2 := t2.Len()
if l1 != l2 {
return false
}
n1, s1 := iter.Pull2(t1.Ascend())
defer s1()
n2, s2 := iter.Pull2(t2.Ascend())
defer s2()
for range l1 {
k1, v1, _ := n1()
k2, v2, _ := n2()
if (k1 != k2 && (k1 == k1 || k2 == k2)) || v1 != v2 {
return false
}
}
return true
}
// Subset reports whether t2 contains every t1 key/value pair.
func Subset[K cmp.Ordered, V comparable](t1, t2 *Tree[K, V]) bool {
if t1 == t2 {
return true
}
l1 := t1.Len()
if l1 == 0 {
return true
}
l2 := t2.Len()
if l1 > l2 {
return false
}
n1, s1 := iter.Pull2(t1.Ascend())
defer s1()
n2, s2 := iter.Pull2(t2.Ascend())
defer s2()
loop1:
for ; l1 <= l2 && l1 > 0; l1-- {
k1, v1, _ := n1()
for ; l1 <= l2 && l2 > 0; l2-- {
k2, v2, _ := n2()
switch cmp.Compare(k1, k2) {
case -1:
return false // won't find k1 in t2
case +1:
continue
}
if v1 == v2 {
continue loop1
}
return false // found k1, v1 != v2
}
return false // couldn't find k1 in t2
}
return l1 == 0 // checked all k1
}
// Overlap reports whether t1 contains any keys from t2.
func Overlap[K cmp.Ordered, V any](t1, t2 *Tree[K, V]) bool {
if t1 == nil || t2 == nil {
return false
}
if t1 == t2 {
return true
}
n1, s1 := iter.Pull2(t1.Ascend())
defer s1()
n2, s2 := iter.Pull2(t2.Ascend())
defer s2()
k1, _, ok1 := n1()
k2, _, ok2 := n2()
for ok1 && ok2 {
switch cmp.Compare(k1, k2) {
case -1:
k1, _, ok1 = n1()
case +1:
k2, _, ok2 = n2()
default:
return true
}
}
return false
}
package aa
import (
"cmp"
"iter"
)
// Ascend returns an ascending iterator for this tree.
func (tree *Tree[K, V]) Ascend() iter.Seq2[K, V] {
return func(yield func(K, V) bool) { tree.ascend(yield) }
}
func (tree *Tree[K, V]) ascend(yield func(K, V) bool) bool {
if tree == nil {
return true
}
return tree.left.ascend(yield) && yield(tree.key, tree.value) && tree.right.ascend(yield)
}
// Descend returns a descending iterator for this tree.
func (tree *Tree[K, V]) Descend() iter.Seq2[K, V] {
return func(yield func(K, V) bool) { tree.descend(yield) }
}
func (tree *Tree[K, V]) descend(yield func(K, V) bool) bool {
if tree == nil {
return true
}
return tree.right.descend(yield) && yield(tree.key, tree.value) && tree.left.descend(yield)
}
// AscendCeil returns an ascending iterator for this tree,
// starting at the least key in this tree greater-than or equal-to pivot.
func (tree *Tree[K, V]) AscendCeil(pivot K) iter.Seq2[K, V] {
return func(yield func(K, V) bool) { tree.ascendCeil(pivot, yield) }
}
func (tree *Tree[K, V]) ascendCeil(pivot K, yield func(K, V) bool) bool {
for tree != nil {
if !cmp.Less(tree.key, pivot) {
return tree.left.ascendCeil(pivot, yield) && yield(tree.key, tree.value) && tree.right.ascend(yield)
}
tree = tree.right
}
return true
}
// AscendFloor returns an ascending iterator for this tree,
// starting at the greatest key in this tree less-than or equal-to pivot.
func (tree *Tree[K, V]) AscendFloor(pivot K) iter.Seq2[K, V] {
return func(yield func(K, V) bool) { tree.ascendFloor(pivot, nil, yield) }
}
func (tree *Tree[K, V]) ascendFloor(pivot K, node *Tree[K, V], yield func(K, V) bool) bool {
for tree != nil {
if cmp.Less(pivot, tree.key) {
return tree.left.ascendFloor(pivot, node, yield) && yield(tree.key, tree.value) && tree.right.ascend(yield)
}
node = tree
tree = tree.right
}
if node != nil {
return yield(node.key, node.value)
}
return true
}
// DescendFloor returns a descending iterator for this tree,
// starting at the greatest key in this tree less-than or equal-to pivot.
func (tree *Tree[K, V]) DescendFloor(pivot K) iter.Seq2[K, V] {
return func(yield func(K, V) bool) { tree.descendFloor(pivot, yield) }
}
func (tree *Tree[K, V]) descendFloor(pivot K, yield func(K, V) bool) bool {
for tree != nil {
if !cmp.Less(pivot, tree.key) {
return tree.right.descendFloor(pivot, yield) && yield(tree.key, tree.value) && tree.left.descend(yield)
}
tree = tree.left
}
return true
}
// DescendCeil returns a descending iterator for this tree,
// starting at the least key in this tree greater-than or equal-to pivot.
func (tree *Tree[K, V]) DescendCeil(pivot K) iter.Seq2[K, V] {
return func(yield func(K, V) bool) { tree.descendCeil(pivot, nil, yield) }
}
func (tree *Tree[K, V]) descendCeil(pivot K, node *Tree[K, V], yield func(K, V) bool) bool {
for tree != nil {
if cmp.Less(tree.key, pivot) {
return tree.right.descendCeil(pivot, node, yield) && yield(tree.key, tree.value) && tree.left.descend(yield)
}
node = tree
tree = tree.left
}
if node != nil {
return yield(node.key, node.value)
}
return true
}
package aa
import "cmp"
// Split partitions this tree around a key. It returns
// a left tree with keys less than key,
// a right tree with keys greater than key,
// and the node for the key
// (or nil if no such key exists in this tree).
func (tree *Tree[K, V]) Split(key K) (left, node, right *Tree[K, V]) {
if tree == nil {
return nil, nil, nil
}
switch cmp.Compare(key, tree.key) {
default:
return tree.left, tree, tree.right
case -1:
left, node, right = tree.left.Split(key)
return left, node, join(right, tree, tree.right)
case +1:
left, node, right = tree.right.Split(key)
return join(tree.left, tree, left), node, right
}
}
// Filter returns a tree of nodes for which pred returns true.
func (tree *Tree[K, V]) Filter(pred func(node *Tree[K, V]) bool) *Tree[K, V] {
if tree == nil {
return nil
}
left := tree.left.Filter(pred)
right := tree.right.Filter(pred)
if pred(tree) {
return join(left, tree, right)
}
return join2(left, right)
}
// Partition returns a tree of nodes for which pred returns true,
// and a tree of nodes for which it returns false.
func (tree *Tree[K, V]) Partition(pred func(node *Tree[K, V]) bool) (t, f *Tree[K, V]) {
if tree == nil {
return nil, nil
}
lt, lf := tree.left.Partition(pred)
rt, rf := tree.right.Partition(pred)
if pred(tree) {
return join(lt, tree, rt), join2(lf, rf)
}
return join2(lt, rt), join(lf, tree, rf)
}
func join[K cmp.Ordered, V any](left, node, right *Tree[K, V]) *Tree[K, V] {
if left == node.left && right == node.right {
return node
}
llevel := left.Level()
rlevel := right.Level()
//nolint:nilaway
if rlevel == llevel+1 && llevel == right.right.Level() { // Can we create a 3-node?
rlevel = llevel // Avoid recursion, rebalancing.
}
switch {
case llevel < rlevel:
copy := *right //nolint:nilaway
copy.left = join(left, node, right.left)
return copy.ins_rebalance()
case llevel > rlevel:
copy := *left //nolint:nilaway
copy.right = join(left.right, node, right)
return copy.ins_rebalance()
default:
return makeNode(node.key, node.value, left, right)
}
}
func join2[K cmp.Ordered, V any](left, right *Tree[K, V]) *Tree[K, V] {
switch {
case left == nil:
return right
case right == nil:
return left
}
// Both DeleteMin/DeleteMax work; DeleteMin minimizes allocs.
right, node := right.DeleteMin()
return join(left, node, right)
}
package aa
import (
"cmp"
"slices"
)
// MakeSet builds a tree from a set of keys.
func MakeSet[K cmp.Ordered](keys ...K) *Tree[K, struct{}] {
if !increasing(keys) {
keys = slices.Clone(keys)
slices.Sort(keys)
keys = slices.Compact(keys)
}
return makeTree[K, struct{}](keys, nil)
}
// MakeMap builds a tree from a key-value map.
func MakeMap[K cmp.Ordered, V any](m map[K]V) *Tree[K, V] {
i, keys := 0, make([]K, len(m))
for key := range m {
keys[i] = key
i++
}
slices.Sort(keys)
return makeTree(keys, m)
}
func makeTree[K cmp.Ordered, V any](keys []K, m map[K]V) *Tree[K, V] {
if len(keys) == 0 {
return nil
}
// AA trees lean right, so round down.
mid := (len(keys) - 1) / 2
left := makeTree(keys[:mid], m)
right := makeTree(keys[mid+1:], m)
key := keys[mid]
return makeNode(key, m[key], left, right)
}
func makeNode[K cmp.Ordered, V any](k K, v V, left, right *Tree[K, V]) *Tree[K, V] {
node := Tree[K, V]{
left: left,
right: right,
key: k,
value: v,
}
return node.fixup()
}
// Increasing tests if keys are in strictly increasing order.
func increasing[K cmp.Ordered](keys []K) bool {
for i := len(keys) - 1; i > 0; i-- {
//nolint:nilaway
if !cmp.Less(keys[i-1], keys[i]) {
return false
}
}
return true
}
// Collect collects key-value pairs from this tree
// into a new map and returns it.
func (tree *Tree[K, V]) Collect() map[K]V {
m := make(map[K]V, tree.Len())
tree.collect(m)
return m
}
func (tree *Tree[K, V]) collect(m map[K]V) {
// AA trees lean right, so recurse left.
for tree != nil {
m[tree.key] = tree.value
tree.left.collect(m)
tree = tree.right
}
}
package aa
func (tree *Tree[K, V]) ins_rebalance() *Tree[K, V] {
if tree.need_raise() { // Avoid 2 rotations and allocs.
return tree.fixup()
}
return tree.skew().split()
}
func (tree *Tree[K, V]) del_rebalance() *Tree[K, V] {
max := 1 + min(tree.left.Level(), tree.right.Level())
if tree.Level() > max {
tree.setLevel(max)
if tree.right.Level() > max {
copy := *tree.right
copy.setLevel(max)
tree.right = ©
}
return tree.skew_rec().split_rec()
}
return tree.fixup()
}
func (tree *Tree[K, V]) need_skew() bool {
return tree != nil && tree.Level() == tree.left.Level()
}
func (tree *Tree[K, V]) need_split() bool {
return tree != nil && tree.right != nil &&
tree.Level() == tree.right.right.Level()
}
func (tree *Tree[K, V]) need_raise() bool {
return tree != nil &&
tree.Level() == tree.left.Level() &&
tree.Level() == tree.right.Level()
}
func (tree *Tree[K, V]) skew() *Tree[K, V] {
if tree.need_skew() {
// Rotate right.
copy := *tree.left
tree.left = copy.right
copy.right = tree.fixup()
tree = ©
}
return tree.fixup()
}
func (tree *Tree[K, V]) skew_rec() *Tree[K, V] {
if tree.need_skew() {
// Rotate right.
copy := *tree.left
tree.left = copy.right
copy.right = tree.skew_rec() // Recurse.
tree = ©
}
if tree.right.need_skew() {
node := *tree.right
tree.right = node.skew_rec() // Recurse.
}
return tree.fixup()
}
func (tree *Tree[K, V]) split() *Tree[K, V] {
if tree.need_split() {
// Rotate left.
copy := *tree.right
tree.right = copy.left
copy.left = tree.fixup()
tree = ©
}
return tree.fixup()
}
func (tree *Tree[K, V]) split_rec() *Tree[K, V] {
if tree.need_split() {
// Rotate left.
copy := *tree.right
tree.right = copy.left
copy.left = tree.fixup()
tree = ©
}
if tree.right.need_split() {
node := *tree.right
tree.right = node.split() // Recurse once.
}
return tree.fixup()
}
package aa
import "cmp"
// Union returns the set union of two trees,
// last value wins.
func Union[K cmp.Ordered, V any](t1, t2 *Tree[K, V]) *Tree[K, V] {
switch {
case t1 == t2 || t1 == nil:
return t2
case t2 == nil:
return t1
}
left, _, right := t1.Split(t2.key)
left = Union(left, t2.left)
right = Union(right, t2.right)
return join(left, t2, right)
}
// Intersection returns the set intersection of two trees,
// first value wins.
func Intersection[K cmp.Ordered, V any](t1, t2 *Tree[K, V]) *Tree[K, V] {
switch {
case t1 == t2:
return t1
case t1 == nil || t2 == nil:
return nil
}
left, node, right := t1.Split(t2.key)
left = Intersection(left, t2.left)
right = Intersection(right, t2.right)
if node == nil {
return join2(left, right)
}
return join(left, node, right)
}
// Difference returns the set difference of two trees.
func Difference[K cmp.Ordered, V any](t1, t2 *Tree[K, V]) *Tree[K, V] {
switch {
case t1 == t2 || t1 == nil:
return nil
case t2 == nil:
return t1
}
left, _, right := t1.Split(t2.key)
left = Difference(left, t2.left)
right = Difference(right, t2.right)
return join2(left, right)
}
// SymmetricDifference returns the set symmetric difference of two trees.
func SymmetricDifference[K cmp.Ordered, V any](t1, t2 *Tree[K, V]) *Tree[K, V] {
switch {
case t1 == t2:
return nil
case t1 == nil:
return t2
case t2 == nil:
return t1
}
left, node, right := t1.Split(t2.key)
left = SymmetricDifference(left, t2.left)
right = SymmetricDifference(right, t2.right)
if node == nil {
return join(left, t2, right)
}
return join2(left, right)
}
package aa
import "cmp"
// Select finds the node at index i of this tree and returns it
// (or nil if i is out of range).
func (tree *Tree[K, V]) Select(i int) *Tree[K, V] {
for tree != nil {
p := tree.left.Len()
switch cmp.Compare(i, p) {
case -1:
tree = tree.left
case +1:
tree = tree.right
i -= p + 1
default:
return tree
}
}
return tree
}
// Rank finds the rank of key,
// the number of nodes in this tree less than key.
//
// tree.Rank(tree.Select(i).Key()) ⟹ i, iff 0 ≤ i < tree.Len()
func (tree *Tree[K, V]) Rank(key K) int {
k := 0
for tree != nil {
switch cmp.Compare(key, tree.key) {
case -1:
tree = tree.left
case +1:
k += tree.left.Len() + 1
tree = tree.right
default:
return k + tree.left.Len()
}
}
return k
}
// Package aa implements immutable AA trees.
package aa
import "cmp"
// Tree is an immutable AA tree,
// a form of self-balancing binary search tree.
//
// Use *Tree as a reference type; the zero value for *Tree (nil) is the empty tree:
//
// var empty *aa.Tree[int, string]
// one := empty.Put(1, "one")
// one.Has(1) ⟹ true
//
// Note: the zero value for Tree{} is a valid — but non-empty — tree.
type Tree[K cmp.Ordered, V any] struct {
left *Tree[K, V]
right *Tree[K, V]
key K
value V
balance
}
// Key returns the key at the root of this tree.
//
// Note: getting the root key of an empty tree (nil)
// causes a runtime panic.
func (tree *Tree[K, V]) Key() K {
return tree.key
}
// Value returns the value at the root of this tree.
//
// Note: getting the root value of an empty tree (nil)
// causes a runtime panic.
func (tree *Tree[K, V]) Value() V {
return tree.value
}
// Left returns the left subtree of this tree,
// containing all keys less than its root key.
//
// Note: the left subtree of the empty tree is the empty tree (nil).
func (tree *Tree[K, V]) Left() *Tree[K, V] {
if tree == nil {
return nil
}
return tree.left
}
// Right returns the right subtree of this tree,
// containing all keys greater than its root key.
//
// Note: the right subtree of the empty tree is the empty tree (nil).
func (tree *Tree[K, V]) Right() *Tree[K, V] {
if tree == nil {
return nil
}
return tree.right
}
// Level returns the level of this AA tree.
//
// Notes:
// - the level of the empty tree (nil) is 0.
// - the height of a tree of level n is between n and 2·n.
// - the number of nodes in a tree of level n is between 2ⁿ-1 and 3ⁿ-1.
func (tree *Tree[K, V]) Level() int {
if tree == nil {
return 0
}
return tree.balance.level()
}
// Len returns the number of nodes in this tree.
func (tree *Tree[K, V]) Len() int {
if tree == nil {
return 0
}
return tree.balance.len()
}
// Min finds the least key in this tree,
// and returns the node for that key,
// or nil if this tree is empty.
func (tree *Tree[K, V]) Min() *Tree[K, V] {
if tree == nil {
return nil
}
for tree.left != nil {
tree = tree.left
}
return tree
}
// Max finds the greatest key in this tree,
// and returns the node for that key,
// or nil if this tree is empty.
func (tree *Tree[K, V]) Max() *Tree[K, V] {
if tree == nil {
return nil
}
for tree.right != nil {
tree = tree.right
}
return tree
}
// Floor finds the greatest key in this tree less-than or equal-to key,
// and returns the node for that key,
// or nil if no such key exists in this tree.
func (tree *Tree[K, V]) Floor(key K) *Tree[K, V] {
var node *Tree[K, V]
for tree != nil {
if cmp.Less(key, tree.key) {
tree = tree.left
} else {
node = tree
tree = tree.right
}
}
return node
}
// Ceil finds the least key in this tree greater-than or equal-to key,
// and returns the node for that key,
// or nil if no such key exists in this tree.
func (tree *Tree[K, V]) Ceil(key K) *Tree[K, V] {
var node *Tree[K, V]
for tree != nil {
if cmp.Less(tree.key, key) {
tree = tree.right
} else {
node = tree
tree = tree.left
}
}
return node
}
// Get retrieves the value for a given key;
// found indicates whether key exists in this tree.
func (tree *Tree[K, V]) Get(key K) (value V, found bool) {
// Floor uses 2-way search, which is faster for strings:
// https://go.dev/issue/71270
// https://user.it.uu.se/~arnea/ps/searchproc.pdf
// Both Floor/Ceil work; Floor is faster since AA trees lean right.
node := tree.Floor(key)
if node != nil && (key == node.key || key != key) {
return node.value, true
}
return // zero, false
}
// Has reports whether key exists in this tree.
func (tree *Tree[K, V]) Has(key K) bool {
_, found := tree.Get(key)
return found
}
// Put returns a modified tree with key set to value.
//
// tree.Put(key, value).Get(key) ⟹ (value, true)
func (tree *Tree[K, V]) Put(key K, value V) *Tree[K, V] {
return tree.Patch(key, func(*Tree[K, V]) (V, bool) {
return value, true
})
}
// Add returns a (possibly) modified tree that contains key.
//
// tree.Add(key).Has(key) ⟹ true
func (tree *Tree[K, V]) Add(key K) *Tree[K, V] {
return tree.Patch(key, func(node *Tree[K, V]) (value V, ok bool) {
return value, node == nil
})
}
// Patch finds key in this tree, calls update with the node for that key
// (or nil, if key is not found), and returns a (possibly) modified tree.
//
// The update callback can opt to set/update the value for the key,
// by returning (value, true), or not, by returning false.
func (tree *Tree[K, V]) Patch(key K, update func(node *Tree[K, V]) (value V, ok bool)) *Tree[K, V] {
if tree == nil {
if value, ok := update(tree); ok {
return &Tree[K, V]{key: key, value: value}
}
return nil
}
switch cmp.Compare(key, tree.key) {
default:
if value, ok := update(tree); ok {
copy := *tree
copy.value = value
return ©
}
return tree
case -1:
left := tree.left.Patch(key, update)
if left == tree.left {
return tree
}
copy := *tree
copy.left = left
return copy.ins_rebalance()
case +1:
right := tree.right.Patch(key, update)
if right == tree.right {
return tree
}
copy := *tree
copy.right = right
return copy.ins_rebalance()
}
}
// Delete returns a (possibly) modified tree with key removed from it.
// The optional pred is called to confirm deletion.
//
// tree.Delete(key).Has(key) ⟹ false
func (tree *Tree[K, V]) Delete(key K, pred ...func(node *Tree[K, V]) bool) *Tree[K, V] {
var p func(*Tree[K, V]) bool
if len(pred) > 0 {
p = pred[0]
}
return tree.delete(key, p)
}
func (tree *Tree[K, V]) delete(key K, pred func(node *Tree[K, V]) bool) *Tree[K, V] {
if tree == nil {
return nil
}
switch cmp.Compare(key, tree.key) {
case -1:
left := tree.left.delete(key, pred)
if left == tree.left {
return tree
}
copy := *tree
copy.left = left
return copy.del_rebalance()
case +1:
right := tree.right.delete(key, pred)
if right == tree.right {
return tree
}
copy := *tree
copy.right = right
return copy.del_rebalance()
default:
if pred != nil && !pred(tree) {
return tree
}
// If tree.right is nil, tree.left is too.
if tree.left == nil {
return tree.right
}
copy := *tree
var heir *Tree[K, V]
// Either works; this saves a few allocs.
if copy.Level() == copy.right.Level() {
copy.right, heir = copy.right.DeleteMin()
} else {
copy.left, heir = copy.left.DeleteMax()
}
copy.key = heir.key
copy.value = heir.value
return copy.del_rebalance()
}
}
// DeleteMin returns a modified tree with its least key removed from it,
// and the removed node.
func (tree *Tree[K, V]) DeleteMin() (_, node *Tree[K, V]) {
if tree == nil {
return nil, nil
}
if tree.left == nil {
return tree.right, tree
}
copy := *tree
copy.left, node = tree.left.DeleteMin()
return copy.del_rebalance(), node
}
// DeleteMax returns a modified tree with its greatest key removed from it,
// and the removed node.
func (tree *Tree[K, V]) DeleteMax() (_, node *Tree[K, V]) {
if tree == nil {
return nil, nil
}
if tree.right == nil {
return nil, tree // tree.left, tree
}
copy := *tree
copy.right, node = tree.right.DeleteMax()
return copy.del_rebalance(), node
}