;; blas.scm (module blas (RowMajor ColMajor NoTrans Trans ConjTrans Left Right Upper Lower Unit NonUnit sicopy dicopy cicopy zicopy scopy dcopy ccopy zcopy unsafe-sgemm! unsafe-dgemm! unsafe-cgemm! unsafe-zgemm! sgemm! dgemm! cgemm! zgemm! sgemm dgemm cgemm zgemm unsafe-ssymm! unsafe-dsymm! unsafe-csymm! unsafe-zsymm! ssymm! dsymm! csymm! zsymm! ssymm dsymm csymm zsymm unsafe-chemm! unsafe-zhemm! chemm! zhemm! chemm zhemm unsafe-ssyrk! unsafe-dsyrk! unsafe-csyrk! unsafe-zsyrk! ssyrk! dsyrk! csyrk! zsyrk! ssyrk dsyrk csyrk zsyrk unsafe-cherk! unsafe-zherk! cherk! zherk! cherk zherk unsafe-ssyr2k! unsafe-dsyr2k! unsafe-csyr2k! unsafe-zsyr2k! ssyr2k! dsyr2k! csyr2k! zsyr2k! ssyr2k dsyr2k csyr2k zsyr2k unsafe-cher2k! unsafe-zher2k! cher2k! zher2k! cher2k zher2k unsafe-strmm! unsafe-dtrmm! unsafe-ctrmm! unsafe-ztrmm! strmm! dtrmm! ctrmm! ztrmm! strmm dtrmm ctrmm ztrmm unsafe-strsm! unsafe-dtrsm! unsafe-ctrsm! unsafe-ztrsm! strsm! dtrsm! ctrsm! ztrsm! strsm dtrsm ctrsm ztrsm unsafe-sgemv! unsafe-dgemv! unsafe-cgemv! unsafe-zgemv! sgemv! dgemv! cgemv! zgemv! sgemv dgemv cgemv zgemv unsafe-chemv! unsafe-zhemv! chemv! zhemv! chemv zhemv unsafe-chbmv! unsafe-zhbmv! chbmv! zhbmv! chbmv zhbmv unsafe-chpmv! unsafe-zhpmv! chpmv! zhpmv! chpmv zhpmv unsafe-ssymv! unsafe-dsymv! ssymv! dsymv! ssymv dsymv unsafe-ssbmv! unsafe-dsbmv! ssbmv! dsbmv! ssbmv dsbmv unsafe-sspmv! unsafe-dspmv! sspmv! dspmv! sspmv dspmv unsafe-strmv! unsafe-dtrmv! unsafe-ctrmv! unsafe-ztrmv! strmv! dtrmv! ctrmv! ztrmv! strmv dtrmv ctrmv ztrmv unsafe-stbmv! unsafe-dtbmv! unsafe-ctbmv! unsafe-ztbmv! stbmv! dtbmv! ctbmv! ztbmv! stbmv dtbmv ctbmv ztbmv unsafe-stpmv! unsafe-dtpmv! unsafe-ctpmv! unsafe-ztpmv! stpmv! dtpmv! ctpmv! ztpmv! stpmv dtpmv ctpmv ztpmv unsafe-strsv! unsafe-dtrsv! unsafe-ctrsv! unsafe-ztrsv! strsv! dtrsv! ctrsv! ztrsv! strsv dtrsv ctrsv ztrsv unsafe-stbsv! unsafe-dtbsv! unsafe-ctbsv! unsafe-ztbsv! stbsv! dtbsv! ctbsv! ztbsv! stbsv dtbsv ctbsv ztbsv unsafe-stpsv! unsafe-dtpsv! unsafe-ctpsv! unsafe-ztpsv! stpsv! dtpsv! ctpsv! ztpsv! stpsv dtpsv ctpsv ztpsv unsafe-sger! unsafe-dger! sger! dger! sger dger unsafe-siger! unsafe-diger! siger! diger! siger diger unsafe-cgeru! unsafe-zgeru! cgeru! zgeru! cgeru zgeru unsafe-cgerc! unsafe-zgerc! cgerc! zgerc! cgerc zgerc unsafe-cher! unsafe-zher! cher! zher! cher zher unsafe-chpr! unsafe-zhpr! chpr! zhpr! chpr zhpr unsafe-cher2! unsafe-zher2! cher2! zher2! cher2 zher2 unsafe-chpr2! unsafe-zhpr2! chpr2! zhpr2! chpr2 zhpr2 unsafe-ssyr! unsafe-dsyr! ssyr! dsyr! ssyr dsyr unsafe-sspr! unsafe-dspr! sspr! dspr! sspr dspr unsafe-ssyr2! unsafe-dsyr2! ssyr2! dsyr2! ssyr2 dsyr2 unsafe-sspr2! unsafe-dspr2! sspr2! dspr2! sspr2 dspr2 unsafe-srot! unsafe-drot! srot! drot! srot drot unsafe-srotm! unsafe-drotm! srotm! drotm! srotm drotm unsafe-sswap! unsafe-dswap! unsafe-cswap! unsafe-zswap! sswap! dswap! cswap! zswap! sswap dswap cswap zswap unsafe-sscal! unsafe-dscal! unsafe-cscal! unsafe-zscal! sscal! dscal! cscal! zscal! sscal dscal cscal zscal unsafe-saxpy! unsafe-daxpy! unsafe-caxpy! unsafe-zaxpy! saxpy! daxpy! caxpy! zaxpy! saxpy daxpy caxpy zaxpy unsafe-siaxpy! unsafe-diaxpy! unsafe-ciaxpy! unsafe-ziaxpy! siaxpy! diaxpy! ciaxpy! ziaxpy! siaxpy diaxpy ciaxpy ziaxpy sdot ddot cdotu zdotu cdotc zdotc snrm2 dnrm2 cnrm2 znrm2 sasum dasum casum zasum samax damax camax zamax ) (import scheme chicken data-structures foreign) (require-extension srfi-4 bind) (define (blas:error x . rest) (let ((port (open-output-string))) (let loop ((objs (if (symbol? x) rest (cons x rest)))) (if (null? objs) (begin (newline port) (error (if (symbol? x) x 'blas) (get-output-string port))) (begin (display (car objs) port) (display " " port) (loop (cdr objs))))))) (bind* #<= offsetX xlen) (blas:error ',name "offset of vector X (" offsetX ") is greater than or equal to its length: " xlen)) ((fx< offsetX 0) (blas:error ',name "offset of vector X (" offsetX ") is negative")) ((fx>= offsetY ylen) (blas:error ',name "offset of vector Y (" offsetY ") is greater than or equal to its length: " ylen)) ((fx> (- ylen offsetY) (- xlen offsetX)) (blas:error ',name "range of vector Y (" (- ylen offsetY) ") is greater than range of vector X: " ( - xlen offsetX)))) (,%let ((y (,%or y (,make-vector ylen)))) (,copy n x incX offsetX y incY offsetY) y)))))) ) (icopy-wrapper sicopy f32vector-length make-f32vector) (icopy-wrapper dicopy f64vector-length make-f64vector) (icopy-wrapper cicopy (lambda (x) (fx/ (f32vector-length x) 2)) (lambda (n) (make-f32vector (fx* 2 n)))) (icopy-wrapper zicopy (lambda (x) (fx/ (f64vector-length x) 2)) (lambda (n) (make-f64vector (fx* 2 n)))) (bind* #<symbol (conc "cblas_" (symbol->string (car fn))))) (fname (string->symbol (conc (if vsize "" "unsafe-") (symbol->string (car fn)) (if copy "" "!")))) (%define (r 'define)) (%begin (r 'begin)) (%let (r 'let)) (%cond (r 'cond)) (%or (r 'or)) (%if (r 'if)) (%let-optionals (r 'let-optionals)) (ka (r 'ka)) (kb (r 'kb)) (kc (r 'kc)) (asize (r 'asize)) (bsize (r 'bsize)) (csize (r 'csize)) (args (reverse (cdr fn))) (fsig (let loop ((args args) (sig 'rest)) (if (null? args) (cons fname sig) (let ((x (car args))) (let ((sig (case x ((lda) sig) ((ldb) sig) ((ldc) sig) (else (cons x sig))))) (loop (cdr args) sig)))))) (opts (append (if (memq 'lda fn) `((lda ,(cond ((memq 'side fn) `(,%if (= side Left) m n)) ((memq 'transA fn) `(,%if (= transA NoTrans) k ,(if (memq 'm fn) 'm 'n))) ((memq 'trans fn) `(,%if (= trans NoTrans) k n)) (else (cond ((memq 'm fn) 'm) (else 'n)))))) `()) (if (memq 'ldb fn) `((ldb ,(cond ((memq 'transB fn) `(,%if (= transB NoTrans) n k)) ((memq 'trans fn) `(,%if (= trans NoTrans) k n)) (else 'n)))) `()) (if (memq 'ldc fn) `((ldc n)) `())))) `(,%define ,fsig (,%let-optionals rest ,opts ,(if vsize `(,%begin (,%let ((,asize (,vsize a)) (,ka ,(cond ((memq 'side fn) `(,%if (= side Left) m n)) ((memq 'transA fn) `(,%if (= transA NoTrans) ,(if (memq 'm fn) 'm 'n) k)) ((memq 'trans fn) `(,%if (= trans NoTrans) ,(if (memq 'm fn) 'm 'n) k)) (else (if (memq 'm fn) 'm 'n))))) (,%if (< ,asize (fx* lda ,ka)) (blas:error ',fname (conc "matrix A is allocated " ,asize " elements " "but given dimensions are " ,ka " by " lda)))) ,(if (memq 'b fn) `(,%let ((,bsize (,vsize b)) (,kb ,(cond ((memq 'transB fn) `(,%if (= transB NoTrans) k n)) ((memq 'trans fn) `(,%if (= trans NoTrans) n k)) (else 'm)))) (,%if (< ,bsize (fx* ldb ,kb)) (blas:error ',fname (conc "matrix B is allocated " ,bsize " elements " "but given dimensions are " ,kb " by " ldb)))) `(begin)) ,(if (memq 'c fn) `(let ((,csize (,vsize c)) (,kc ,(if (memq 'm fn) 'm 'n))) (if (< ,csize (fx* ldc ,kc)) (blas:error ',fname (conc "matrix C is allocated " ,csize " elements " "but given dimensions are " ,kc " by " ldc)))) `(begin))) `(begin)) (,%let ,(let loop ((fn fn) (bnds '())) (if (null? fn) bnds (let ((x (car fn))) (let ((bnds (case x (else (if (and copy ret (memq x ret)) (cons `(,x (,copy ,x)) bnds) bnds))))) (loop (cdr fn) bnds))))) (,%begin (,cfname . ,(cdr fn)) (values . ,ret))))))) ) (define-syntax blas-level3-wrapx (lambda (x r c) (let* ((fn (cadr x)) (ret (caddr x)) (errs (cadddr x))) `(begin (blas-level3-wrap ,(cons (string->symbol (conc "s" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs #f #f) (blas-level3-wrap ,(cons (string->symbol (conc "d" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs #f #f) (blas-level3-wrap ,(cons (string->symbol (conc "c" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs #f #f) (blas-level3-wrap ,(cons (string->symbol (conc "z" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs #f #f) (blas-level3-wrap ,(cons (string->symbol (conc "s" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs f32vector-length #f) (blas-level3-wrap ,(cons (string->symbol (conc "d" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs f64vector-length #f) (blas-level3-wrap ,(cons (string->symbol (conc "c" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs (lambda (v) (fx/ 2 (f32vector-length v))) #f) (blas-level3-wrap ,(cons (string->symbol (conc "z" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs (lambda (v) (fx/ 2 (f64vector-length v))) #f) (blas-level3-wrap ,(cons (string->symbol (conc "s" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs f32vector-length scopy) (blas-level3-wrap ,(cons (string->symbol (conc "d" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs f64vector-length dcopy) (blas-level3-wrap ,(cons (string->symbol (conc "c" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs (lambda (v) (fx/ 2 (f32vector-length v))) ccopy) (blas-level3-wrap ,(cons (string->symbol (conc "z" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs (lambda (v) (fx/ 2 (f64vector-length v))) zcopy)))) ) (define-syntax blas-level3-cz-wrapx (lambda (x r c) (let* ((fn (cadr x)) (ret (caddr x)) (errs (cadddr x))) `(begin (blas-level3-wrap ,(cons (string->symbol (conc "c" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs #f #f) (blas-level3-wrap ,(cons (string->symbol (conc "z" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs #f #f) (blas-level3-wrap ,(cons (string->symbol (conc "c" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs (lambda (v) (fx/ 2 (f32vector-length v))) #f) (blas-level3-wrap ,(cons (string->symbol (conc "z" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs (lambda (v) (fx/ 2 (f64vector-length v))) #f) (blas-level3-wrap ,(cons (string->symbol (conc "c" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs (lambda (v) (fx/ 2 (f32vector-length v))) ccopy) (blas-level3-wrap ,(cons (string->symbol (conc "z" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs (lambda (v) (fx/ 2 (f64vector-length v))) zcopy)))) ) (blas-level3-wrapx (gemm order transA transB m n k alpha a lda b ldb beta c ldc) (c) (lambda (i) (cond ((= i 3) "M < 0") ((= i 4) "N < 0") ((= i 5) "K < 0") ((= i 8) "LDA < max(1, M or K)") ((= i 10) "LDB < max(1, N or K)") ((= i 13) "LDC < max(1, M)") (else (conc "error code " i))))) (blas-level3-wrapx (symm order side uplo m n alpha a lda b ldb beta c ldc) (c) (lambda (i) (cond ((= i 3) "M < 0") ((= i 4) "N < 0") ((= i 5) "K < 0") ((= i 8) "LDA < max(1, M or K)") ((= i 10) "LDB < max(1, N or K)") ((= i 13) "LDC < max(1, M)") (else (conc "error code " i))))) (blas-level3-cz-wrapx (hemm order side uplo m n alpha a lda b ldb beta c ldc) (c) (lambda (i) (cond ((= i 3) "M < 0") ((= i 4) "N < 0") ((= i 5) "K < 0") ((= i 8) "LDA < max(1, M or K)") ((= i 10) "LDB < max(1, N or K)") ((= i 13) "LDC < max(1, M)") (else (conc "error code " i))))) (blas-level3-wrapx (syrk order uplo trans n k alpha a lda beta c ldc) (c) (lambda (i) (cond ((= i 3) "M < 0") ((= i 4) "N < 0") ((= i 5) "K < 0") ((= i 8) "LDA < max(1, M or K)") ((= i 10) "LDB < max(1, N or K)") ((= i 13) "LDC < max(1, M)") (else (conc "error code " i))))) (blas-level3-cz-wrapx (herk order uplo trans n k alpha a lda beta c ldc) (c) (lambda (i) (cond ((= i 3) "M < 0") ((= i 4) "N < 0") ((= i 5) "K < 0") ((= i 8) "LDA < max(1, M or K)") ((= i 10) "LDB < max(1, N or K)") ((= i 13) "LDC < max(1, M)") (else (conc "error code " i))))) (blas-level3-wrapx (syr2k order uplo trans n k alpha a lda b ldb beta c ldc) (c) (lambda (i) (cond ((= i 3) "M < 0") ((= i 4) "N < 0") ((= i 5) "K < 0") ((= i 8) "LDA < max(1, M or K)") ((= i 10) "LDB < max(1, N or K)") ((= i 13) "LDC < max(1, M)") (else (conc "error code " i))))) (blas-level3-cz-wrapx (her2k order uplo trans n k alpha a lda b ldb beta c ldc) (c) (lambda (i) (cond ((= i 3) "M < 0") ((= i 4) "N < 0") ((= i 5) "K < 0") ((= i 8) "LDA < max(1, M or K)") ((= i 10) "LDB < max(1, N or K)") ((= i 13) "LDC < max(1, M)") (else (conc "error code " i))))) (blas-level3-wrapx (trmm order side uplo transA diag m n alpha a lda b ldb) (b) (lambda (i) (cond ((= i 3) "M < 0") ((= i 4) "N < 0") ((= i 5) "K < 0") ((= i 8) "LDA < max(1, M or K)") ((= i 10) "LDB < max(1, N or K)") ((= i 13) "LDC < max(1, M)") (else (conc "error code " i))))) (blas-level3-wrapx (trsm order side uplo transA diag m n alpha a lda b ldb) (b) (lambda (i) (cond ((= i 3) "M < 0") ((= i 4) "N < 0") ((= i 5) "K < 0") ((= i 8) "LDA < max(1, M or K)") ((= i 10) "LDB < max(1, N or K)") ((= i 13) "LDC < max(1, M)") (else (conc "error code " i))))) (define-syntax blas-level2-wrap (lambda (x r c) (let* ((fn (cadr x)) (ret (caddr x)) (err (cadddr x)) (vsize (car (cddddr x))) (copy (cadr (cddddr x))) (cfname (string->symbol (conc "cblas_" (symbol->string (car fn))))) (fname (string->symbol (conc (if vsize "" "unsafe-") (symbol->string (car fn)) (if copy "" "!")))) (%define (r 'define)) (%begin (r 'begin)) (%let (r 'let)) (%cond (r 'cond)) (%or (r 'or)) (%if (r 'if)) (%let-optionals (r 'let-optionals)) (ka (r 'ka)) (asize (r 'asize)) (apsize (r 'apsize)) (apdim (r 'apdim)) (xsize (r 'xsize)) (ysize (r 'ysize)) (xdim (r 'xdim)) (ydim (r 'ydim)) (args (reverse (cdr fn))) (fsig (let loop ((args args) (sig 'rest)) (if (null? args) (cons fname sig) (let ((x (car args))) (let ((sig (case x ((lda) sig) ((incx) sig) ((incy) sig) ((offx) sig) ((offy) sig) (else (cons x sig))))) (loop (cdr args) sig)))))) (opts (append (if (memq 'lda fn) `((lda ,(cond ((memq 'k fn) `(fx+ 1 k)) (else 'n)))) `()) (if (memq 'incy fn) `((incx 1) (incy 1) (offx 0) (offy 0)) `((incx 1))))) ) `(,%define ,fsig (,%let-optionals rest ,opts ,(if vsize `(,%begin ,(if (memq 'a fn) `(,%let ((,asize (,vsize a)) (,ka ,(if (memq 'm fn) 'm 'n))) (,%if (< ,asize (fx* lda ,ka)) (blas:error ',fname (conc "matrix A is allocated " ,asize " elements " "but given dimensions are " ,ka " by " lda)))) `(begin)) ,(if (memq 'ap fn) `(,%let ((,apsize (,vsize ap)) (,apdim (fx/ (fx* n (fx+ n 1)) 2))) (,%if (< ,apsize ,apdim) (blas:error ',fname (conc "vector Ap is allocated " ,apsize " elements " "but given dimension is " ,apdim)))) `(begin)) ,(if (memq 'y fn) `(,%let ((,ysize (,vsize y)) (,ydim ,(if (and (memq 'm fn) (memq 'trans fn)) `(,%if (= trans NoTrans) (fx+ 1 (fx* (abs incy) (fx- (fx+ offy m) 1))) (fx+ 1 (fx* (abs incy) (fx- (fx+ offy n) 1)))) `(fx+ 1 (fx* (abs incy) (fx- n 1)))))) (,%if (< ,ysize ,ydim) (blas:error ',fname (conc "vector Y is allocated " ,ysize " elements " "but given dimension is " ,ydim)))) `(begin)) ,(if (memq 'x fn) `(,%let ((,xsize (,vsize x)) (,xdim ,(if (and (memq 'm fn) (memq 'trans fn)) `(if (= trans NoTrans) (fx+ 1 (fx* (abs incx) (fx- (fx+ offx n) 1))) (fx+ 1 (fx* (abs incx) (fx- (fx+ offx m) 1)))) `(fx+ 1 (fx* (abs incx) (fx- n 1)))))) (,%if (< ,xsize ,xdim) (blas:error ',fname (conc "vector X is allocated " ,xsize " elements " "but given dimension is " ,xdim)))) `(begin))) `(begin)) (let ,(let loop ((fn fn) (bnds '())) (if (null? fn) bnds (let ((x (car fn))) (let ((bnds (case x (else (if (and copy ret (memq x ret)) (cons `(,x (,copy ,x)) bnds) bnds))))) (loop (cdr fn) bnds))))) (begin (,cfname . ,(cdr fn)) (values . ,ret))))))) ) (define-syntax blas-level2-wrapx (lambda (x r c) (let* ((fn (cadr x)) (ret (caddr x)) (errs (cadddr x))) `(begin (blas-level2-wrap ,(cons (string->symbol (conc "s" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs #f #f) (blas-level2-wrap ,(cons (string->symbol (conc "d" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs #f #f) (blas-level2-wrap ,(cons (string->symbol (conc "c" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs #f #f) (blas-level2-wrap ,(cons (string->symbol (conc "z" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs #f #f) (blas-level2-wrap ,(cons (string->symbol (conc "s" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs f32vector-length #f) (blas-level2-wrap ,(cons (string->symbol (conc "d" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs f64vector-length #f) (blas-level2-wrap ,(cons (string->symbol (conc "c" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs (lambda (v) (fx/ 2 (f32vector-length v))) #f) (blas-level2-wrap ,(cons (string->symbol (conc "z" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs (lambda (v) (fx/ 2 (f64vector-length v))) #f) (blas-level2-wrap ,(cons (string->symbol (conc "s" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs f32vector-length scopy) (blas-level2-wrap ,(cons (string->symbol (conc "d" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs f64vector-length dcopy) (blas-level2-wrap ,(cons (string->symbol (conc "c" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs (lambda (v) (fx/ 2 (f32vector-length v))) ccopy) (blas-level2-wrap ,(cons (string->symbol (conc "z" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs (lambda (v) (fx/ 2 (f64vector-length v))) zcopy))) )) (define-syntax blas-level2-sd-wrapx (lambda (x r c) (let* ((fn (cadr x)) (ret (caddr x)) (errs (cadddr x))) `(begin (blas-level2-wrap ,(cons (string->symbol (conc "s" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs #f #f) (blas-level2-wrap ,(cons (string->symbol (conc "d" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs #f #f) (blas-level2-wrap ,(cons (string->symbol (conc "s" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs f32vector-length #f) (blas-level2-wrap ,(cons (string->symbol (conc "d" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs f64vector-length #f) (blas-level2-wrap ,(cons (string->symbol (conc "s" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs f32vector-length scopy) (blas-level2-wrap ,(cons (string->symbol (conc "d" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs f64vector-length dcopy)))) ) (define-syntax blas-level2-cz-wrapx (lambda (x r c) (let* ((fn (cadr x)) (ret (caddr x)) (errs (cadddr x))) `(begin (blas-level2-wrap ,(cons (string->symbol (conc "c" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs #f #f) (blas-level2-wrap ,(cons (string->symbol (conc "z" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs #f #f) (blas-level2-wrap ,(cons (string->symbol (conc "c" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs (lambda (v) (fx/ 2 (f32vector-length v))) #f) (blas-level2-wrap ,(cons (string->symbol (conc "z" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs (lambda (v) (fx/ 2 (f64vector-length v))) #f) (blas-level2-wrap ,(cons (string->symbol (conc "c" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs (lambda (v) (fx/ 2 (f32vector-length v))) ccopy) (blas-level2-wrap ,(cons (string->symbol (conc "z" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs (lambda (v) (fx/ 2 (f64vector-length v))) zcopy)))) ) (blas-level2-wrapx (gemv order trans m n alpha a lda x incx beta y incy) (y) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-cz-wrapx (hemv order uplo n alpha a lda x incx beta y incy) (y) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-cz-wrapx (hbmv order uplo n k alpha a lda x incx beta y incy) (y) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-cz-wrapx (hpmv order uplo n alpha ap x incx beta y incy) (y) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-sd-wrapx (symv order uplo n alpha a lda x incx beta y incy) (y) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-sd-wrapx (sbmv order uplo n k alpha a lda x incx beta y incy) (y) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-sd-wrapx (spmv order uplo n alpha ap x incx beta y incy) (y) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-wrapx (trmv order uplo trans diag n a lda x incx) (x) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-wrapx (tbmv order uplo trans diag n k a lda x incx) (x) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-wrapx (tpmv order uplo trans diag n ap x incx) (x) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-wrapx (trsv order uplo trans diag n a lda x incx) (x) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-wrapx (tbsv order uplo trans diag n k a lda x incx) (x) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-wrapx (tpsv order uplo trans diag n ap x incx) (x) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-sd-wrapx (ger order m n alpha x incx y incy a lda) (a) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-cz-wrapx (geru order m n alpha x incx y incy a lda) (a) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-cz-wrapx (gerc order m n alpha x incx y incy a lda) (a) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-cz-wrapx (her order uplo n alpha x incx a lda) (a) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-cz-wrapx (hpr order uplo n alpha x incx ap) (ap) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-cz-wrapx (her2 order uplo n alpha x incx y incy a lda) (a) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-cz-wrapx (hpr2 order uplo n alpha x incx y incy ap) (ap) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-sd-wrapx (syr order uplo n alpha x incx a lda) (a) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-sd-wrapx (spr order uplo n alpha x incx ap) (ap) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-sd-wrapx (syr2 order uplo n alpha x incx y incy a lda) (a) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-sd-wrapx (ger order m n alpha x incx y incy a lda) (a) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-sd-wrapx (iger order m n alpha x incx offx y incy offy a lda) (a) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-cz-wrapx (geru order m n alpha x incx y incy a lda) (a) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-cz-wrapx (gerc order m n alpha x incx y incy a lda) (a) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-cz-wrapx (her order uplo n alpha x incx a lda) (a) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-cz-wrapx (hpr order uplo n alpha x incx ap) (ap) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-cz-wrapx (her2 order uplo n alpha x incx y incy a lda) (a) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-cz-wrapx (hpr2 order uplo n alpha x incx y incy ap) (ap) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-sd-wrapx (syr order uplo n alpha x incx a lda) (a) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-sd-wrapx (spr order uplo n alpha x incx ap) (ap) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-sd-wrapx (syr2 order uplo n alpha x incx y incy a lda) (a) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (blas-level2-sd-wrapx (spr2 order uplo n alpha x incx y incy ap) (ap) (lambda (i) (cond ((= i 2) "M < 0") ((= i 3) "N < 0") ((= i 6) "LDA < max(1, M)") ((= i 8) "INCX = 0") ((= i 11) "INCY < = 0") (else (conc "error code " i))))) (define-syntax blas-level1-wrap (lambda (x r c) (let* ((fn (cadr x)) (ret (caddr x)) (err (cadddr x)) (vsize (car (cddddr x))) (copy (cadr (cddddr x))) (make-return (cddr (cddddr x))) (cfname (string->symbol (conc "cblas_" (symbol->string (car fn))))) (fname (string->symbol (conc (if vsize "" "unsafe-") (symbol->string (car fn)) (if copy "" "!")))) (%define (r 'define)) (%begin (r 'begin)) (%let (r 'let)) (%cond (r 'cond)) (%or (r 'or)) (%if (r 'if)) (%let-optionals (r 'let-optionals)) (asize (r 'asize)) (apsize (r 'apsize)) (apdim (r 'apdim)) (xsize (r 'xsize)) (ysize (r 'ysize)) (xdim (r 'xdim)) (ydim (r 'ydim)) (psize (r 'psize)) (pdim (r 'pdim)) (args (reverse (cdr fn))) (fsig (let loop ((args args) (sig 'rest)) (if (null? args) (cons fname sig) (let ((x (car args))) (let ((sig (case x ((incx) sig) ((incy) sig) ((dotu) sig) ((dotc) sig) ((offx) sig) ((offy) sig) (else (cons x sig))))) (loop (cdr args) sig)))))) (opts (cond ((memq 'incy fn) `((incx 1) (incy 1) (offx 0) (offy 0))) (else `((incx 1) (offx 0)))))) `(,%define ,fsig (,%let-optionals rest ,opts ,(if vsize `(,%begin ,(if (memq 'y fn) `(,%let ((,ysize (,vsize y)) (,ydim (fx+ 1 (fx* (abs incy) (fx- (fx+ offy n) 1))))) (,%if (< ,ysize ,ydim) (blas:error ',fname (conc "vector Y is allocated " ,ysize " elements " "but given dimension is " ,ydim)))) `(begin)) ,(if (memq 'x fn) `(,%let ((,xsize (,vsize x)) (,xdim (fx+ 1 (fx* (abs incx) (fx- (fx+ offx n) 1))))) (,%if (< ,xsize ,xdim) (blas:error ',fname (conc "vector X is allocated " ,xsize " elements " "but given dimension is " ,xdim)))) `(begin)) ,(if (memq 'param fn) `(,%let ((,psize (,vsize param)) (,pdim 5)) (,%if (< ,psize ,pdim) (blas:error ',fname (conc "vector PARAM is allocated " ,psize " elements " "but dimension must be " ,pdim)))) `(begin))) `(begin)) (let ,(let loop ((fn fn) (bnds '())) (if (null? fn) bnds (let ((x (car fn))) (let ((bnds (cond ((or (eq? x 'dotc) (eq? x 'dotu)) (cons `(,x (,(car make-return))) bnds)) ((and copy ret (memq x ret)) (cons `(,x (,copy ,x)) bnds)) (else bnds)))) (loop (cdr fn) bnds))))) ,(cond ((memq 'dotc fn) `(begin (,cfname . ,(cdr fn)) (values dotc))) ((memq 'dotu fn) `(begin (,cfname . ,(cdr fn)) (values dotu))) ((not ret) `(,cfname . ,(cdr fn))) (else `(begin (,cfname . ,(cdr fn)) (values . ,ret))))))))) ) (define-syntax blas-level1-wrapx (lambda (x r c) (let* ((fn (cadr x)) (ret (caddr x)) (errs (cadddr x))) (if (not ret) `(begin (blas-level1-wrap ,(cons (string->symbol (conc "s" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs f32vector-length scopy) (blas-level1-wrap ,(cons (string->symbol (conc "d" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs f64vector-length dcopy) (blas-level1-wrap ,(cons (string->symbol (conc "c" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs (lambda (v) (fx/ 2 (f32vector-length v))) ccopy) (blas-level1-wrap ,(cons (string->symbol (conc "z" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs (lambda (v) (fx/ 2 (f64vector-length v))) zcopy)) `(begin (blas-level1-wrap ,(cons (string->symbol (conc "s" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs #f #f) (blas-level1-wrap ,(cons (string->symbol (conc "d" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs #f #f) (blas-level1-wrap ,(cons (string->symbol (conc "c" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs #f #f) (blas-level1-wrap ,(cons (string->symbol (conc "z" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs #f #f) (blas-level1-wrap ,(cons (string->symbol (conc "s" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs f32vector-length #f) (blas-level1-wrap ,(cons (string->symbol (conc "d" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs f64vector-length #f) (blas-level1-wrap ,(cons (string->symbol (conc "c" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs (lambda (v) (fx/ 2 (f32vector-length v))) #f) (blas-level1-wrap ,(cons (string->symbol (conc "z" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs (lambda (v) (fx/ 2 (f64vector-length v))) #f) (blas-level1-wrap ,(cons (string->symbol (conc "s" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs f32vector-length scopy) (blas-level1-wrap ,(cons (string->symbol (conc "d" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs f64vector-length dcopy) (blas-level1-wrap ,(cons (string->symbol (conc "c" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs (lambda (v) (fx/ 2 (f32vector-length v))) ccopy) (blas-level1-wrap ,(cons (string->symbol (conc "z" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs (lambda (v) (fx/ 2 (f64vector-length v))) zcopy)))) )) (define-syntax blas-level1-sd-wrapx (lambda (x r c) (let* ((fn (cadr x)) (ret (caddr x)) (errs (cadddr x))) (if (not ret) `(begin (blas-level1-wrap ,(cons (string->symbol (conc "s" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs f32vector-length scopy) (blas-level1-wrap ,(cons (string->symbol (conc "d" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs f64vector-length dcopy)) `(begin (blas-level1-wrap ,(cons (string->symbol (conc "s" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs #f #f) (blas-level1-wrap ,(cons (string->symbol (conc "d" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs #f #f) (blas-level1-wrap ,(cons (string->symbol (conc "s" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs f32vector-length #f) (blas-level1-wrap ,(cons (string->symbol (conc "d" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs f64vector-length #f) (blas-level1-wrap ,(cons (string->symbol (conc "s" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs f32vector-length scopy) (blas-level1-wrap ,(cons (string->symbol (conc "d" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs f64vector-length dcopy)))) )) (define-syntax blas-level1-cz-wrapx (lambda (x r c) (let* ((fn (cadr x)) (ret (caddr x)) (errs (cadddr x))) (if (not ret) `(begin (blas-level1-wrap ,(cons (string->symbol (conc "c" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs (lambda (v) (fx/ 2 (f32vector-length v))) ccopy (lambda () (make-f32vector 2))) (blas-level1-wrap ,(cons (string->symbol (conc "z" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs (lambda (v) (fx/ 2 (f64vector-length v))) zcopy (lambda () (make-f64vector 2)))) `(begin (blas-level1-wrap ,(cons (string->symbol (conc "c" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs #f #f) (blas-level1-wrap ,(cons (string->symbol (conc "z" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs #f #f) (blas-level1-wrap ,(cons (string->symbol (conc "c" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs (lambda (v) (fx/ 2 (f32vector-length v))) #f) (blas-level1-wrap ,(cons (string->symbol (conc "z" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs (lambda (v) (fx/ 2 (f64vector-length v))) #f) (blas-level1-wrap ,(cons (string->symbol (conc "c" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs (lambda (v) (fx/ 2 (f32vector-length v))) ccopy) (blas-level1-wrap ,(cons (string->symbol (conc "z" (symbol->string (car fn)))) (cdr fn)) ,ret ,errs (lambda (v) (fx/ 2 (f64vector-length v))) zcopy)))) )) (blas-level1-sd-wrapx (rot n x incx y incy c s) (x y) (lambda (i) (cond (conc "error code " i)))) (blas-level1-sd-wrapx (rotm n x incx y incy param) (x y) (lambda (i) (cond (conc "error code " i)))) (blas-level1-wrapx (swap n x incx y incy) (x y) (lambda (i) (cond (conc "error code " i)))) (blas-level1-wrapx (scal n alpha x incx) (x) (lambda (i) (cond (conc "error code " i)))) (blas-level1-wrapx (axpy n alpha x incx y incy) (y) (lambda (i) (cond (conc "error code " i)))) (blas-level1-wrapx (iaxpy n alpha x incx offx y incy offy) (y) (lambda (i) (cond (conc "error code " i)))) (blas-level1-sd-wrapx (dot n x incx y incy) #f (lambda (i) (cond (conc "error code " i)))) (blas-level1-cz-wrapx (dotu n x incx y incy dotu) #f (lambda (i) (cond (conc "error code " i)))) (blas-level1-cz-wrapx (dotc n x incx y incy dotc) #f (lambda (i) (cond (conc "error code " i)))) (blas-level1-wrapx (nrm2 n x incx) #f (lambda (i) (cond (conc "error code " i)))) (blas-level1-wrapx (asum n x incx) #f (lambda (i) (cond (conc "error code " i)))) (blas-level1-wrapx (amax n x incx) #f (lambda (i) (cond (conc "error code " i)))) )