;;;; Copyright (c) 2017, Jeremy Steward ;;;; All rights reserved. ;;;; ;;;; Redistribution and use in source and binary forms, with or without ;;;; modification, are permitted provided that the following conditions are met: ;;;; ;;;; 1. Redistributions of source code must retain the above copyright notice, ;;;; this list of conditions and the following disclaimer. ;;;; ;;;; 2. Redistributions in binary form must reproduce the above copyright notice, ;;;; this list of conditions and the following disclaimer in the documentation ;;;; and/or other materials provided with the distribution. ;;;; ;;;; 3. Neither the name of the copyright holder nor the names of its ;;;; contributors may be used to endorse or promote products derived from this ;;;; software without specific prior written permission. ;;;; ;;;; THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" ;;;; AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE ;;;; IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ;;;; ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE ;;;; LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR ;;;; CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF ;;;; SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS ;;;; INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN ;;;; CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ;;;; ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE ;;;; POSSIBILITY OF SUCH DAMAGE. ;;; Base interfaces and helper procs (define-record-type array (%make-array storage-class shape stride offset mutable? storage-object) array? (storage-class %array-storage-class : (struct storage-class)) (shape %array-shape (setter %array-shape) : (vector-of fixnum)) (stride %array-stride (setter %array-stride) : (vector-of fixnum)) (offset %array-offset : fixnum) (mutable? %array-mutable? (setter %array-mutable?) : boolean) (storage-object %array-storage-object)) (define-check+error-type array array?) (define-check+error-type array-index (lambda (obj) (and (vector? obj) (vector-every fixnum? obj)))) (define *UNSPECIFIED* (list 'unspecified)) ;; Calculates the stride of an array based on the shape. Implicitly assumes ;; row-major ordering (as does most of this API), but can optionally be used to ;; specify column major ordering if necessary. (define (shape->stride shape #!optional (make-row-major? #t)) (if (and (fx> (vector-length shape) 2) (fx= 1 (vector-ref shape (sub1 (vector-length shape))))) (let* ((total (vector-fold fx* 1 shape)) (stride (vector-map (cute fx/ total <>) (vector-cumulate fx* 1 shape)))) (unless make-row-major? (vector-reverse! stride)) stride) (let ((f (vector-ref shape 0))) (let ((stride (vector-map (cute fx/ <> f) (vector-cumulate fx* 1 shape)))) (when make-row-major? (vector-reverse! stride)) stride)))) (define (row-major? array) (apply >= (vector->list (%array-stride array)))) ;; Quick and dirty way to get the "rank" of an array. Rank here refers to the ;; same notation as in Fortran or C, not actual mathematical rank. (define (%array-rank array) (vector-length (%array-stride array))) (define (%array-lower-bound array) (make-vector (%array-rank array) 0)) (define (%array-upper-bound array) (vector-copy (%array-shape array))) ;; Converts an array index to a storage index. ;; This API is unsafe, see the public version without the % prefix for more ;; information. (define (%array-index->storage-index array index) (fx+ (%array-offset array) (match (list (%array-stride array) index) ;; Rank-1 Case ((#(si) #(i)) (fx* si i)) ;; Rank-2 Case ((#(si sj) #(i j)) (fx+ (fx* si i) (fx* sj j))) ;; Rank-3 Case ((#(si sj sk) #(i j k)) (fx+ (fx* si i) (fx+ (fx* sj j) (fx* sk k)))) ;; Rank-4 Case ((#(si sj sk sv) #(i j k v)) (fx+ (fx* si i) (fx+ (fx* sj j) (fx+ (fx* sk k) (fx* sv v))))) ;; Rank-N Case (_ (vector-fold fx+ 0 (vector-map fx* (%array-stride array) index)))))) ;; Increments an index up to the upper-bound. ;; Some examples assume upper bound of #(3 3 3): ;; ;; #(0 0 0) -> #(0 0 1) ;; #(0 0 1) -> #(0 0 2) ;; #(0 0 2) -> #(0 1 0) ;; ;; Note how in every case it increments up to the upper bound of a given ;; dimension, but never to it. If the index cannot be incremented any further ;; beyond its max value, then the upper bound is returned. (define (%increment-index index upper-bound) (let loop ((dim (sub1 (vector-length index))) (index index) (upper-bound (vector-copy upper-bound))) (cond ((or (fx< dim 0) (fx>= (vector-ref index 0) (vector-ref upper-bound 0))) upper-bound) ((fx= (add1 (vector-ref index dim)) (vector-ref upper-bound dim)) (vector-set! index dim 0) (loop (sub1 dim) index upper-bound)) (else (vector-set! index dim (add1 (vector-ref index dim))) index)))) ;; Executes proc on every possible index value of array, in lexicographic ;; order. e.g. (%array-for-each-index print ) prints all the ;; indices in lexicographic order. (define (%array-for-each-index proc start end) (match (list start end) ;; Rank-1 Case ((#(li) #(ui)) (let i-loop ((i li)) (when (< i ui) (proc (vector i)) (i-loop (add1 i))))) ;; Rank-2 Case ((#(li lj) #(ui uj)) (let i-loop ((i li)) (when (< i ui) (let j-loop ((j lj)) (if (< j uj) (begin (proc (vector i j)) (j-loop (add1 j))) (i-loop (add1 i))))))) ;; Rank-3 Case ((#(li lj lk) #(ui uj uk)) (let i-loop ((i li)) (when (< i ui) (let j-loop ((j lj)) (if (< j uj) (let k-loop ((k lk)) (if (< k uk) (begin (proc (vector i j k)) (k-loop (add1 k))) (j-loop (add1 j)))) (i-loop (add1 i))))))) ;; Rank-4 Case ((#(li lj lk lv) #(ui uj uk uv)) (let i-loop ((i li)) (when (< i ui) (let j-loop ((j lj)) (if (< j uj) (let k-loop ((k lk)) (if (< k uk) (let v-loop ((v lv)) (if (< v uv) (begin (proc (vector i j k v)) (v-loop (add1 v))) (k-loop (add1 k)))) (j-loop (add1 j)))) (i-loop (add1 i))))))) ;; Rank-N Case (((and start (? vector?) (and end (? vector?)))) (let loop ((index (vector-copy start)) (upper-bound (vector-copy end))) (cond ((vector= fx= index upper-bound) (void)) (else (proc index) (loop (%increment-index index upper-bound) upper-bound))))) ;; Default rule (_ (error "start and end must be vectors of the same rank." start end)))) (define (%array-ref array index) (let ((storage-index (%array-index->storage-index array index)) (ref (storage-class-accessor (%array-storage-class array))) (storage-object (%array-storage-object array))) (ref storage-object storage-index))) (define (%array-set! array index value) (let ((storage-index (%array-index->storage-index array index)) (storage-set! (storage-class-mutator (%array-storage-class array))) (storage-object (%array-storage-object array))) (storage-set! storage-object storage-index value))) (define (%%dummy-make-array storage-class shape #!optional (fill *UNSPECIFIED*)) (let* ((nelems (vector-fold * 1 shape)) (stride (shape->stride shape)) (storage-object (make-storage-object storage-class nelems (if (eq? fill *UNSPECIFIED*) (storage-class-default-fill storage-class) fill)))) (%make-array storage-class shape stride 0 #t storage-object))) ;;; Derived interfaces and helper procs (define (%array-map storage-class proc array arrays) (let ((all-arrays (cons array arrays)) (new-array (%%dummy-make-array storage-class (%array-shape array)))) (let ((data (%array-storage-object new-array)) (storage-set! (storage-class-mutator (%array-storage-class new-array)))) (%array-for-each-index (lambda (index) (let ((i (%array-index->storage-index array index))) (storage-set! data i (apply proc (map (lambda (a) (let ((storage-ref (storage-class-accessor (%array-storage-class a))) (storage-object (%array-storage-object a)) (storage-index (%array-index->storage-index a index))) (storage-ref storage-object storage-index))) all-arrays))))) (%array-lower-bound new-array) (%array-upper-bound new-array))) (unless (%array-mutable? array) (set! (%array-mutable? new-array) #f)) new-array)) (define (%array-map! proc array arrays) (let ((all-arrays (cons array arrays)) (data (%array-storage-object array)) (storage-set! (storage-class-mutator (%array-storage-class array)))) (%array-for-each-index (lambda (index) (let ((i (%array-index->storage-index array index))) (storage-set! data i (apply proc (map (lambda (a) (let ((storage-ref (storage-class-accessor (%array-storage-class a))) (storage-object (%array-storage-object a)) (storage-index (%array-index->storage-index a index))) (storage-ref storage-object storage-index))) all-arrays))))) (%array-lower-bound array) (%array-upper-bound array)) array)) (define (%array-fold proc seed array arrays) (let ((acc seed) (all-arrays (cons array arrays))) (%array-for-each-index (lambda (index) (let ((args (map (lambda (a) (let ((storage-ref (storage-class-accessor (%array-storage-class a))) (storage-object (%array-storage-object a)) (storage-index (%array-index->storage-index a index))) (storage-ref storage-object storage-index))) all-arrays))) (set! acc (apply proc acc args)))) (%array-lower-bound array) (%array-upper-bound array)) acc)) (define (remove-axis-from-index index axis) (match (list index axis) ((#(i) 0) (vector 0)) ((#(i j) 0) (vector j)) ((#(i j) 1) (vector i)) ((#(i j k) 0) (vector j k)) ((#(i j k) 1) (vector i k)) ((#(i j k) 2) (vector i j)) ((#(i j k v) 0) (vector j k v)) ((#(i j k v) 1) (vector i k v)) ((#(i j k v) 2) (vector i j v)) ((#(i j k v) 3) (vector i j k)) (_ (let ((new-index (make-vector (sub1 (vector-length index))))) (vector-fold (lambda (i x) (cond ((eq? i axis) (add1 i)) ((> i axis) (vector-set! new-index (sub1 i) (vector-ref index i)) (add1 i)) (else (vector-set! new-index i (vector-ref index i)) (add1 i)))) 0 index) new-index)))) (define (%array-reduce proc seed array axis) (let* ((shape (%array-shape array)) (new-array (%%dummy-make-array (%array-storage-class array) (if (= (%array-rank array) 1) (vector 1) (remove-axis-from-index shape axis)) seed))) (%array-for-each-index (lambda (index) (let ((new-index (remove-axis-from-index index axis))) (%array-set! new-array new-index (proc (%array-ref new-array new-index) (%array-ref array index))))) (%array-lower-bound array) (%array-upper-bound array)) new-array)) (define (%array-cumulate proc seed array axis) (let* ((shape (%array-shape array)) (new-array (%%dummy-make-array (%array-storage-class array) shape seed))) (%array-for-each-index (lambda (index) (let ((last-index (vector-copy index))) (unless (zero? (vector-ref index axis)) (vector-set! last-index axis (sub1 (vector-ref last-index axis)))) (%array-set! new-array index (proc (%array-ref new-array last-index) (%array-ref array index))))) (%array-lower-bound array) (%array-upper-bound array)) new-array)) (define (compress-index index booleans axis) (let* ((new-index (vector-copy index)) (axis-value (vector-ref index axis)) (nitems-skipped (vector-count not (subvector booleans 0 axis-value)))) (vector-set! new-index axis (fx- axis-value nitems-skipped)) new-index)) (define (%array-compress array booleans axis) (let ((new-axis-length (vector-count identity booleans)) (new-shape (vector-copy (%array-shape array)))) (vector-set! new-shape axis new-axis-length) (let ((new-array (%%dummy-make-array (%array-storage-class array) new-shape))) (%array-for-each-index (lambda (index) (when (vector-ref booleans (vector-ref index axis)) (let ((compressed-index (compress-index index booleans axis))) (%array-set! new-array compressed-index (%array-ref array index))))) (%array-lower-bound array) (%array-upper-bound array)) new-array))) (define (expand-index index booleans axis) (let* ((new-index (vector-copy index)) (axis-value (vector-ref index axis)) (nitems-inserted (vector-count identity (subvector booleans 0 axis-value)))) (vector-set! new-index axis (fx+ axis-value nitems-inserted)) new-index)) (define (%array-expand array booleans nil axis) (let ((new-axis-length (fx+ (vector-ref (%array-shape array) axis) (vector-count identity booleans))) (new-shape (vector-copy (%array-shape array)))) (vector-set! new-shape axis new-axis-length) (let ((new-array (%%dummy-make-array (%array-storage-class array) new-shape))) (%array-for-each-index (lambda (index) (let ((expanded-index (expand-index index booleans axis))) (%array-set! new-array expanded-index (%array-ref array index)) (when (vector-ref booleans (vector-ref index axis)) (vector-set! expanded-index axis (add1 (vector-ref expanded-index axis))) (let ((nil-index (remove-axis-from-index index axis))) (%array-set! new-array expanded-index (%array-ref nil nil-index)))))) (%array-lower-bound array) (%array-upper-bound array)) new-array))) (define (%array-rearrange array order axis) (let ((new-array (%%dummy-make-array (%array-storage-class array) (%array-shape array))) (axis-length (vector-ref (%array-shape array) axis)) (lower-bound (vector-copy (%array-lower-bound array))) (upper-bound (vector-copy (%array-upper-bound array))) (new-lower-bound (vector-copy (%array-lower-bound array))) (new-upper-bound (vector-copy (%array-upper-bound array)))) (vector-for-each (lambda (i) (vector-set! lower-bound axis i) (vector-set! upper-bound axis (add1 i)) (vector-set! new-upper-bound axis (add1 (vector-ref new-lower-bound axis))) (%array-copy! (%array-slice array lower-bound upper-bound (vector-map (lambda _ 1) lower-bound)) (%array-slice new-array new-lower-bound new-upper-bound (vector-map (lambda _ 1) new-lower-bound))) (vector-set! new-lower-bound axis (add1 (vector-ref new-lower-bound axis)))) order) new-array)) (define (%array-transform proc shape array) (let* ((new-array (%%dummy-make-array (%array-storage-class array) shape)) (new-data (%array-storage-object new-array)) (data (%array-storage-object array))) (%array-for-each-index (lambda (index) (let ((from-i (%array-index->storage-index array (proc index))) (to-i (%array-index->storage-index new-array index)) (storage-set! (storage-class-mutator (%array-storage-class new-array))) (storage-ref (storage-class-accessor (%array-storage-class array)))) (storage-set! new-data to-i (storage-ref data from-i)))) (%array-lower-bound new-array) (%array-upper-bound new-array)) new-array)) (define (%array-rearrange-axes array order) (let ((shape (%array-shape array)) (stride (%array-stride array))) (%make-array (%array-storage-class array) (vector-map (cute vector-ref shape <>) order) (vector-map (cute vector-ref stride <>) order) (%array-offset array) (%array-mutable? array) (%array-storage-object array)))) (define (%array-slice array start end step) (%make-array (%array-storage-class array) (vector-map (compose inexact->exact ceiling /) (vector-map fx- end start) step) (vector-map fx* (%array-stride array) step) (%array-index->storage-index array start) (%array-mutable? array) (%array-storage-object array))) (define (squeeze-vector vec axes) (let ((limit (vector-length vec))) (let loop ((i 0) (acc '())) (cond ((fx>= i limit) (reverse-list->vector acc)) ((vector-index (cute = <> i) axes) (loop (add1 i) acc)) (else (loop (add1 i) (cons (vector-ref vec i) acc))))))) (define (%array-squeeze array axes) (%make-array (%array-storage-class array) (squeeze-vector (%array-shape array) axes) (squeeze-vector (%array-stride array) axes) (%array-offset array) (%array-mutable? array) (%array-storage-object array))) (define (unsqueeze-vector vec rank) (let* ((new-length (add1 (vector-length vec))) (result (make-vector new-length))) (let loop ((i 0)) (cond ((fx> i new-length) result) ((fx= i rank) (vector-set! result i 1) (loop (add1 i))) (else (vector-set! result i (vector-ref vec (if (fx> rank i) i (sub1 i)))) (loop (add1 i))))))) (define (%array-unsqueeze array rank) (let ((shape (unsqueeze-vector (%array-shape array) rank))) (%make-array (%array-storage-class array) shape (shape->stride shape (row-major? array)) (%array-offset array) (%array-mutable? array) (%array-storage-object array)))) (define (%array-copy array mutable?) (let* ((new-array (%%dummy-make-array (%array-storage-class array) (%array-shape array))) (new-data (%array-storage-object new-array)) (data (%array-storage-object array))) (%array-for-each-index (lambda (index) (let ((from-i (%array-index->storage-index array index)) (to-i (%array-index->storage-index new-array index)) (storage-set! (storage-class-mutator (%array-storage-class new-array))) (storage-ref (storage-class-accessor (%array-storage-class array)))) (storage-set! new-data to-i (storage-ref data from-i)))) (%array-lower-bound array) (%array-upper-bound array)) (set! (%array-mutable? new-array) mutable?) new-array)) (define (%array-copy! from to) (let ((from-data (%array-storage-object from)) (from-offset (%array-offset from)) (to-data (%array-storage-object to)) (to-offset (%array-offset to))) (%array-for-each-index (lambda (index) (let ((from-i (%array-index->storage-index from index)) (to-i (%array-index->storage-index to index)) (storage-set! (storage-class-mutator (%array-storage-class to))) (storage-ref (storage-class-accessor (%array-storage-class from)))) (storage-set! to-data to-i (storage-ref from-data from-i)))) (%array-lower-bound from) (%array-upper-bound from))) to) (define (%array-append storage-class axis arrays) (let* ((new-axis-length (fold (lambda (x k) (fx+ k (vector-ref (%array-shape x) axis))) 0 arrays)) (new-shape (let ((shape (vector-copy (%array-shape (car arrays))))) (vector-set! shape axis new-axis-length) shape)) (new-array (%%dummy-make-array storage-class new-shape))) (fold (lambda (array length) (let ((offset (%array-lower-bound array))) (vector-set! offset axis length) (%array-copy! array (%array-slice new-array offset (vector-map fx+ offset (%array-shape array)) (vector-map (lambda _ 1) offset)))) (fx+ length (vector-ref (%array-shape array) axis))) 0 arrays) new-array)) (define (%array->nested-list array) (let loop ((subarray array)) (let ((dim-length (vector-ref (%array-shape subarray) 0))) (if (= (%array-rank subarray) 1) (map (compose (cute %array-ref subarray <>) vector) (iota dim-length)) (let ((lower-bound (vector-copy (%array-lower-bound subarray))) (upper-bound (vector-copy (%array-shape subarray)))) (map (lambda (i) (vector-set! lower-bound 0 i) (vector-set! upper-bound 0 (add1 i)) (loop (%array-squeeze (%array-slice subarray lower-bound upper-bound (make-vector dim-length 1)) (vector 0)))) (iota dim-length))))))) (define-record-printer (array a output-port) (fprintf output-port "#a~S~S~S" (%array-rank a) (storage-class-short-id (%array-storage-class a)) (%array->nested-list a)))