;; ;; ;; Utility procedures for NEMO code generators. ;; ;; Copyright 2008-2012 Ivan Raikov and the Okinawa Institute of Science and Technology ;; ;; This program is free software: you can redistribute it and/or ;; modify it under the terms of the GNU General Public License as ;; published by the Free Software Foundation, either version 3 of the ;; License, or (at your option) any later version. ;; ;; This program is distributed in the hope that it will be useful, but ;; WITHOUT ANY WARRANTY; without even the implied warranty of ;; MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU ;; General Public License for more details. ;; ;; A full copy of the GPL license can be found at ;; . ;; (module nemo-utils (lookup-def enum-bnds enum-freevars sum if-convert let-enum let-elim let-lift s+ sw+ slp nl spaces ppf transitions-graph state-conseqs differentiate simplify distribute make-output-fname) (import scheme chicken data-structures files srfi-1 srfi-13) (require-extension matchable strictly-pretty varsubst digraph nemo-core) (define (lookup-def k lst . rest) (let-optionals rest ((default #f)) (let ((k (->string k))) (let recur ((kv #f) (lst lst)) (if (or kv (null? lst)) (if (not kv) default (match kv ((k v) v) (else (cdr kv)))) (let ((kv (car lst))) (recur (and (string=? (->string (car kv)) k) kv) (cdr lst)) )))))) (define (enum-bnds expr ax) (match expr (('if . es) (fold enum-bnds ax es)) (('let bnds body) (enum-bnds body (append (map car bnds) (fold enum-bnds ax (map cadr bnds))))) ((s . es) (if (symbol? s) (fold enum-bnds ax es) ax)) (else ax))) (define (enum-freevars expr bnds ax) (match expr (('if . es) (fold (lambda (x ax) (enum-freevars x bnds ax)) ax es)) (('let lbnds body) (let ((bnds1 (append (map first lbnds) bnds))) (enum-freevars body bnds1 (fold (lambda (x ax) (enum-freevars x bnds ax)) ax (map second lbnds))))) ((s . es) (if (symbol? s) (fold (lambda (x ax) (enum-freevars x bnds ax)) ax es) ax)) (id (if (and (symbol? id) (not (member id bnds))) (cons id ax) ax)))) (define (sum lst) (if (null? lst) lst (match lst ((x) x) ((x y) `(+ ,x ,y)) ((x y . rest) `(+ (+ ,x ,y) ,(sum rest))) ((x . rest) `(+ ,x ,(sum rest)))))) (define (if-convert expr) (match expr (('if c t e) (let ((r (gensym "if"))) `(let ((,r (if ,(if-convert c) ,(if-convert t) ,(if-convert e)))) ,r))) (('let bs e) `(let ,(map (lambda (b) `(,(car b) ,(if-convert (cadr b)))) bs) ,(if-convert e))) ((f . es) (cons f (map if-convert es))) ((? atom? ) expr))) (define (let-enum expr ax) (match expr (('let ((x ('if c t e))) y) (let ((ax (fold let-enum ax (list c )))) (if (eq? x y) (append ax (list (list x `(if ,c ,t ,e)))) ax))) (('let bnds body) (append ax bnds)) (('if c t e) (let-enum c ax)) ((f . es) (fold let-enum ax es)) (else ax))) (define (let-elim expr) (match expr (('let ((x ('if c t e))) y) (if (eq? x y) y expr)) (('let bnds body) body) (('if c t e) `(if ,(let-elim c) ,(let-lift t) ,(let-lift e))) ((f . es) `(,f . ,(map let-elim es))) (else expr))) (define (let-lift expr) (define (fbnds bnds) (let ((bnds0 (fold (lambda (b ax) (let ((bexpr (cadr b))) (match bexpr (('let bnds expr) (append bnds ax)) (else (append (let-enum bexpr (list)) ax))))) '() bnds))) bnds0)) (let ((expr1 (match expr (('let bnds expr) (let ((bnds0 (fbnds bnds)) (expr1 `(let ,(map (lambda (b) (list (car b) (let-elim (cadr b)))) bnds) ,(let-lift expr)))) (if (null? bnds0) expr1 `(let ,bnds0 ,expr1)))) (else (let ((bnds (let-enum expr (list)))) (if (null? bnds) (let-elim expr) (let ((bnds0 (fbnds bnds)) (expr1 `(let ,(map (lambda (b) (list (car b) (let-elim (cadr b)))) bnds) ,(let-elim expr)))) (if (null? bnds0) expr1 `(let ,bnds0 ,expr1)))))) ))) (if (equal? expr expr1) expr1 (let-lift expr1)))) (define (s+ . lst) (string-concatenate (map ->string lst))) (define (sw+ lst) (string-intersperse (filter-map (lambda (x) (and x (->string x))) lst) " ")) (define (slp p lst) (string-intersperse (map ->string lst) p)) (define nl "\n") (define (spaces n) (list->string (list-tabulate n (lambda (x) #\space)))) (define (ppf indent . lst) (let ((sp (spaces indent))) (for-each (lambda (x) (and x (match x ((i . x1) (if (and (number? i) (positive? i)) (for-each (lambda (x) (ppf (+ indent i) x)) x1) (print sp (sw+ x)))) (else (print sp (if (list? x) (sw+ x) x)))))) lst))) (define (transitions-graph n open transitions conserve state-name) (let* ((subst-convert (subst-driver (lambda (x) (and (symbol? x) x)) nemo:binding? identity nemo:bind nemo:subst-term)) (g (make-digraph n (string-append (->string n) " transitions graph"))) (add-node! (g 'add-node!)) (add-edge! (g 'add-edge!)) (out-edges (g 'out-edges)) (in-edges (g 'in-edges)) (node-info (g 'node-info)) (node-list (let loop ((lst (list)) (tlst transitions)) (if (null? tlst) (delete-duplicates lst eq?) (match (car tlst) (('-> (and (? symbol?) s0) (and (? symbol?) s1) rate-expr) (loop (cons* s0 s1 lst) (cdr tlst))) (((and (? symbol?) s0) '-> (and (? symbol? s1)) rate-expr) (loop (cons* s0 s1 lst) (cdr tlst))) (('<-> (and (? symbol?) s0) (and (? symbol?) s1) rate-expr1 rate-expr2) (loop (cons* s0 s1 lst) (cdr tlst))) (((and (? symbol?) s0) 'M-> (and (? symbol? s1)) rate-expr1 rate-expr2) (loop (cons* s0 s1 lst) (cdr tlst))) (else (nemo:error 'state-eqs ": invalid transition equation " (car tlst) " in state complex " n)) (else (loop lst (cdr tlst))))))) (node-ids (list-tabulate (length node-list) identity)) (name->id-map (zip node-list node-ids)) (node-subs (fold (lambda (s ax) (subst-extend s (state-name n s) ax)) subst-empty node-list))) ;; insert state nodes in the dependency graph (for-each (lambda (i n) (add-node! i n)) node-ids node-list) (let* ((nodes ((g 'nodes))) (conserve (and (pair? conserve) (car conserve))) ;; if a conservation equation is present, we eliminate one ;; transition equation from the system (cvars (and conserve (enum-freevars (third conserve) '() '()))) (cnode (and conserve (find (lambda (s) (let ((n (second s))) (and (member n cvars) (not (eq? n open))))) nodes))) (cname (and cnode (second cnode))) (cnexpr (and cnode (let* ((cvars1 (filter-map (lambda (n) (and (not (eq? n cname)) n)) cvars)) (sumvar (gensym "sum"))) `(let ((,sumvar ,(sum cvars1))) (- ,(first conserve) ,sumvar))))) (add-tredge (lambda (s0 s1 rexpr1 rexpr2) (let* ((i (car (alist-ref s0 name->id-map))) (j (car (alist-ref s1 name->id-map))) (x0 (if (and cnode (eq? s0 cname)) cnexpr s0)) (x1 (if (and cnode (eq? s1 cname)) cnexpr s1)) (ij-expr `(* ,(subst-convert x0 node-subs) ,(subst-convert rexpr1 node-subs))) (ji-expr (and rexpr2 `(* ,(subst-convert x1 node-subs) ,(subst-convert rexpr2 node-subs))))) (add-edge! (list i j ij-expr)) (if rexpr2 (add-edge! (list j i ji-expr))))))) ;; create rate edges in the graph (for-each (lambda (e) (match e (('-> s0 s1 rexpr) (add-tredge s0 s1 rexpr #f)) ((s0 '-> s1 rexpr) (add-tredge s0 s1 rexpr #f)) (('<-> s0 s1 rexpr1 rexpr2) (add-tredge s0 s1 rexpr1 rexpr2)) ((s0 '<-> s1 rexpr1 rexpr2) (add-tredge s0 s1 rexpr1 rexpr2)) )) transitions) (list g cnode node-subs)))) (define (state-conseqs n transitions conseqs state-name) (let* ((subst-convert (subst-driver (lambda (x) (and (symbol? x) x)) nemo:binding? identity nemo:bind nemo:subst-term)) (state-list (let loop ((lst (list)) (tlst transitions)) (if (null? tlst) (delete-duplicates lst eq?) (match (car tlst) (('-> (and (? symbol?) s0) (and (? symbol?) s1) rate-expr) (loop (cons* s0 s1 lst) (cdr tlst))) (((and (? symbol?) s0) '-> (and (? symbol? s1)) rate-expr) (loop (cons* s0 s1 lst) (cdr tlst))) (('<-> (and (? symbol?) s0) (and (? symbol?) s1) rate-expr1 rate-expr2) (loop (cons* s0 s1 lst) (cdr tlst))) (((and (? symbol?) s0) 'M-> (and (? symbol? s1)) rate-expr1 rate-expr2) (loop (cons* s0 s1 lst) (cdr tlst))) (else (nemo:error 'nemo:state-conseq ": invalid transition equation " (car tlst) " in state complex " n)) (else (loop lst (cdr tlst))))))) (state-subs (fold (lambda (s ax) (subst-extend s (state-name n s) ax)) subst-empty state-list)) (conseqs1 (map (lambda (conseq) (match conseq ((i '= . expr) `(,i = . ,(subst-convert expr state-subs))))) conseqs))) (list n conseqs1))) (define (make-output-fname dirname sysname suffix . rest) (let-optionals rest ((x #t)) (and x (if (string? x) x (let ((fname (s+ sysname suffix))) (or (and dirname (make-pathname dirname fname)) fname))) ))) (define LOG10E 0.434294481903252) (define LOG2E 1.44269504088896) (define (differentiate fenv x t) (define subst-convert (subst-driver (lambda (x) (and (symbol? x) x)) nemo:binding? identity nemo:bind nemo:subst-term)) (cond ((number? t) 0.0) ((symbol? t) (cond ((string=? (->string x) (->string t)) 1.0) (else 0.0))) (else (match t (('neg u) `(neg ,(differentiate fenv x u))) (('+ u v) `(+ ,(differentiate fenv x u) ,(differentiate fenv x v))) (('- u v) `(- ,(differentiate fenv x u) ,(differentiate fenv x v))) (('* (and u (? number?)) v) `(* ,u ,(differentiate fenv x v))) (('* v (and u (? number?))) `(* ,u ,(differentiate fenv x v))) (('* u v) `(+ (* ,(differentiate fenv x u) ,v) (* ,u ,(differentiate fenv x v)))) (('/ u v) `(/ (- (* ,(differentiate fenv x u) ,v) (* ,u ,(differentiate fenv x v))) (pow ,v 2.0))) (('cube u) (differentiate fenv x `(pow ,u 3.0))) (('pow u n) (chain fenv x u `(* ,n (pow ,u (- ,n 1.0))))) (('sqrt u) (chain fenv x u `(/ 1.0 (* 2.0 (sqrt ,u))))) (('exp u) (chain fenv x u `(exp ,u))) (('log u) (chain fenv x u `(/ 1.0 ,u))) (('log10 u) (chain fenv x u `(* ,LOG10E (/ ,(differentiate fenv x u) ,u)))) (('log2 u) (chain fenv x u `(* ,LOG2E (/ ,(differentiate fenv x u) ,u)))) (('log1p u) (differentiate fenv x `(log (+ 1.0 ,u)))) (('ldexp u n) (differentiate fenv x `(* ,u ,(expt 2 n)))) (('sin u) (chain fenv x u `(cos ,u))) (('cos u) (chain fenv x u `(neg (sin ,u)))) (('tan u) (differentiate fenv x `(* (sin ,u) (/ 1.0 (cos ,u))))) (('asin u) (chain fenv x u `(/ 1.0 (sqrt (- 1.0 (pow ,u 2.0)))))) (('acos u) (chain fenv x u `(/ (neg 1.0) (sqrt (- 1.0 (pow ,u 2.0)))))) (('atan u) (chain fenv x u `(/ 1.0 (+ 1.0 (pow ,u 2.0))))) (('sinh u) (differentiate fenv x `(/ (- (exp ,u) (exp (neg ,u))) 2.0))) (('cosh u) (differentiate fenv x `(/ (+ (exp ,u) (exp (neg ,u))) 2.0))) (('tanh u) (differentiate fenv x `(/ (sinh ,u) (cosh ,u)))) (('let bnds body) (let* ((body1 (subst-convert body bnds))) (differentiate fenv x body1))) ((op . us) (let ((fv (enum-freevars t '() '()))) (if (member x fv) (cond ((lookup-def op fenv) => (lambda (fs) (cond ((and (pair? fs) (pair? us)) `(+ . ,(map (lambda (fu u) (chain fenv x u `(,fu ,u))) fs us))) (else (chain fenv x us `(,fs ,us)))))) (else #f)) 0.0))) (else #f))))) (define (chain fenv x t u) (if (symbol? t) u `(* ,(differentiate fenv x t) ,u))) (define (simplify t) (match t (('neg 0.0) 0.0) (('+ 0.0 0.0) 0.0) (('+ 0.0 t1) t1) (('+ t1 0.0) t1) (('+ t1 ('neg t2)) `(- ,t1 ,t2)) (('+ (and t1 (? number?)) (and t2 (? number?))) (+ t1 t2)) (('- 0.0 0.0) 0.0) (('- 0.0 t1) `(neg ,t1)) (('- t1 0.0) t1) (('neg ('neg t1)) t1) (('- (and t1 (? number?)) (and t2 (? number?))) (- t1 t2)) (('* 0.0 0.0) 0.0) (('* 0.0 t1) 0.0) (('* t1 0.0) 0.0) (('* 1.0 t1) t1) (('* t1 1.0) t1) (('* ('neg t1) ('neg t2)) `(* ,t1 ,t2)) (('* (and t1 (? number?)) (and t2 (? number?))) (* t1 t2)) (('/ 0.0 t1) 0.0) (('pow t1 0.0) 1.0) (('pow t1 1.0) t1) (('pow (and t1 (? number?)) (and t2 (? number?))) (expt t1 t2)) (('let () body) (simplify body)) (('let bnds body) `(let ,(map (match-lambda ((v b) `(v ,(simplify b))) (else #f)) bnds) ,(simplify body))) ((op . ts) `(,op . ,(map simplify ts))) (else t))) (define (distribute t) (match t (((and (or '+ '- '* '/) op) x y) `(,op ,(distribute x) ,(distribute y))) (((and (or '+ '- '* '/) op) x y z) `(,op ,(distribute x) (,op ,(distribute y) ,(distribute z)))) (((and (or '+ '- '* '/) op) . lst) (let* ((n (length lst)) (n/2 (inexact->exact (round (/ n 2))))) `(,op ,(distribute `(,op . ,(take lst n/2))) ,(distribute `(,op . ,(drop lst n/2 )))))) (('let bnds body) `(let ,(map (match-lambda ((v b) `(,v ,(distribute b))) (else #f)) bnds) ,(distribute body))) ((op . ts) `(,op . ,(map distribute ts))) (else t))) )