;;;; mm-gambit.scm (declare (standard-bindings) (extended-bindings) (block) (mostly-fixnum-flonum) ) (define make-matrix (lambda (m n) (let ((res (make-f64vector (+ 1 (* m n)) 0.0))) (f64vector-set! res (* m n) (exact->inexact n)) res))) (define (matrix-size x) (- (f64vector-length x) 1)) (define (matrix-num-rows x) (quotient (matrix-size x) (matrix-num-cols x))) (define (matrix-num-cols x) (inexact->exact (f64vector-ref x (- (f64vector-length x) 1)))) (define (array-ref mat i j) (f64vector-ref mat (+ j (* i (matrix-num-cols mat))))) (define (array-set! mat i j x) (f64vector-set! mat (+ j (* i (matrix-num-cols mat))) x)) (define (matrix-ref mat num-cols i j) (f64vector-ref mat (+ j (* i num-cols)))) (define (matrix-set! mat num-cols i j x) (f64vector-set! mat (+ j (* i num-cols)) x)) (define multiply (lambda (m1 m2 r nr1 nc1 nr2 nc2) (let lp ((i 0) (j 0) (k 0) (a 0)) (cond ((= k nr2) (matrix-set! r nc2 i j a) (if (= (+ j 1) nc2) (if (not (= (+ i 1) nr1)) (lp (+ i 1) 0 0 0)) (lp i (+ j 1) 0 0))) (else (lp i j (+ k 1) (+ a (* (matrix-ref m1 nc1 i k) (matrix-ref m2 nc2 k j))))))))) (define bench-multiply (lambda (a b times) (let* ((nr1 (matrix-num-rows a)) (nr2 (matrix-num-rows b)) (nc1 (matrix-num-cols a)) (nc2 (matrix-num-cols b)) (m (make-matrix nr1 nc2))) (do ((i 0 (+ i 1))) ((> i times)) (multiply a b m nr1 nc1 nr2 nc2)) m))) (define (main args) (let ((times (string->number (car args))) (xm (cadr args)) (ym (caddr args))) (let ((port-a (open-input-file xm)) (port-b (open-input-file ym))) (let ((a (read-matrix port-a)) (b (read-matrix port-b)) (c #f)) (if (< (matrix-size a) 100) (begin (display (format-matrix a)) (newline))) (display "------------") (newline) (let ((t0 (cpu-time))) (set! c (bench-multiply a b times)) (print (- (cpu-time) t0))) (display "------------") (newline) (if (< (matrix-size c) 100) (display (format-matrix c))) (newline) 0)))) (define read-matrix (lambda (port) (let ((m (read port))) (let ((n (read port))) (let ((mat (make-matrix m n))) (do ((i 0 (+ 1 i))) ((= i m) mat) (do ((j 0 (+ 1 j))) ((= j n)) (array-set! mat i j (read port))))))))) (define format-matrix (lambda (mat) (let ((m (matrix-num-rows mat)) (n (matrix-num-cols mat)) (out (make-string 0))) (do ((i 0 (+ 1 i))) ((= i m) out) (set! out (string-append out "\n")) (do ((j 0 (+ 1 j))) ((= j n)) (set! out (string-append out (number->string (array-ref mat i j)) " "))))))) (newline) (main (cdr (command-line)))