;; http://en.wikipedia.org/wiki/K-d_tree (module kd-tree ( Point-point3d point3d? make-point3d point3d-x point3d-y point3d-z kd-tree? KdNode KdEmpty kd-tree-empty? kd-tree->list kd-tree->list* kd-tree-map kd-tree-for-each kd-tree-for-each* kd-tree-fold-right kd-tree-fold-right* kd-tree-subtrees kd-tree-point list->kd-tree kd-tree-nearest-neighbor kd-tree-near-neighbors kd-tree-near-neighbors* kd-tree-k-nearest-neighbors kd-tree-remove kd-tree-slice kd-tree-slice* kd-tree-is-valid? kd-tree-all-subtrees-are-valid? ) (import scheme chicken data-structures) (require-library srfi-1 extras) (require-extension typeclass datatype) (import (only srfi-1 xcons fold list-tabulate drop take every) (only extras fprintf)) (define-class ;; dimension returns the number of coordinates of a point. dimension ;; Point -> Int ;; gets the k'th coordinate, starting from 0. coord ;; Int * Point -> Double ;; returns the squared distance between two points. dist2 ;; Point * Point -> Double ;; returns 0, negative or positive number depending on the ;; distance between two points compare-distance ) (define (minimum-by lst less?) (if (null? lst) #f (let recur ((lst (cdr lst)) (m (car lst))) (if (null? lst) m (if (less? (car lst) m) (recur (cdr lst) (car lst)) (recur (cdr lst) m) )) ))) (define (sum lst) (fold + 0. lst)) (define (default- dimension coord) (let* ((dist2 (lambda (a b) (let ((diff2 (lambda (i) (let ((v (- (coord i a) (coord i b)))) (* v v))))) (sum (list-tabulate (dimension a) diff2))))) (compare-distance (lambda (p a b . reltol) (let ((delta (- (dist2 p a) (dist2 p b)))) (if (null? reltol) delta (if (<= delta (car reltol)) 0 delta)))))) (make- dimension coord dist2 compare-distance) )) (define-record-type point3d (make-point3d x y z) point3d? (x point3d-x) (y point3d-y) (z point3d-z) ) (define-record-printer (point3d p out) (fprintf out "#<~a,~a,~a>" (point3d-x p) (point3d-y p) (point3d-z p) )) (define Point-point3d (default- (lambda (p) (and (point3d? p) 3)) (lambda (i p) (case i ((0) (point3d-x p)) ((1) (point3d-y p)) ((2) (point3d-z p)))) )) ;; Selects an axis based on depth so that the axis cycles through all ;; valid values. (define=> (make-list->kd-tree/depth ) (letrec ((list->kd-tree/depth (lambda (m n points depth) (if (null? points) (KdEmpty) (let* ((axis (modulo depth (dimension (car points)))) ;; Sort point list and choose median as pivot element (sorted-points (sort points (lambda (a b) (< (coord axis a) (coord axis b))))) (median-index (quotient (- (- n m) 1) 2))) (KdNode (list->kd-tree/depth m (+ m median-index) (take sorted-points median-index) (+ 1 depth)) (list-ref sorted-points median-index) (+ m median-index) (list->kd-tree/depth (+ m median-index 1) n (drop sorted-points (+ median-index 1)) (+ 1 depth)) axis) )) ))) list->kd-tree/depth)) ;; Returns the nearest neighbor of p in tree t. (define=> (make-kd-tree-nearest-neighbor ) (define (tree-empty? t) (cases kd-tree t (KdEmpty () #t) (else #f))) (letrec ((find-nearest (lambda (t1 t2 p probe xp x-probe) (let* ((candidates1 (let ((best1 (nearest-neighbor t1 probe))) (or (and best1 (list best1 p)) (list p)))) (sphere-intersects-plane? (let ((v (- x-probe xp))) (< (* v v) (dist2 probe (car candidates1))))) (candidates2 (if sphere-intersects-plane? (let ((nn (nearest-neighbor t2 probe))) (if nn (append candidates1 (list nn)) candidates1)) candidates1))) (minimum-by candidates2 (lambda (a b) (negative? (compare-distance probe a b)))) ))) (nearest-neighbor (lambda (t probe) (cases kd-tree t (KdEmpty () #f) (KdNode (l p i r axis) (if (and (tree-empty? l) (tree-empty? r)) p (let ((x-probe (coord axis probe)) (xp (coord axis p))) (if (<= x-probe xp) (find-nearest l r p probe xp x-probe) (find-nearest r l p probe xp x-probe)) )) )) ))) nearest-neighbor )) ;; nearNeighbors tree p returns all neighbors within distance r from p in tree t. (define=> (make-kd-tree-near-neighbors ) (define (tree-empty? t) (cases kd-tree t (KdEmpty () #t) (else #f))) (letrec ((near-neighbors (lambda (t radius probe) (cases kd-tree t (KdEmpty () '()) (KdNode (l p i r axis) (let ((maybe-pivot (if (<= (dist2 probe p) (* radius radius)) (list p) '()))) (if (and (tree-empty? l) (tree-empty? r)) maybe-pivot (let ((x-probe (coord axis probe)) (xp (coord axis p))) (if (<= x-probe xp) (let ((nearest (append maybe-pivot (near-neighbors l radius probe)))) (if (> (+ x-probe (abs radius)) xp) (append (near-neighbors r radius probe) nearest) nearest)) (let ((nearest (append maybe-pivot (near-neighbors r radius probe)))) (if (< (- x-probe (abs radius)) xp) (append (near-neighbors l radius probe) nearest) nearest))) )) )) )) )) near-neighbors )) (define=> (make-kd-tree-near-neighbors* ) (define (tree-empty? t) (cases kd-tree t (KdEmpty () #t) (else #f))) (letrec ((near-neighbors (lambda (t radius probe) (cases kd-tree t (KdEmpty () '()) (KdNode (l p i r axis) (let ((maybe-pivot (if (<= (dist2 probe p) (* radius radius)) (list (list i p)) '()))) (if (and (tree-empty? l) (tree-empty? r)) maybe-pivot (let ((x-probe (coord axis probe)) (xp (coord axis p))) (if (<= x-probe xp) (let ((nearest (append maybe-pivot (near-neighbors l radius probe)))) (if (> (+ x-probe (abs radius)) xp) (append (near-neighbors r radius probe) nearest) nearest)) (let ((nearest (append maybe-pivot (near-neighbors r radius probe)))) (if (< (- x-probe (abs radius)) xp) (append (near-neighbors l radius probe) nearest) nearest))) )) )) )) )) near-neighbors )) ;; Returns the k nearest points to p within tree. (define=> (make-kd-tree-k-nearest-neighbors ) (letrec ((k-nearest-neighbors (lambda (t k probe) (cases kd-tree t (KdEmpty () '()) (else (if (<= k 0) '() (let* ((nearest (kd-tree-nearest-neighbor t probe)) (tree1 (kd-tree-remove t nearest))) (cons nearest (k-nearest-neighbors tree1 (- k 1) probe))) )) )) )) k-nearest-neighbors)) ;; removes the point p from t. (define=> (make-kd-tree-remove ) (letrec ((remove (lambda (t p-kill) (cases kd-tree t (KdEmpty () (KdEmpty)) (KdNode (l p i r axis) (if (equal? p p-kill) (let ((pts1 (append (kd-tree->list l) (kd-tree->list r)))) (list->kd-tree/depth 0 (length pts1) pts1 axis)) (if (<= (coord axis p-kill) (coord axis p)) (KdNode (remove l p-kill) p i r axis) (KdNode l p i (remove r p-kill) axis)) )) )) )) remove)) ;; Checks whether the K-D tree property holds for a given tree. ;; ;; Specifically, it tests that all points in the left subtree lie to ;; the left of the plane, p is on the plane, and all points in the ;; right subtree lie to the right. (define=> (make-kd-tree-is-valid? ) (lambda (t) (cases kd-tree t (KdEmpty () #t) (KdNode (l p i r axis) (let ((x (coord axis p))) (and (every (lambda (y) (<= (coord axis y) x )) (kd-tree->list l)) (every (lambda (y) (>= (coord axis y) x)) (kd-tree->list r))))) ))) ;; Checks whether the K-D tree property holds for the given tree and ;; all subtrees. (define=> (make-kd-tree-all-subtrees-are-valid? ) (lambda (t) (every kd-tree-is-valid? (kd-tree-subtrees t)))) (define-datatype kd-tree kd-tree? (KdNode (left kd-tree?) (p point3d?) (i integer?) (right kd-tree?) (axis integer?)) (KdEmpty)) (define (kd-tree-empty? t) (cases kd-tree t (KdEmpty () #t) (else #f))) (define (kd-tree->list t) (kd-tree-fold-right cons '() t)) (define (kd-tree->list* t) (kd-tree-fold-right* (lambda (i x ax) (cons (list i x) ax)) '() t)) (define (kd-tree-map f t) (cases kd-tree t (KdEmpty () (KdEmpty)) (KdNode (l x i r axis) (KdNode (kd-tree-map f l) (f x) (kd-tree-map f r) axis)) )) (define (kd-tree-for-each f t) (cases kd-tree t (KdEmpty () (begin)) (KdNode (l x i r axis) (begin (kd-tree-for-each f l) (f x) (kd-tree-for-each f r) )) )) (define (kd-tree-for-each* f t) (cases kd-tree t (KdEmpty () (begin)) (KdNode (l x i r axis) (begin (kd-tree-for-each* f l) (f i x) (kd-tree-for-each* f r) )) )) (define (kd-tree-fold-right f init t) (cases kd-tree t (KdEmpty () init) (KdNode (l x i r _) (let* ((init2 (kd-tree-fold-right f init r)) (init3 (f x init2))) (kd-tree-fold-right f init3 l))) )) (define (kd-tree-fold-right* f init t) (cases kd-tree t (KdEmpty () init) (KdNode (l x i r _) (let* ((init2 (kd-tree-fold-right* f init r)) (init3 (f i x init2))) (kd-tree-fold-right* f init3 l))) )) (define=> (make-kd-tree-slice ) (lambda (x-axis x1 x2 t) (let recur ((t t) (pts '())) (cases kd-tree t (KdEmpty () pts) (KdNode (l p i r axis) (if (= axis x-axis) (cond ((and (<= x1 (coord axis p)) (<= (coord axis p) x2)) (recur l (cons p (recur r pts)))) ((< (coord axis p) x1) (recur r pts)) ((< x2 (coord axis p)) (recur l pts))) (if (and (<= x1 (coord x-axis p)) (<= (coord x-axis p) x2)) (recur l (cons p (recur r pts))) (recur l (recur r pts))) )) )) )) (define=> (make-kd-tree-slice* ) (lambda (x-axis x1 x2 t) (let recur ((t t) (pts '())) (cases kd-tree t (KdEmpty () pts) (KdNode (l p i r axis) (if (= axis x-axis) (cond ((and (<= x1 (coord axis p)) (<= (coord axis p) x2)) (recur l (cons (cons i p) (recur r pts)))) ((< (coord axis p) x1) (recur r pts)) ((< x2 (coord axis p)) (recur l pts))) (if (and (<= x1 (coord x-axis p)) (<= (coord x-axis p) x2)) (recur l (cons p (recur r pts))) (recur l (recur r pts))) )) )) )) ;; Returns a list containing t and all its subtrees, including the ;; empty leaf nodes. (define (kd-tree-subtrees t) (cases kd-tree t (KdEmpty () (list (KdEmpty))) (KdNode (l x i r axis) (append (kd-tree-subtrees l) (list t) (kd-tree-subtrees r))) )) (define (kd-tree-point t) (cases kd-tree t (KdEmpty () #f) (KdNode (l x i r axis) x) )) (define list->kd-tree/depth (make-list->kd-tree/depth Point-point3d)) (define (list->kd-tree points) (list->kd-tree/depth 0 (length points) points 0)) (define kd-tree-nearest-neighbor (make-kd-tree-nearest-neighbor Point-point3d)) (define kd-tree-near-neighbors (make-kd-tree-near-neighbors Point-point3d)) (define kd-tree-near-neighbors* (make-kd-tree-near-neighbors* Point-point3d)) (define kd-tree-k-nearest-neighbors (make-kd-tree-k-nearest-neighbors Point-point3d)) (define kd-tree-remove (make-kd-tree-remove Point-point3d)) (define kd-tree-slice (make-kd-tree-slice Point-point3d)) (define kd-tree-slice* (make-kd-tree-slice* Point-point3d)) (define kd-tree-is-valid? (make-kd-tree-is-valid? Point-point3d)) (define kd-tree-all-subtrees-are-valid? (make-kd-tree-all-subtrees-are-valid? Point-point3d)) )