;; ;; ;; Utility procedures for NEMO code generators. ;; ;; Copyright 2008-2009 Ivan Raikov and the Okinawa Institute of Science and Technology ;; ;; This program is free software: you can redistribute it and/or ;; modify it under the terms of the GNU General Public License as ;; published by the Free Software Foundation, either version 3 of the ;; License, or (at your option) any later version. ;; ;; This program is distributed in the hope that it will be useful, but ;; WITHOUT ANY WARRANTY; without even the implied warranty of ;; MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU ;; General Public License for more details. ;; ;; A full copy of the GPL license can be found at ;; . ;; (module nemo-utils (lookup-def enum-bnds enum-freevars sum if-convert let-enum let-elim let-lift s+ sw+ sl\ nl spaces ppf transitions-graph state-lineqs differentiate simplify ) (import scheme chicken data-structures srfi-1 srfi-13) (require-extension matchable strictly-pretty varsubst digraph nemo-core) (define (lookup-def k lst . rest) (let-optionals rest ((default #f)) (let ((kv (assoc k lst))) (if (not kv) default (match kv ((k v) v) (else (cdr kv))))))) (define (enum-bnds expr ax) (match expr (('if . es) (fold enum-bnds ax es)) (('let bnds body) (enum-bnds body (append (map car bnds) (fold enum-bnds ax (map cadr bnds))))) ((s . es) (if (symbol? s) (fold enum-bnds ax es) ax)) (else ax))) (define (enum-freevars expr bnds ax) (match expr (('if . es) (fold (lambda (x ax) (enum-freevars x bnds ax)) ax es)) (('let bnds body) (let ((bnds1 (append (map first bnds) bnds))) (enum-freevars body bnds1 (fold (lambda (x ax) (enum-freevars x bnds ax)) ax (map second bnds))))) ((s . es) (if (symbol? s) (fold (lambda (x ax) (enum-freevars x bnds ax)) ax es) ax)) (id (if (and (symbol? id) (not (member id bnds))) (cons id ax) ax)))) (define (sum lst) (if (null? lst) lst (match lst ((x) x) ((x y) `(+ ,x ,y)) ((x y . rest) `(+ (+ ,x ,y) ,(sum rest))) ((x . rest) `(+ ,x ,(sum rest)))))) (define (if-convert expr) (match expr (('if c t e) (let ((r (gensym "if"))) `(let ((,r (if ,(if-convert c) ,(if-convert t) ,(if-convert e)))) ,r))) (('let bs e) `(let ,(map (lambda (b) `(,(car b) ,(if-convert (cadr b)))) bs) ,(if-convert e))) ((f . es) (cons f (map if-convert es))) ((? atom? ) expr))) (define (let-enum expr ax) (match expr (('let ((x ('if c t e))) y) (let ((ax (fold let-enum ax (list c )))) (if (eq? x y) (append ax (list (list x `(if ,c ,t ,e)))) ax))) (('let bnds body) (let-enum body (append ax bnds))) (('if c t e) (let-enum ax c)) ((f . es) (fold let-enum ax es)) (else ax))) (define (let-elim expr) (match expr (('let ((x ('if c t e))) y) (if (eq? x y) y expr)) (('let bnds body) (let-elim body)) (('if c t e) `(if ,(let-elim c) ,(let-lift t) ,(let-lift e))) ((f . es) `(,f . ,(map let-elim es))) (else expr))) (define (let-lift expr) (let ((bnds (let-enum expr (list)))) (if (null? bnds) (let-elim expr) `(let ,(map (lambda (b) (list (car b) (let-elim (cadr b)))) bnds) ,(let-elim expr))))) (define (s+ . lst) (string-concatenate (map ->string lst))) (define (sw+ lst) (string-intersperse (filter-map (lambda (x) (and x (->string x))) lst) " ")) (define (sl\ p lst) (string-intersperse (map ->string lst) p)) (define nl "\n") (define (spaces n) (list->string (list-tabulate n (lambda (x) #\space)))) (define (ppf indent . lst) (let ((sp (spaces indent))) (for-each (lambda (x) (and x (match x ((i . x1) (if (and (number? i) (positive? i)) (for-each (lambda (x) (ppf (+ indent i) x)) x1) (print sp (sw+ x)))) (else (print sp (if (list? x) (sw+ x) x)))))) lst))) (define (transitions-graph n open transitions state-name) (let* ((subst-convert (subst-driver (lambda (x) (and (symbol? x) x)) nemo:binding? identity nemo:bind nemo:subst-term)) (g (make-digraph n (string-append (->string n) " transitions graph"))) (add-node! (g 'add-node!)) (add-edge! (g 'add-edge!)) (out-edges (g 'out-edges)) (in-edges (g 'in-edges)) (node-info (g 'node-info)) (node-list (let loop ((lst (list)) (tlst transitions)) (if (null? tlst) (delete-duplicates lst eq?) (match (car tlst) (('-> (and (? symbol?) s0) (and (? symbol?) s1) rate-expr) (loop (cons* s0 s1 lst) (cdr tlst))) (((and (? symbol?) s0) '-> (and (? symbol? s1)) rate-expr) (loop (cons* s0 s1 lst) (cdr tlst))) (('<-> (and (? symbol?) s0) (and (? symbol?) s1) rate-expr1 rate-expr2) (loop (cons* s0 s1 lst) (cdr tlst))) (((and (? symbol?) s0) 'M-> (and (? symbol? s1)) rate-expr1 rate-expr2) (loop (cons* s0 s1 lst) (cdr tlst))) (else (nemo:error 'state-eqs ": invalid transition equation " (car tlst) " in state complex " n)) (else (loop lst (cdr tlst))))))) (node-ids (list-tabulate (length node-list) identity)) (name->id-map (zip node-list node-ids)) (node-subs (fold (lambda (s ax) (subst-extend s (state-name n s) ax)) subst-empty node-list))) ;; insert state nodes in the dependency graph (for-each (lambda (i n) (add-node! i n)) node-ids node-list) (let* ((nodes ((g 'nodes))) (snode (find (lambda (s) (not (eq? (second s) open))) nodes)) (snex (let ((nodes/s (filter-map (lambda (s) (and (not (= (first s) (first snode))) (second s))) nodes)) (sumvar (gensym "sum"))) `(let ((,sumvar ,(sum nodes/s))) (- 1 ,sumvar)))) (add-tredge (lambda (s0 s1 rexpr1 rexpr2) (let* ((i (car (alist-ref s0 name->id-map))) (j (car (alist-ref s1 name->id-map))) (x0 (if (eq? s0 (second snode)) snex s0)) (x1 (if (eq? s1 (second snode)) snex s1)) (ij-expr `(* ,(subst-convert x0 node-subs) ,(subst-convert rexpr1 node-subs))) (ji-expr (and rexpr2 `(* ,(subst-convert x1 node-subs) ,(subst-convert rexpr2 node-subs))))) (add-edge! (list i j ij-expr)) (if rexpr2 (add-edge! (list j i ji-expr))))))) ;; create rate edges in the graph (for-each (lambda (e) (match e (('-> s0 s1 rexpr) (add-tredge s0 s1 rexpr #f)) ((s0 '-> s1 rexpr) (add-tredge s0 s1 rexpr #f)) (('<-> s0 s1 rexpr1 rexpr2) (add-tredge s0 s1 rexpr1 rexpr2)) ((s0 '<-> s1 rexpr1 rexpr2) (add-tredge s0 s1 rexpr1 rexpr2)) )) transitions) (list g node-subs)))) (define (state-lineqs n transitions lineqs state-name) (let* ((subst-convert (subst-driver (lambda (x) (and (symbol? x) x)) nemo:binding? identity nemo:bind nemo:subst-term)) (state-list (let loop ((lst (list)) (tlst transitions)) (if (null? tlst) (delete-duplicates lst eq?) (match (car tlst) (('-> (and (? symbol?) s0) (and (? symbol?) s1) rate-expr) (loop (cons* s0 s1 lst) (cdr tlst))) (((and (? symbol?) s0) '-> (and (? symbol? s1)) rate-expr) (loop (cons* s0 s1 lst) (cdr tlst))) (('<-> (and (? symbol?) s0) (and (? symbol?) s1) rate-expr1 rate-expr2) (loop (cons* s0 s1 lst) (cdr tlst))) (((and (? symbol?) s0) 'M-> (and (? symbol? s1)) rate-expr1 rate-expr2) (loop (cons* s0 s1 lst) (cdr tlst))) (else (nemo:error 'nemo:state-lineq ": invalid transition equation " (car tlst) " in state complex " n)) (else (loop lst (cdr tlst))))))) (state-subs (fold (lambda (s ax) (subst-extend s (state-name n s) ax)) subst-empty state-list)) (lineqs1 (map (lambda (lineq) (match lineq ((i '= . expr) `(,i = . ,(subst-convert expr state-subs))))) lineqs))) (list n lineqs1))) ;; `(+ - * / pow neg abs atan asin acos sin cos exp ln ;; sqrt tan cosh sinh tanh hypot gamma lgamma log10 log2 log1p ldexp cube ;; > < <= >= = and or round ceiling floor max min ;; fpvector-ref)) (define LOG10E 0.434294481903252) (define LOG2E 1.44269504088896) (define (differentiate fenv x t) (define subst-convert (subst-driver (lambda (x) (and (symbol? x) x)) nemo:binding? identity nemo:bind nemo:subst-term)) (cond ((number? t) 0.0) ((symbol? t) (cond ((equal? x t) 1.0) (else 0.0))) (else (match t (('neg u) `(neg ,(differentiate fenv x u))) (('+ u v) `(+ ,(differentiate fenv x u) ,(differentiate fenv x v))) (('- u v) `(- ,(differentiate fenv x u) ,(differentiate fenv x v))) (('* (and u (? number?)) v) `(* ,u ,(differentiate fenv x v))) (('* v (and u (? number?))) `(* ,u ,(differentiate fenv x v))) (('* u v) `(+ (* ,(differentiate fenv x u) ,v) (* ,u ,(differentiate fenv x v)))) (('/ u v) `(/ (- (* ,(differentiate fenv x u) ,v) (* ,u ,(differentiate fenv x v))) (pow ,v 2.0))) (('cube u) (differentiate fenv x `(pow ,u 3.0))) (('pow u n) (chain fenv x u `(* ,n (pow ,u (- ,n 1.0))))) (('sqrt u) (chain fenv x u `(/ 1.0 (* 2.0 (sqrt ,u))))) (('exp u) (chain fenv x u `(exp ,u))) (('log u) (chain fenv x u `(/ 1.0 ,u))) (('log10 u) (chain fenv x u `(* ,LOG10E (/ ,(differentiate fenv x u) ,u)))) (('log2 u) (chain fenv x u `(* ,LOG2E (/ ,(differentiate fenv x u) ,u)))) (('log1p u) (differentiate fenv x `(log (+ 1.0 ,u)))) (('ldexp u n) (differentiate fenv x `(* ,u ,(expt 2 n)))) (('sin u) (chain fenv x u `(cos ,u))) (('cos u) (chain fenv x u `(neg (sin ,u)))) (('tan u) (differentiate fenv x `(* (sin ,u) (/ 1.0 (cos ,u))))) (('asin u) (chain fenv x u `(/ 1.0 (sqrt (- 1.0 (pow ,u 2.0)))))) (('acos u) (chain fenv x u `(/ (neg 1.0) (sqrt (- 1.0 (pow ,u 2.0)))))) (('atan u) (chain fenv x u `(/ 1.0 (+ 1.0 (pow ,u 2.0))))) (('sinh u) (differentiate fenv x `(/ (- (exp ,u) (exp (neg ,u))) 2.0))) (('cosh u) (differentiate fenv x `(/ (+ (exp ,u) (exp (neg ,u))) 2.0))) (('tanh u) (differentiate fenv x `(/ (sinh ,u) (cosh ,u)))) (('let bnds body) (let ((body1 (subst-convert body bnds))) (differentiate fenv x body1))) ((op . us) (let ((fv (enum-freevars t '() '()))) (if (member x fv) (cond ((lookup-def op fenv) => (lambda (fs) (cond ((and (pair? fs) (pair? us)) `(+ . ,(map (lambda (fu u) (chain fenv x u `(,fu ,u))) fs us))) (else (chain fenv x us `(,fs ,us)))))) (else #f)) 0.0))) (else #f))))) (define (chain fenv x t u) (if (symbol? t) u `(* ,(differentiate fenv x t) ,u))) (define (simplify t) (match t (('neg 0.0) 0.0) (('+ 0.0 0.0) 0.0) (('+ 0.0 t1) t1) (('+ t1 0.0) t1) (('+ t1 ('neg t2)) `(- ,t1 ,t2)) (('+ (and t1 (? number?)) (and t2 (? number?))) (+ t1 t2)) (('- 0.0 0.0) 0.0) (('- 0.0 t1) `(neg ,t1)) (('- t1 0.0) t1) (('neg ('neg t1)) t1) (('- (and t1 (? number?)) (and t2 (? number?))) (- t1 t2)) (('* 0.0 0.0) 0.0) (('* 0.0 t1) 0.0) (('* t1 0.0) 0.0) (('* 1.0 t1) t1) (('* t1 1.0) t1) (('* ('neg t1) ('neg t2)) `(* ,t1 ,t2)) (('* (and t1 (? number?)) (and t2 (? number?))) (* t1 t2)) (('/ 0.0 t1) 0.0) (('pow t1 0.0) 1.0) (('pow t1 1.0) t1) (('pow (and t1 (? number?)) (and t2 (? number?))) (expt t1 t2)) (('let bnds body) `(let ,(map (match-lambda ((v b) `(v ,(simplify b))) (else #f)) bnds) ,(simplify body))) ((op . ts) `(,op . ,(map simplify ts))) (else t))) )