package wbt
import "cmp"
// https://yoichihirai.com/bst.pdf
const (
delta_p, delta_q = 5, 2
gamma_p, gamma_q = 3, 2
)
func (tree *Tree[K, V]) left_rebalance() *Tree[K, V] {
if !is_heavy(tree.Left(), tree.Right()) {
return tree.fixup()
}
frst := *tree.left
if is_single(frst.Right(), frst.Left()) {
tree.left = frst.right
frst.right = tree.fixup()
return frst.fixup()
}
scnd := *frst.right
tree.left = scnd.right
scnd.right = tree.fixup()
frst.right = scnd.left
scnd.left = frst.fixup()
return scnd.fixup()
}
func (tree *Tree[K, V]) right_rebalance() *Tree[K, V] {
if !is_heavy(tree.Right(), tree.Left()) {
return tree.fixup()
}
frst := *tree.right
if is_single(frst.Left(), frst.Right()) {
tree.right = frst.left
frst.left = tree.fixup()
return frst.fixup()
}
scnd := *frst.left
tree.right = scnd.left
scnd.left = tree.fixup()
frst.left = scnd.right
scnd.right = frst.fixup()
return scnd.fixup()
}
func is_heavy[K cmp.Ordered, V any](a, b *Tree[K, V]) bool {
// Nodes are are at least 4 machine words,
// so we'd run out of memory before this overflows.
return delta_q*(a.Len()+1) > delta_p*(b.Len()+1)
}
func is_single[K cmp.Ordered, V any](a, b *Tree[K, V]) bool {
// Nodes are are at least 4 machine words,
// so we'd run out of memory before this overflows.
return gamma_q*(a.Len()+1) < gamma_p*(b.Len()+1)
}
func (tree *Tree[K, V]) fixup() *Tree[K, V] {
tree.childs = tree.Left().Len() + tree.Right().Len()
return tree
}
package wbt
import (
"cmp"
"iter"
)
// Equal reports whether two trees contain the same key/value pairs.
func Equal[K cmp.Ordered, V comparable](t1, t2 *Tree[K, V]) bool {
if t1 == t2 {
return true
}
l1 := t1.Len()
l2 := t2.Len()
if l1 != l2 {
return false
}
n1, s1 := iter.Pull2(t1.Ascend())
defer s1()
n2, s2 := iter.Pull2(t2.Ascend())
defer s2()
for range l1 {
k1, v1, _ := n1()
k2, v2, _ := n2()
if (k1 != k2 && (k1 == k1 || k2 == k2)) || v1 != v2 {
return false
}
}
return true
}
// Subset reports whether t2 contains every t1 key/value pair.
func Subset[K cmp.Ordered, V comparable](t1, t2 *Tree[K, V]) bool {
if t1 == t2 {
return true
}
l1 := t1.Len()
if l1 == 0 {
return true
}
l2 := t2.Len()
if l1 > l2 {
return false
}
n1, s1 := iter.Pull2(t1.Ascend())
defer s1()
n2, s2 := iter.Pull2(t2.Ascend())
defer s2()
loop1:
for ; l1 <= l2 && l1 > 0; l1-- {
k1, v1, _ := n1()
for ; l1 <= l2 && l2 > 0; l2-- {
k2, v2, _ := n2()
switch cmp.Compare(k1, k2) {
case -1:
return false // won't find k1 in t2
case +1:
continue
}
if v1 == v2 {
continue loop1
}
return false // found k1, v1 != v2
}
return false // couldn't find k1 in t2
}
return l1 == 0 // checked all k1
}
// Overlap reports whether t1 contains any keys from t2.
func Overlap[K cmp.Ordered, V any](t1, t2 *Tree[K, V]) bool {
if t1 == nil || t2 == nil {
return false
}
if t1 == t2 {
return true
}
n1, s1 := iter.Pull2(t1.Ascend())
defer s1()
n2, s2 := iter.Pull2(t2.Ascend())
defer s2()
k1, _, ok1 := n1()
k2, _, ok2 := n2()
for ok1 && ok2 {
switch cmp.Compare(k1, k2) {
case -1:
k1, _, ok1 = n1()
case +1:
k2, _, ok2 = n2()
default:
return true
}
}
return false
}
package wbt
import (
"cmp"
"iter"
)
// Ascend returns an ascending iterator for this tree.
func (tree *Tree[K, V]) Ascend() iter.Seq2[K, V] {
return func(yield func(K, V) bool) { tree.ascend(yield) }
}
func (tree *Tree[K, V]) ascend(yield func(K, V) bool) bool {
if tree == nil {
return true
}
return tree.left.ascend(yield) && yield(tree.key, tree.value) && tree.right.ascend(yield)
}
// Descend returns a descending iterator for this tree.
func (tree *Tree[K, V]) Descend() iter.Seq2[K, V] {
return func(yield func(K, V) bool) { tree.descend(yield) }
}
func (tree *Tree[K, V]) descend(yield func(K, V) bool) bool {
if tree == nil {
return true
}
return tree.right.descend(yield) && yield(tree.key, tree.value) && tree.left.descend(yield)
}
// AscendCeil returns an ascending iterator for this tree,
// starting at the least key in this tree greater-than or equal-to pivot.
func (tree *Tree[K, V]) AscendCeil(pivot K) iter.Seq2[K, V] {
return func(yield func(K, V) bool) { tree.ascendCeil(pivot, yield) }
}
func (tree *Tree[K, V]) ascendCeil(pivot K, yield func(K, V) bool) bool {
for tree != nil {
if !cmp.Less(tree.key, pivot) {
return tree.left.ascendCeil(pivot, yield) && yield(tree.key, tree.value) && tree.right.ascend(yield)
}
tree = tree.right
}
return true
}
// AscendFloor returns an ascending iterator for this tree,
// starting at the greatest key in this tree less-than or equal-to pivot.
func (tree *Tree[K, V]) AscendFloor(pivot K) iter.Seq2[K, V] {
return func(yield func(K, V) bool) { tree.ascendFloor(pivot, nil, yield) }
}
func (tree *Tree[K, V]) ascendFloor(pivot K, node *Tree[K, V], yield func(K, V) bool) bool {
for tree != nil {
if cmp.Less(pivot, tree.key) {
return tree.left.ascendFloor(pivot, node, yield) && yield(tree.key, tree.value) && tree.right.ascend(yield)
}
node = tree
tree = tree.right
}
if node != nil {
return yield(node.key, node.value)
}
return true
}
// DescendFloor returns a descending iterator for this tree,
// starting at the greatest key in this tree less-than or equal-to pivot.
func (tree *Tree[K, V]) DescendFloor(pivot K) iter.Seq2[K, V] {
return func(yield func(K, V) bool) { tree.descendFloor(pivot, yield) }
}
func (tree *Tree[K, V]) descendFloor(pivot K, yield func(K, V) bool) bool {
for tree != nil {
if !cmp.Less(pivot, tree.key) {
return tree.right.descendFloor(pivot, yield) && yield(tree.key, tree.value) && tree.left.descend(yield)
}
tree = tree.left
}
return true
}
// DescendCeil returns a descending iterator for this tree,
// starting at the least key in this tree greater-than or equal-to pivot.
func (tree *Tree[K, V]) DescendCeil(pivot K) iter.Seq2[K, V] {
return func(yield func(K, V) bool) { tree.descendCeil(pivot, nil, yield) }
}
func (tree *Tree[K, V]) descendCeil(pivot K, node *Tree[K, V], yield func(K, V) bool) bool {
for tree != nil {
if cmp.Less(tree.key, pivot) {
return tree.right.descendCeil(pivot, node, yield) && yield(tree.key, tree.value) && tree.left.descend(yield)
}
node = tree
tree = tree.left
}
if node != nil {
return yield(node.key, node.value)
}
return true
}
package wbt
import "cmp"
// Split partitions this tree around a key. It returns
// a left tree with keys less than key,
// a right tree with keys greater than key,
// and the node for the key
// (or nil if no such key exists in this tree).
func (tree *Tree[K, V]) Split(key K) (left, node, right *Tree[K, V]) {
if tree == nil {
return nil, nil, nil
}
switch cmp.Compare(key, tree.key) {
default:
return tree.left, tree, tree.right
case -1:
left, node, right = tree.left.Split(key)
return left, node, join(right, tree, tree.right)
case +1:
left, node, right = tree.right.Split(key)
return join(tree.left, tree, left), node, right
}
}
// Filter returns a tree of nodes for which pred returns true.
func (tree *Tree[K, V]) Filter(pred func(node *Tree[K, V]) bool) *Tree[K, V] {
if tree == nil {
return nil
}
left := tree.left.Filter(pred)
right := tree.right.Filter(pred)
if pred(tree) {
return join(left, tree, right)
}
return join2(left, right)
}
// Partition returns a tree of nodes for which pred returns true,
// and a tree of nodes for which it returns false.
func (tree *Tree[K, V]) Partition(pred func(node *Tree[K, V]) bool) (t, f *Tree[K, V]) {
if tree == nil {
return nil, nil
}
lt, lf := tree.left.Partition(pred)
rt, rf := tree.right.Partition(pred)
if pred(tree) {
return join(lt, tree, rt), join2(lf, rf)
}
return join2(lt, rt), join(lf, tree, rf)
}
func join[K cmp.Ordered, V any](left, node, right *Tree[K, V]) *Tree[K, V] {
if left == node.left && right == node.right {
return node
}
switch {
case is_heavy(right, left):
copy := *right //nolint:nilaway
copy.left = join(left, node, right.left)
return copy.left_rebalance()
case is_heavy(left, right):
copy := *left //nolint:nilaway
copy.right = join(left.right, node, right)
return copy.right_rebalance()
default:
return makeNode(node.key, node.value, left, right)
}
}
func join2[K cmp.Ordered, V any](left, right *Tree[K, V]) *Tree[K, V] {
switch {
case left == nil:
return right
case right == nil:
return left
}
var heir *Tree[K, V]
// Either works; this saves a few allocs.
if left.childs > right.childs {
left, heir = left.DeleteMax()
} else {
right, heir = right.DeleteMin()
}
return join(left, heir, right)
}
package wbt
import (
"cmp"
"slices"
)
// MakeSet builds a tree from a set of keys.
func MakeSet[K cmp.Ordered](keys ...K) *Tree[K, struct{}] {
if !increasing(keys) {
keys = slices.Clone(keys)
slices.Sort(keys)
keys = slices.Compact(keys)
}
return makeTree[K, struct{}](keys, nil)
}
// MakeMap builds a tree from a key-value map.
func MakeMap[K cmp.Ordered, V any](m map[K]V) *Tree[K, V] {
i, keys := 0, make([]K, len(m))
for key := range m {
keys[i] = key
i++
}
slices.Sort(keys)
return makeTree(keys, m)
}
func makeTree[K cmp.Ordered, V any](keys []K, m map[K]V) *Tree[K, V] {
if len(keys) == 0 {
return nil
}
mid := len(keys) / 2
left := makeTree(keys[:mid], m)
right := makeTree(keys[mid+1:], m)
key := keys[mid]
return makeNode(key, m[key], left, right)
}
func makeNode[K cmp.Ordered, V any](k K, v V, left, right *Tree[K, V]) *Tree[K, V] {
node := Tree[K, V]{
left: left,
right: right,
key: k,
value: v,
}
return node.fixup()
}
// Increasing tests if keys are in strictly increasing order.
func increasing[K cmp.Ordered](keys []K) bool {
for i := len(keys) - 1; i > 0; i-- {
//nolint:nilaway
if !cmp.Less(keys[i-1], keys[i]) {
return false
}
}
return true
}
// Collect collects key-value pairs from this tree
// into a new map and returns it.
func (tree *Tree[K, V]) Collect() map[K]V {
m := make(map[K]V, tree.Len())
tree.collect(m)
return m
}
func (tree *Tree[K, V]) collect(m map[K]V) {
for tree != nil {
m[tree.key] = tree.value
tree.left.collect(m)
tree = tree.right
}
}
package wbt
import "cmp"
// Union returns the set union of two trees,
// last value wins.
func Union[K cmp.Ordered, V any](t1, t2 *Tree[K, V]) *Tree[K, V] {
switch {
case t1 == t2 || t1 == nil:
return t2
case t2 == nil:
return t1
}
left, _, right := t1.Split(t2.key)
left = Union(left, t2.left)
right = Union(right, t2.right)
return join(left, t2, right)
}
// Intersection returns the set intersection of two trees,
// first value wins.
func Intersection[K cmp.Ordered, V any](t1, t2 *Tree[K, V]) *Tree[K, V] {
switch {
case t1 == t2:
return t1
case t1 == nil || t2 == nil:
return nil
}
left, node, right := t1.Split(t2.key)
left = Intersection(left, t2.left)
right = Intersection(right, t2.right)
if node == nil {
return join2(left, right)
}
return join(left, node, right)
}
// Difference returns the set difference of two trees.
func Difference[K cmp.Ordered, V any](t1, t2 *Tree[K, V]) *Tree[K, V] {
switch {
case t1 == t2 || t1 == nil:
return nil
case t2 == nil:
return t1
}
left, _, right := t1.Split(t2.key)
left = Difference(left, t2.left)
right = Difference(right, t2.right)
return join2(left, right)
}
// SymmetricDifference returns the set symmetric difference of two trees.
func SymmetricDifference[K cmp.Ordered, V any](t1, t2 *Tree[K, V]) *Tree[K, V] {
switch {
case t1 == t2:
return nil
case t1 == nil:
return t2
case t2 == nil:
return t1
}
left, node, right := t1.Split(t2.key)
left = SymmetricDifference(left, t2.left)
right = SymmetricDifference(right, t2.right)
if node == nil {
return join(left, t2, right)
}
return join2(left, right)
}
package wbt
import "cmp"
// Select finds the node at index i of this tree and returns it
// (or nil if i is out of range).
func (tree *Tree[K, V]) Select(i int) *Tree[K, V] {
for tree != nil {
p := tree.left.Len()
switch cmp.Compare(i, p) {
case -1:
tree = tree.left
case +1:
tree = tree.right
i -= p + 1
default:
return tree
}
}
return tree
}
// Rank finds the rank of key,
// the number of nodes in this tree less than key.
//
// tree.Rank(tree.Select(i).Key()) ⟹ i, iff 0 ≤ i < tree.Len()
func (tree *Tree[K, V]) Rank(key K) int {
k := 0
for tree != nil {
switch cmp.Compare(key, tree.key) {
case -1:
tree = tree.left
case +1:
k += tree.left.Len() + 1
tree = tree.right
default:
return k + tree.left.Len()
}
}
return k
}
// Package wbt implements immutable weight-balanced trees.
package wbt
import "cmp"
// Tree is an immutable weight-balanced tree,
// a form of self-balancing binary search tree.
//
// Use *Tree as a reference type; the zero value for *Tree (nil) is the empty tree:
//
// var empty *wbt.Tree[int, string]
// one := empty.Put(1, "one")
// one.Has(1) ⟹ true
//
// Note: the zero value for Tree{} is a valid — but non-empty — tree.
type Tree[K cmp.Ordered, V any] struct {
left *Tree[K, V]
right *Tree[K, V]
key K
value V
childs int
}
// Key returns the key at the root of this tree.
//
// Note: getting the root key of an empty tree (nil)
// causes a runtime panic.
func (tree *Tree[K, V]) Key() K {
return tree.key
}
// Value returns the value at the root of this tree.
//
// Note: getting the root value of an empty tree (nil)
// causes a runtime panic.
func (tree *Tree[K, V]) Value() V {
return tree.value
}
// Left returns the left subtree of this tree,
// containing all keys less than its root key.
//
// Note: the left subtree of the empty tree is the empty tree (nil).
func (tree *Tree[K, V]) Left() *Tree[K, V] {
if tree == nil {
return nil
}
return tree.left
}
// Right returns the right subtree of this tree,
// containing all keys greater than its root key.
//
// Note: the right subtree of the empty tree is the empty tree (nil).
func (tree *Tree[K, V]) Right() *Tree[K, V] {
if tree == nil {
return nil
}
return tree.right
}
// Len returns the number of nodes in this tree.
func (tree *Tree[K, V]) Len() int {
if tree == nil {
return 0
}
return 1 + tree.childs
}
// Min finds the least key in this tree,
// and returns the node for that key,
// or nil if this tree is empty.
func (tree *Tree[K, V]) Min() *Tree[K, V] {
if tree == nil {
return nil
}
for tree.left != nil {
tree = tree.left
}
return tree
}
// Max finds the greatest key in this tree,
// and returns the node for that key,
// or nil if this tree is empty.
func (tree *Tree[K, V]) Max() *Tree[K, V] {
if tree == nil {
return nil
}
for tree.right != nil {
tree = tree.right
}
return tree
}
// Floor finds the greatest key in this tree less-than or equal-to key,
// and returns the node for that key,
// or nil if no such key exists in this tree.
func (tree *Tree[K, V]) Floor(key K) *Tree[K, V] {
var node *Tree[K, V]
for tree != nil {
if cmp.Less(key, tree.key) {
tree = tree.left
} else {
node = tree
tree = tree.right
}
}
return node
}
// Ceil finds the least key in this tree greater-than or equal-to key,
// and returns the node for that key,
// or nil if no such key exists in this tree.
func (tree *Tree[K, V]) Ceil(key K) *Tree[K, V] {
var node *Tree[K, V]
for tree != nil {
if cmp.Less(tree.key, key) {
tree = tree.right
} else {
node = tree
tree = tree.left
}
}
return node
}
// Get retrieves the value for a given key;
// found indicates whether key exists in this tree.
func (tree *Tree[K, V]) Get(key K) (value V, found bool) {
// Floor uses 2-way search, which is faster for strings:
// https://go.dev/issue/71270
// https://user.it.uu.se/~arnea/ps/searchproc.pdf
// Floor(NaN) returns either nil or NaN.
node := tree.Floor(key)
if node != nil && (key == node.key || key != key) {
return node.value, true
}
return // zero, false
}
// Has reports whether key exists in this tree.
func (tree *Tree[K, V]) Has(key K) bool {
_, found := tree.Get(key)
return found
}
// Put returns a modified tree with key set to value.
//
// tree.Put(key, value).Get(key) ⟹ (value, true)
func (tree *Tree[K, V]) Put(key K, value V) *Tree[K, V] {
return tree.Patch(key, func(*Tree[K, V]) (V, bool) {
return value, true
})
}
// Add returns a (possibly) modified tree that contains key.
//
// tree.Add(key).Has(key) ⟹ true
func (tree *Tree[K, V]) Add(key K) *Tree[K, V] {
return tree.Patch(key, func(node *Tree[K, V]) (value V, ok bool) {
return value, node == nil
})
}
// Patch finds key in this tree, calls update with the node for that key
// (or nil, if key is not found), and returns a (possibly) modified tree.
//
// The update callback can opt to set/update the value for the key,
// by returning (value, true), or not, by returning false.
func (tree *Tree[K, V]) Patch(key K, update func(node *Tree[K, V]) (value V, ok bool)) *Tree[K, V] {
if tree == nil {
if value, ok := update(tree); ok {
return &Tree[K, V]{key: key, value: value}
}
return nil
}
switch cmp.Compare(key, tree.key) {
default:
if value, ok := update(tree); ok {
copy := *tree
copy.value = value
return ©
}
return tree
case -1:
left := tree.left.Patch(key, update)
if left == tree.left {
return tree
}
copy := *tree
copy.left = left
return copy.left_rebalance()
case +1:
right := tree.right.Patch(key, update)
if right == tree.right {
return tree
}
copy := *tree
copy.right = right
return copy.right_rebalance()
}
}
// Delete returns a (possibly) modified tree with key removed from it.
// The optional pred is called to confirm deletion.
//
// tree.Delete(key).Has(key) ⟹ false
func (tree *Tree[K, V]) Delete(key K, pred ...func(node *Tree[K, V]) bool) *Tree[K, V] {
var p func(*Tree[K, V]) bool
if len(pred) > 0 {
p = pred[0]
}
return tree.delete(key, p)
}
func (tree *Tree[K, V]) delete(key K, pred func(node *Tree[K, V]) bool) *Tree[K, V] {
if tree == nil {
return nil
}
switch cmp.Compare(key, tree.key) {
case -1:
left := tree.left.delete(key, pred)
if left == tree.left {
return tree
}
copy := *tree
copy.left = left
return copy.right_rebalance()
case +1:
right := tree.right.delete(key, pred)
if right == tree.right {
return tree
}
copy := *tree
copy.right = right
return copy.left_rebalance()
default:
if pred != nil && !pred(tree) {
return tree
}
switch {
case tree.left == nil:
return tree.right
case tree.right == nil:
return tree.left
}
copy := *tree
var heir *Tree[K, V]
// Either works; this saves a few allocs.
if copy.left.childs > copy.right.childs {
copy.left, heir = copy.left.DeleteMax()
copy.key = heir.key
copy.value = heir.value
return copy.right_rebalance()
} else {
copy.right, heir = copy.right.DeleteMin()
copy.key = heir.key
copy.value = heir.value
return copy.left_rebalance()
}
}
}
// DeleteMin returns a modified tree with its least key removed from it,
// and the removed node.
func (tree *Tree[K, V]) DeleteMin() (_, node *Tree[K, V]) {
if tree == nil {
return nil, nil
}
if tree.left == nil {
return tree.right, tree
}
copy := *tree
copy.left, node = tree.left.DeleteMin()
return copy.right_rebalance(), node
}
// DeleteMax returns a modified tree with its greatest key removed from it,
// and the removed node.
func (tree *Tree[K, V]) DeleteMax() (_, node *Tree[K, V]) {
if tree == nil {
return nil, nil
}
if tree.right == nil {
return tree.left, tree
}
copy := *tree
copy.right, node = tree.right.DeleteMax()
return copy.left_rebalance(), node
}