;; http://en.wikipedia.org/wiki/K-d_tree (module kd-tree ( default- Point3d Point2d point? make-point default- KdTree3d KdTree2d kd-tree? kd-tree-empty? kd-tree->list kd-tree->list* kd-tree-map kd-tree-for-each kd-tree-for-each* kd-tree-fold-right kd-tree-fold-right* kd-tree-subtrees kd-tree-points kd-tree-indices ) (import scheme chicken data-structures foreign) (require-library srfi-1 srfi-4 extras cis) (require-extension typeclass datatype) (import (only srfi-1 xcons fold list-tabulate split-at every fold-right take filter filter-map remove) (only srfi-4 f64vector-ref f64vector-length make-f64vector f64vector->list) (only extras fprintf pp) (only foreign foreign-lambda) (prefix cis cis:)) #> void cdslice(const int M, const int N, const double *X, double *Y) { unsigned int i,j; if (M >= N) return; for (j=0,i=M; i<=N; j++,i++) { Y[j] = X[i]; } } <# (define dslice (foreign-lambda void "cdslice" int int f64vector f64vector)) (define (f64vector-slice x m n) (if (>= m n) (error 'f64vector-slice "argument m is greater than or equal to n")) (let ((k (+ 1 (- n m)))) (let ((y (make-f64vector k))) (dslice m n x y) y))) (define-class ;; dimension returns the number of coordinates of a point. dimension ;; Point -> Int ;; gets the k'th coordinate, starting from 0. coord ;; Int * Point -> Double ;; compares the given coordinates compare-coord ;; Int * Point * Point -> Bool ;; returns the squared distance between two points. dist2 ;; Point * Point -> Double ;; returns the scaled squared distance between two points. sdist2 ;; Point * Point * [Int] -> Double ;; returns 0, negative or positive number depending on the ;; distance between two points compare-distance ) (define (minimum-by lst less? . rest) (if (null? lst) #f (if (null? rest) (let recur ((lst (cdr lst)) (m (car lst))) (if (null? lst) m (if (less? (car lst) m) (recur (cdr lst) (car lst)) (recur (cdr lst) m) )) ) (let recur ((lst (cdr lst)) (rest (map cdr rest)) (m (map car (cons lst rest)))) (if (null? lst) m (if (less? (car lst) (car m)) (recur (cdr lst) (map cdr rest) (map car (cons lst rest))) (recur (cdr lst) (map cdr rest) m) )) ) ))) (define (sum lst) (fold + 0. lst)) (define (default- dimension coord) (let* ((dist2 (lambda (a b) (let ((diff2 (lambda (i) (let ((v (- (coord i a) (coord i b)))) (* v v))))) (sum (list-tabulate (dimension a) diff2))))) (sdist2 (lambda (factors) (let ((factors2 (map (lambda (n) (* n n)) factors))) (lambda (a b) (let ((diff2 (lambda (i) (let ((v (- (coord i a) (coord i b)))) (* v v))))) (let ((v (sum (map * (list-tabulate (dimension a) diff2) factors2)))) v)))))) (compare-distance (lambda (p a b . reltol) (let ((delta (- (dist2 p a) (dist2 p b)))) (if (null? reltol) delta (if (<= delta (car reltol)) 0 delta))))) (compare-coord (lambda (c a b) (< (coord c a) (coord c b)))) ) (make- dimension coord compare-coord dist2 sdist2 compare-distance) )) (define point? vector?) (define make-point vector) (define Point3d (default- (lambda (p) (and (point? p) 3)) (lambda (i p) (vector-ref p i)) )) (define Point2d (default- (lambda (p) (and (point? p) 2)) (lambda (i p) (vector-ref p i)) )) (define-class ;; constructs a kd-tree from a list of points list->kd-tree ;; constructs a kd-tree from a list of f64vectors f64vector->kd-tree ;; nearest neighbor of a point kd-tree-nearest-neighbor ;; the index of the nearest neighbor of a point kd-tree-nearest-neighbor* ;; neighbors of a point within radius r kd-tree-near-neighbors ;; neighbors of a point within radius r (using point indices) kd-tree-near-neighbors* ;; k nearest neighbors of a point kd-tree-k-nearest-neighbors ;; removes a point from the tree kd-tree-remove ;; retrieves all points between two planes kd-tree-slice ;; retrieves all points between two planes (using point indices) kd-tree-slice* ;; checks that the kd-tree properties are preserved kd-tree-is-valid? kd-tree-all-subtrees-are-valid? ) (define-datatype kd-tree kd-tree? (KdNode (left kd-tree?) (p point?) (i (lambda (v) (or (integer? v) (and (pair? v) (integer? (car v)))))) (right kd-tree?) (axis integer?)) (KdLeaf (ii cis:cis?) (pp (lambda (lst) (every point? lst))) (vv list?) (axis integer?) ) ) (define (kd-tree-empty? t) (cases kd-tree t (KdLeaf (ii pp vv axis) (cis:empty? ii)) (else #f))) (define (kd-tree->list t) (kd-tree-fold-right cons '() t)) (define (kd-tree->list* t) (kd-tree-fold-right* (lambda (i x ax) (cons (list i x) ax)) '() t)) (define (kd-tree-map f t) (cases kd-tree t (KdLeaf (ii pp vv axis) (KdLeaf ii (map f pp) vv axis)) (KdNode (l x i r axis) (KdNode (kd-tree-map f l) (f x) i (kd-tree-map f r) axis)) )) (define (kd-tree-for-each f t) (cases kd-tree t (KdLeaf (ii pp vv axis) (for-each f pp)) (KdNode (l x i r axis) (begin (kd-tree-for-each f l) (f x) (kd-tree-for-each f r) )) )) (define (kd-tree-for-each* f t) (cases kd-tree t (KdLeaf (ii pp vv axis) (for-each f (cis:elements ii) pp)) (KdNode (l x i r axis) (begin (kd-tree-for-each* f l) (f i x) (kd-tree-for-each* f r) )) )) (define (kd-tree-fold-right f init t) (cases kd-tree t (KdLeaf (ii pp vv axis) (fold-right f init pp)) (KdNode (l x i r _) (let* ((init2 (kd-tree-fold-right f init r)) (init3 (f x init2))) (kd-tree-fold-right f init3 l))) )) (define (kd-tree-fold-right* f init t) (cases kd-tree t (KdLeaf (ii pp vv axis) (fold-right f init (reverse (cis:elements ii)) pp)) (KdNode (l x i r _) (let* ((init2 (kd-tree-fold-right* f init r)) (init3 (f i x init2))) (kd-tree-fold-right* f init3 l))) )) ;; Returns a list containing t and all its subtrees, including the ;; leaf nodes. (define (kd-tree-subtrees t) (cases kd-tree t (KdLeaf (ii pp vv axis) (list t)) (KdNode (l x i r axis) (append (kd-tree-subtrees l) (list t) (kd-tree-subtrees r))) )) (define (kd-tree-points t) (cases kd-tree t (KdLeaf (ii pp vv axis) pp) (KdNode (l x i r axis) (list x)) )) (define (kd-tree-indices t) (cases kd-tree t (KdLeaf (ii pp vv axis) (cis:elements ii)) (KdNode (l x i r axis) (list i)) )) ;; construct a kd-tree from a list of points (define=> (make-list->kd-tree/depth ) (lambda (make-point make-value) (letrec ( (split (lambda (m n points depth) (let* ((axis (modulo depth (dimension (make-point (car points))))) (cmpfn (lambda (p0 p1) (compare-coord axis (make-point p0) (make-point p1)))) (sorted (sort points cmpfn)) (median-index (quotient (- n m) 2))) (let-values (((lt gte) (split-at sorted median-index))) (values (car gte) median-index lt (cdr gte)))) )) (list->kd-tree/depth (lambda (m n points depth #!key (leaf-factor 10)) (cond ((null? points) (KdLeaf cis:empty '() '() depth)) ((<= (- n m) leaf-factor) (let ((k (- n m))) (let* ((es (take points k)) (ps (map make-point es)) (ii (cis:interval m (- n 1))) (vs (map make-value (reverse (cis:elements ii)) es))) (KdLeaf ii ps vs (modulo depth (dimension (car ps)))) ))) ((null? (cdr points)) (let* ((e (car points)) (p (make-point e)) (v (make-value m e))) (KdLeaf (cis:singleton (car v)) (list p) (cdr v) (modulo depth (dimension p))) )) (else (let-values (((median median-index lt gte) (split m n points depth))) (let* ((depth1 (+ 1 depth)) (p (make-point median)) (v (make-value (+ m median-index) median)) (axis (modulo depth (dimension p)))) (KdNode (list->kd-tree/depth m (+ m median-index) lt depth1 leaf-factor: leaf-factor) p v (list->kd-tree/depth (+ m median-index 1) n gte depth1 leaf-factor: leaf-factor) axis))) ))) )) list->kd-tree/depth ))) ;; construct a kd-tree from f64vectors with point coordinates ;; the points argument must be a list of f64vector for each axis: ;; ;; ( F64VECTOR-X F64VECTOR-Y F64VECTOR-Z ... ) ;; ;; NOTE: this procedure assumes that the axial vectors are sorted in ;; increasing order. (define=> (make-f64vector->kd-tree/depth ) (lambda (dimensions make-point make-value) (letrec ((axial-vectors-ref (lambda (axv i) (map (lambda (x) (f64vector-ref x i)) axv))) (axial-vectors-slice (lambda (axv m n) (map (lambda (x) (f64vector-slice x m n)) axv))) (f64vector->kd-tree/depth (lambda (m n axial-vectors depth #!key (leaf-factor 10)) (cond ((> m n) (KdLeaf cis:empty '() '() depth)) ((<= (- n m) leaf-factor) (let ((k (- n m))) (let* ((sl (axial-vectors-slice axial-vectors m n)) (ii (cis:interval m n)) (es (reverse (cis:elements ii))) (ps (map (compose make-point (lambda (i) (axial-vectors-ref sl (- i m)))) es)) (vs (map make-value es ps))) (KdLeaf ii ps vs (modulo depth (dimension (car ps)))) ))) ((= m n) (let* ((e (axial-vectors-ref axial-vectors m)) (p (make-point e)) (v (make-value m e))) (KdLeaf (cis:singleton (car v)) (list p) (cdr v) (modulo depth (dimension p))) )) (else (let* ((depth1 (+ 1 depth)) (median-index (+ m (quotient (- n m) 2))) (median (axial-vectors-ref axial-vectors median-index)) (p (make-point median)) (v (make-value median-index median)) (axis (modulo depth dimensions))) (KdNode (f64vector->kd-tree/depth m (- median-index 1) axial-vectors depth1 leaf-factor: leaf-factor) p v (f64vector->kd-tree/depth (+ median-index 1) n axial-vectors depth1 leaf-factor: leaf-factor) axis))) ))) ) f64vector->kd-tree/depth ))) ;; Returns the nearest neighbor of p in tree t. (define=> (make-kd-tree-nearest-neighbor ) (define (tree-empty? t) (cases kd-tree t (KdLeaf (ii pp vv axis) (cis:empty? ii)) (else #f))) (letrec ((find-nearest (lambda (t1 t2 p probe xp x-probe) (let* ((candidates1 (let ((best1 (nearest-neighbor t1 probe))) (or (and best1 (list best1 p)) (list p)))) (sphere-intersects-plane? (let ((v (- x-probe xp))) (< (* v v) (dist2 probe (car candidates1))))) (candidates2 (if sphere-intersects-plane? (let ((nn (nearest-neighbor t2 probe))) (if nn (append candidates1 (list nn)) candidates1)) candidates1))) (minimum-by candidates2 (lambda (a b) (negative? (compare-distance probe a b)))) ))) (nearest-neighbor (lambda (t probe) (cases kd-tree t (KdLeaf (ii pp vv axis) (minimum-by pp (lambda (a b) (negative? (compare-distance probe a b))))) (KdNode (l p i r axis) (if (and (tree-empty? l) (tree-empty? r)) p (let ((x-probe (coord axis probe)) (xp (coord axis p))) (if (<= x-probe xp) (find-nearest l r p probe xp x-probe) (find-nearest r l p probe xp x-probe)) )) )) ))) nearest-neighbor )) ;; Returns the index of the nearest neighbor of p in tree t. (define=> (make-kd-tree-nearest-neighbor* ) (define (tree-empty? t) (cases kd-tree t (KdLeaf (ii pp vv axis) (cis:empty? ii)) (else #f))) (letrec ((find-nearest (lambda (t1 t2 i p probe xp x-probe) (let* ((candidates1 (let ((best1 (nearest-neighbor t1 probe))) (or (and best1 (list best1 (list i p))) (list (list i p))))) (sphere-intersects-plane? (let ((v (- x-probe xp))) (< (* v v) (dist2 (cdr probe) (cadar candidates1))))) (candidates2 (if sphere-intersects-plane? (let ((nn (nearest-neighbor t2 probe))) (if nn (append candidates1 (list nn)) candidates1)) candidates1))) (minimum-by candidates2 (lambda (a b) (negative? (compare-distance (cdr probe) (cadr a) (cadr b))))) ))) (nearest-neighbor (lambda (t probe) (cases kd-tree t (KdLeaf (ii pp vv axis) (let ((v (minimum-by pp (lambda (a b) (negative? (compare-distance (cdr probe) a b))) (reverse (cis:elements ii)) ))) (and v (reverse v)))) (KdNode (l p i r axis) (if (and (tree-empty? l) (tree-empty? r)) (list i p) (let ((x-probe (coord axis (cdr probe))) (xp (coord axis p)) (xi i)) (if (<= x-probe xp) (find-nearest l r i p probe xp x-probe) (find-nearest r l i p probe xp x-probe)) )) )) ))) nearest-neighbor )) ;; nearNeighbors tree p returns all neighbors within distance r from p in tree t. (define=> (make-kd-tree-near-neighbors ) (define (tree-empty? t) (cases kd-tree t (KdLeaf (ii pp vv axis) (cis:empty? ii)) (else #f))) (letrec ((near-neighbors (lambda (t radius probe fdist) (cases kd-tree t (KdLeaf (ii pp vv axis) (let ((r2 (* radius radius))) (filter (lambda (p) (<= (fdist probe p) r2)) pp))) (KdNode (l p i r axis) (let ((maybe-pivot (if (<= (fdist probe p) (* radius radius)) (list p) '()))) (if (and (tree-empty? l) (tree-empty? r)) maybe-pivot (let ((x-probe (coord axis probe)) (xp (coord axis p))) (if (<= x-probe xp) (let ((nearest (append maybe-pivot (near-neighbors l radius probe fdist)))) (if (> (+ x-probe (abs radius)) xp) (append (near-neighbors r radius probe fdist) nearest) nearest)) (let ((nearest (append maybe-pivot (near-neighbors r radius probe fdist)))) (if (< (- x-probe (abs radius)) xp) (append (near-neighbors l radius probe fdist) nearest) nearest))) )))) )) )) (lambda (t radius probe #!key (factors #f)) (if (not factors) (near-neighbors t radius probe dist2) (near-neighbors t radius probe (sdist2 factors)))) )) (define=> (make-kd-tree-near-neighbors* ) (define (tree-empty? t) (cases kd-tree t (KdLeaf (ii pp vv axis) (cis:empty? ii)) (else #f))) (letrec ((near-neighbors (lambda (t radius probe fdist) (cases kd-tree t (KdLeaf (ii pp vv axis) (let ((rr (* radius radius))) (filter-map (lambda (p i) (and (<= (fdist probe p) rr) (cons i p))) pp (cis:elements ii)) )) (KdNode (l p i r axis) (let ((maybe-pivot (if (<= (fdist probe p) (* radius radius)) (list (list i p)) '()))) (if (and (tree-empty? l) (tree-empty? r)) maybe-pivot (let ((x-probe (coord axis probe)) (xp (coord axis p))) (if (<= x-probe xp) (let ((nearest (append maybe-pivot (near-neighbors l radius probe fdist)))) (if (> (+ x-probe (abs radius)) xp) (append (near-neighbors r radius probe fdist) nearest) nearest)) (let ((nearest (append maybe-pivot (near-neighbors r radius probe fdist)))) (if (< (- x-probe (abs radius)) xp) (append (near-neighbors l radius probe fdist) nearest) nearest))) )) )) )) )) (lambda (t radius probe #!key (factors #f)) (if (not factors) (near-neighbors t radius probe dist2) (near-neighbors t radius probe (sdist2 factors)))) )) ;; Returns the k nearest points to p within tree. (define=> (make-kd-tree-k-nearest-neighbors ) (lambda (kd-tree-remove kd-tree-nearest-neighbor) (letrec ((k-nearest-neighbors (lambda (t k probe) (cases kd-tree t (KdLeaf (ii pp vv axis) (let recur ((res '()) (pp pp) (k k)) (if (or (<= k 0) (null? pp)) res (let ((nearest (minimum-by pp (lambda (a b) (negative? (compare-distance probe a b)))))) (recur (cons nearest res) (remove (lambda (p) (equal? p nearest)) pp) (- k 1)) )) )) (else (if (<= k 0) '() (let* ((nearest (kd-tree-nearest-neighbor t probe)) (tree1 (kd-tree-remove t nearest))) (cons nearest (k-nearest-neighbors tree1 (- k 1) probe))) )) )) )) k-nearest-neighbors))) ;; removes the point p from t. (define=> (make-kd-tree-remove ) (lambda (list->kd-tree/depth) (letrec ((tree-remove (lambda (t p-kill) (cases kd-tree t (KdLeaf (ii pp vv axis) (let ((ipv (filter-map (lambda (p i v) (and (equal? p p-kill) (list i p v))) pp (reverse (cis:elements ii)) vv))) (let* ((ii1 (fold (lambda (i ax) (cis:remove i ax)) ii (map car ipv))) (pp1 (fold (lambda (x ax) (remove (lambda (p) (equal? p x)) ax)) pp (map cadr ipv))) (vv1 (fold (lambda (x ax) (remove (lambda (p) (equal? p x)) ax)) vv (map caddr ipv)))) (KdLeaf ii1 pp1 vv1 axis)) )) (KdNode (l p i r axis) (if (equal? p p-kill) (let ((pts1 (append (kd-tree->list l) (kd-tree->list r)))) (list->kd-tree/depth 0 (length pts1) pts1 axis)) (if (<= (coord axis p-kill) (coord axis p)) (KdNode (tree-remove l p-kill) p i r axis) (KdNode l p i (tree-remove r p-kill) axis)) )) )) )) tree-remove))) ;; Checks whether the K-D tree property holds for a given tree. ;; ;; Specifically, it tests that all points in the left subtree lie to ;; the left of the plane, p is on the plane, and all points in the ;; right subtree lie to the right. (define=> (make-kd-tree-is-valid? ) (lambda (t) (cases kd-tree t (KdLeaf (ii pp vv axis) #t) (KdNode (l p i r axis) (let ((x (coord axis p))) (and (every (lambda (y) (<= (coord axis y) x )) (kd-tree->list l)) (every (lambda (y) (>= (coord axis y) x)) (kd-tree->list r))))) ))) ;; Checks whether the K-D tree property holds for the given tree and ;; all subtrees. (define (make-kd-tree-all-subtrees-are-valid? kd-tree-is-valid?) (lambda (t) (every kd-tree-is-valid? (kd-tree-subtrees t)))) (define=> (make-kd-tree-slice ) (lambda (x-axis x1 x2 t) (let recur ((t t) (pts '())) (cases kd-tree t (KdLeaf (ii pp vv axis) (append (filter (lambda (p) (and (<= x1 (coord x-axis p)) (<= (coord x-axis p) x2))) pp) pts)) (KdNode (l p i r axis) (if (= axis x-axis) (cond ((and (<= x1 (coord axis p)) (<= (coord axis p) x2)) (recur l (cons p (recur r pts)))) ((< (coord axis p) x1) (recur r pts)) ((< x2 (coord axis p)) (recur l pts))) (if (and (<= x1 (coord x-axis p)) (<= (coord x-axis p) x2)) (recur l (cons p (recur r pts))) (recur l (recur r pts))) )) )) )) (define=> (make-kd-tree-slice* ) (lambda (x-axis x1 x2 t) (let recur ((t t) (pts '())) (cases kd-tree t (KdLeaf (ii pp vv axis) (append (filter-map (lambda (p i) (and (<= x1 (coord x-axis p)) (<= (coord x-axis p) x2) (cons i p))) pp (cis:elements ii)) pts)) (KdNode (l p i r axis) (if (= axis x-axis) (cond ((and (<= x1 (coord axis p)) (<= (coord axis p) x2)) (recur l (cons (cons i p) (recur r pts)))) ((< (coord axis p) x1) (recur r pts)) ((< x2 (coord axis p)) (recur l pts))) (if (and (<= x1 (coord x-axis p)) (<= (coord x-axis p) x2)) (recur l (cons (cons i p) (recur r pts))) (recur l (recur r pts))) )) )) )) (define (default- point-class) (let* ((list->kd-tree/depth (make-list->kd-tree/depth point-class)) (f64vector->kd-tree/depth (make-f64vector->kd-tree/depth point-class)) (kd-tree-remove ((make-kd-tree-remove point-class) (list->kd-tree/depth identity (lambda (i v) i)))) (kd-tree-nearest-neighbor (make-kd-tree-nearest-neighbor point-class))) (make- (lambda (points #!key (leaf-factor 10) (point-ref identity) (make-value (lambda (i v) i))) ((list->kd-tree/depth point-ref make-value) 0 (length points) points 0 leaf-factor: leaf-factor)) (lambda (axial-vectors #!key (leaf-factor 10) (point-ref list->vector) (make-value (lambda (i v) i))) (let ((dimensions (length axial-vectors)) (len (f64vector-length (car axial-vectors)))) (if (zero? len) (KdLeaf cis:empty '() '() 0) ((f64vector->kd-tree/depth dimensions point-ref make-value) 0 (- len 1) axial-vectors 0 leaf-factor: leaf-factor)))) (make-kd-tree-nearest-neighbor point-class) (make-kd-tree-nearest-neighbor* point-class) (make-kd-tree-near-neighbors point-class) (make-kd-tree-near-neighbors* point-class) ((make-kd-tree-k-nearest-neighbors point-class) kd-tree-remove kd-tree-nearest-neighbor) kd-tree-remove (make-kd-tree-slice point-class) (make-kd-tree-slice* point-class) (make-kd-tree-is-valid? point-class) (make-kd-tree-all-subtrees-are-valid? (make-kd-tree-is-valid? point-class)) ))) (define KdTree3d (default- Point3d)) (define KdTree2d (default- Point2d)) )