package skipfilter
import (
        "fmt"
        "runtime"
        "sync"
        "sync/atomic"
        "github.com/MauriceGit/skiplist"
        "github.com/RoaringBitmap/roaring/roaring64"
        "github.com/hashicorp/golang-lru"
)
// SkipFilter combines a skip list with an lru cache of roaring bitmaps
type SkipFilter struct {
        i     uint64
        idx   map[interface{}]uint64
        list  skiplist.SkipList
        cache *lru.Cache
        test  func(interface{}, interface{}) bool
        mutex sync.RWMutex
}
// New creates a new SkipFilter.
//   test - should return true if the value passes the provided filter.
//   size - controls the size of the LRU cache. Defaults to 100,000 if 0 or less.
//          should be tuned to match or exceed the expected filter cardinality.
func New(test func(value interface{}, filter interface{}) bool, size int) *SkipFilter {
        if size <= 0 {
                size = 1e5
        }
        cache, _ := lru.New(size)
        return &SkipFilter{
                idx:   make(map[interface{}]uint64),
                list:  skiplist.New(),
                cache: cache,
                test:  test,
        }
}
// Add adds a value to the set
func (sf *SkipFilter) Add(value interface{}) {
        sf.mutex.Lock()
        defer sf.mutex.Unlock()
        el := &entry{sf.i, value}
        sf.list.Insert(el)
        sf.idx[value] = sf.i
        sf.i++
}
// Remove removes a value from the set
func (sf *SkipFilter) Remove(value interface{}) {
        sf.mutex.Lock()
        defer sf.mutex.Unlock()
        if id, ok := sf.idx[value]; ok {
                sf.list.Delete(&entry{id: id})
                delete(sf.idx, value)
        }
}
// Len returns the number of values in the set
func (sf *SkipFilter) Len() int {
        sf.mutex.RLock()
        defer sf.mutex.RUnlock()
        return sf.list.GetNodeCount()
}
// MatchAny returns a slice of values in the set matching any of the provided filters
func (sf *SkipFilter) MatchAny(filterKeys ...interface{}) []interface{} {
        sf.mutex.RLock()
        defer sf.mutex.RUnlock()
        var sets = make([]*roaring64.Bitmap, len(filterKeys))
        var filters = make([]*filter, len(filterKeys))
        for i, k := range filterKeys {
                filters[i] = sf.getFilter(k)
                sets[i] = filters[i].set
        }
        var set = roaring64.ParOr(runtime.NumCPU(), sets...)
        values, notfound := sf.getValues(set)
        if len(notfound) > 0 {
                // Clean up references to removed values
                for _, f := range filters {
                        f.mutex.Lock()
                        for _, id := range notfound {
                                f.set.Remove(id)
                        }
                        f.mutex.Unlock()
                }
        }
        return values
}
// Walk executes callback for each value in the set beginning at `start` index.
// Return true in callback to continue iterating, false to stop.
// Returned uint64 is index of `next` element (send as `start` to continue iterating)
func (sf *SkipFilter) Walk(start uint64, callback func(val interface{}) bool) uint64 {
        sf.mutex.RLock()
        defer sf.mutex.RUnlock()
        var i uint64
        var id = start
        var prev uint64
        var first = true
        el, ok := sf.list.FindGreaterOrEqual(&entry{id: start})
        for ok && el != nil {
                if id = el.GetValue().(*entry).id; !first && id <= prev {
                        // skiplist loops back to first element so we have to detect loop and break manually
                        id = prev + 1
                        break
                }
                i++
                if !callback(el.GetValue().(*entry).val) {
                        id++
                        break
                }
                prev = id
                el = sf.list.Next(el)
                first = false
        }
        return id
}
func (sf *SkipFilter) getFilter(k interface{}) *filter {
        var f *filter
        val, ok := sf.cache.Get(k)
        if ok {
                f = val.(*filter)
        } else {
                f = &filter{i: 0, set: roaring64.New()}
                sf.cache.Add(k, f)
        }
        var id uint64
        var prev uint64
        var first = true
        if atomic.LoadUint64(&f.i) < sf.i {
                f.mutex.Lock()
                defer f.mutex.Unlock()
                for el, ok := sf.list.FindGreaterOrEqual(&entry{id: f.i}); ok && el != nil; el = sf.list.Next(el) {
                        if id = el.GetValue().(*entry).id; !first && id <= prev {
                                // skiplist loops back to first element so we have to detect loop and break manually
                                break
                        }
                        if sf.test(el.GetValue().(*entry).val, k) {
                                f.set.Add(id)
                        }
                        prev = id
                        first = false
                }
                f.i = sf.i
        }
        return f
}
func (sf *SkipFilter) getValues(set *roaring64.Bitmap) ([]interface{}, []uint64) {
        idBuf := make([]uint64, 512)
        iter := set.ManyIterator()
        values := []interface{}{}
        notfound := []uint64{}
        e := &entry{}
        for n := iter.NextMany(idBuf); n > 0; n = iter.NextMany(idBuf) {
                for i := 0; i < n; i++ {
                        e.id = idBuf[i]
                        el, ok := sf.list.Find(e)
                        if ok {
                                values = append(values, el.GetValue().(*entry).val)
                        } else {
                                notfound = append(notfound, idBuf[i])
                        }
                }
        }
        return values, notfound
}
type entry struct {
        id  uint64
        val interface{}
}
func (e *entry) ExtractKey() float64 {
        return float64(e.id)
}
func (e *entry) String() string {
        return fmt.Sprintf("%16x", e.id)
}
type filter struct {
        i     uint64
        mutex sync.RWMutex
        set   *roaring64.Bitmap
}