package rtree import ( "math" "sync" "github.com/tidwall/rtree/base" ) type Iterator func(item Item) bool type Item interface { Rect(ctx interface{}) (min []float64, max []float64) } type RTree struct { dims int maxEntries int ctx interface{} trs []*base.RTree used int } func New(ctx interface{}) *RTree { tr := &RTree{ ctx: ctx, dims: 20, maxEntries: 13, } tr.trs = make([]*base.RTree, 20) return tr } func (tr *RTree) Insert(item Item) { if item == nil { panic("nil item") } min, max := item.Rect(tr.ctx) if len(min) != len(max) { return // just return panic("invalid item rectangle") } if len(min) < 1 || len(min) > len(tr.trs) { return // just return panic("invalid dimension") } btr := tr.trs[len(min)-1] if btr == nil { btr = base.New(len(min), tr.maxEntries) tr.trs[len(min)-1] = btr tr.used++ } amin := make([]float64, len(min)) amax := make([]float64, len(max)) for i := 0; i < len(min); i++ { amin[i], amax[i] = min[i], max[i] } btr.Insert(amin, amax, item) } func (tr *RTree) Remove(item Item) { if item == nil { panic("nil item") } min, max := item.Rect(tr.ctx) if len(min) != len(max) { return // just return panic("invalid item rectangle") } if len(min) < 1 || len(min) > len(tr.trs) { return // just return panic("invalid dimension") } btr := tr.trs[len(min)-1] if btr == nil { return } amin := make([]float64, len(min)) amax := make([]float64, len(max)) for i := 0; i < len(min); i++ { amin[i], amax[i] = min[i], max[i] } btr.Remove(amin, amax, item) if btr.IsEmpty() { tr.trs[len(min)-1] = nil tr.used-- } } func (tr *RTree) Reset() { for i := 0; i < len(tr.trs); i++ { tr.trs[i] = nil } tr.used = 0 } func (tr *RTree) Count() int { var count int for _, btr := range tr.trs { if btr != nil { count += btr.Count() } } return count } func (tr *RTree) Search(bounds Item, iter Iterator) { if bounds == nil { panic("nil bounds being used for search") } min, max := bounds.Rect(tr.ctx) if len(min) != len(max) { return // just return panic("invalid item rectangle") } if len(min) < 1 || len(min) > len(tr.trs) { return // just return panic("invalid dimension") } used := tr.used for i, btr := range tr.trs { if used == 0 { break } if btr != nil { if !search(btr, min, max, i+1, iter) { return } used-- } } } func search(btr *base.RTree, min, max []float64, dims int, iter Iterator) bool { amin := make([]float64, dims) amax := make([]float64, dims) for i := 0; i < dims; i++ { if i < len(min) { amin[i] = min[i] amax[i] = max[i] } else { amin[i] = math.Inf(-1) amax[i] = math.Inf(+1) } } var ended bool btr.Search(amin, amax, func(item interface{}) bool { if !iter(item.(Item)) { ended = true return false } return true }) return !ended } func (tr *RTree) KNN(bounds Item, center bool, iter func(item Item, dist float64) bool) { if bounds == nil { panic("nil bounds being used for search") } min, max := bounds.Rect(tr.ctx) if len(min) != len(max) { return // just return panic("invalid item rectangle") } if len(min) < 1 || len(min) > len(tr.trs) { return // just return panic("invalid dimension") } if tr.used == 0 { return } if tr.used == 1 { for i, btr := range tr.trs { if btr != nil { knn(btr, min, max, center, i+1, func(item interface{}, dist float64) bool { return iter(item.(Item), dist) }) break } } return } type queueT struct { done bool step int item Item dist float64 } var mu sync.Mutex var ended bool queues := make(map[int][]queueT) cond := sync.NewCond(&mu) for i, btr := range tr.trs { if btr != nil { dims := i + 1 mu.Lock() queues[dims] = []queueT{} cond.Signal() mu.Unlock() go func(dims int, btr *base.RTree) { knn(btr, min, max, center, dims, func(item interface{}, dist float64) bool { mu.Lock() if ended { mu.Unlock() return false } queues[dims] = append(queues[dims], queueT{item: item.(Item), dist: dist}) cond.Signal() mu.Unlock() return true }) mu.Lock() queues[dims] = append(queues[dims], queueT{done: true}) cond.Signal() mu.Unlock() }(dims, btr) } } mu.Lock() for { ready := true for i := range queues { if len(queues[i]) == 0 { ready = false break } if queues[i][0].done { delete(queues, i) } } if len(queues) == 0 { break } if ready { var j int var minDist float64 var minItem Item var minQueue int for i := range queues { if j == 0 || queues[i][0].dist < minDist { minDist = queues[i][0].dist minItem = queues[i][0].item minQueue = i } } queues[minQueue] = queues[minQueue][1:] if !iter(minItem, minDist) { ended = true break } continue } cond.Wait() } mu.Unlock() } func knn(btr *base.RTree, min, max []float64, center bool, dims int, iter func(item interface{}, dist float64) bool) bool { amin := make([]float64, dims) amax := make([]float64, dims) for i := 0; i < dims; i++ { if i < len(min) { amin[i] = min[i] amax[i] = max[i] } else { amin[i] = math.Inf(-1) amax[i] = math.Inf(+1) } } var ended bool btr.KNN(amin, amax, center, func(item interface{}, dist float64) bool { if !iter(item.(Item), dist) { ended = true return false } return true }) return !ended }