(define make-matrix (lambda (m n) (let ((res (make-vector (+ 1 (* m n))))) (vector-set! res (* m n) n) res))) (define matrix-size (lambda (x) (- (vector-length x) 1))) (define-inline (matrix-num-rows x) (quotient (matrix-size x) (matrix-num-cols x))) (define-inline (matrix-num-cols x) (vector-ref x (- (vector-length x) 1))) (define-inline (array-ref mat i j) (vector-ref mat (+ j (* i (matrix-num-cols mat))))) (define-inline (array-set! mat i j x) (vector-set! mat (+ j (* i (matrix-num-cols mat))) x)) (define-inline (matrix-ref mat num-cols i j) (vector-ref mat (+ j (* i num-cols)))) (define-inline (matrix-set! mat num-cols i j x) (vector-set! mat (+ j (* i num-cols)) x)) (define format-matrix (lambda (mat) (let ((m (matrix-num-rows mat)) (n (matrix-num-cols mat)) (out (make-string 0))) (do ((i 0 (+ 1 i))) ((= i m) out) (set! out (string-append out "\n")) (do ((j 0 (+ 1 j))) ((= j n)) (set! out (string-append out (format "~a " (array-ref mat i j))))))))) (define read-matrix (lambda (port) (let ((m (read port))) (let ((n (read port))) (let ((mat (make-matrix m n))) (do ((i 0 (+ 1 i))) ((= i m) mat) (do ((j 0 (+ 1 j))) ((= j n)) (array-set! mat i j (read port))))))))) ;; match-error is called to complain when mul receives a pair of ;; incompatible arguments. (define match-error (lambda (what1 what2) (error 'mul "~s and ~s are incompatible operands" what1 what2))) (define multiply (lambda (m1 m2) (let* ((nr1 (matrix-num-rows m1)) (nr2 (matrix-num-rows m2)) (nc1 (matrix-num-cols m1)) (nc2 (matrix-num-cols m2)) (r (make-matrix nr1 nc2))) (if (not (= nc1 nr2)) (match-error m1 m2)) (let lp ((i 0) (j 0) (k 0) (a 0)) (cond ((= k nr2) (matrix-set! r nc2 i j a) (if (= (+ j 1) nc2) (if (= (+ i 1) nr1) r (lp (+ i 1) 0 0 0)) (lp i (+ j 1) 0 0))) (else (lp i j (+ k 1) (+ a (* (matrix-ref m1 nc1 i k) (matrix-ref m2 nc2 k j)))))))))) (define bench-multiply (lambda (a b times) (let ((c #f)) (do ((i 0 (+ i 1))) ((> i times)) (set! c (multiply a b))) c))) (define (main args) (let ((times (string->number (car args))) (xm (cadr args)) (ym (caddr args))) (let ((port-a (open-input-file xm)) (port-b (open-input-file ym))) (let ((a (read-matrix port-a)) (b (read-matrix port-b)) (c #f)) (if (< (matrix-size a) 100) (begin (display (format-matrix a)) (newline))) (display (format "------------")) (newline) (let ((t0 (nth-value 0 (cpu-time)))) (set! c (bench-multiply a b times)) (print (- (nth-value 0 (cpu-time)) t0))) (display (format "------------")) (newline) (if (< (matrix-size c) 100) (display (format-matrix c))) (newline) 0)))) (newline) (main (command-line-arguments))