(module websockets ( ; parameters ping-interval close-timeout connection-timeout accept-connection drop-incoming-pings propagate-common-errors max-frame-size max-message-size ; high level API with-websocket with-concurrent-websocket send-message receive-message ; low level API send-frame read-frame read-frame-payload receive-fragments valid-utf8? control-frame? upgrade-to-websocket current-websocket unmask close-websocket process-fragments ; fragment make-fragment fragment? fragment-payload fragment-length fragment-masked? fragment-masking-key fragment-last? fragment-optype ) (import chicken scheme data-structures extras ports posix foreign) (use srfi-1 srfi-4 spiffy intarweb uri-common base64 simple-sha1 srfi-18 srfi-13 miscmacros mailbox) ; TODO make sure all C operations check args to prevent overflows (foreign-declare "#include \"utf8validator.c\"") (define-inline (neq? obj1 obj2) (not (eq? obj1 obj2))) (define current-websocket (make-parameter #f)) (define ping-interval (make-parameter 15)) (define close-timeout (make-parameter 5)) (define connection-timeout (make-parameter 58)) (define accept-connection (make-parameter (lambda (origin) #t))) (define drop-incoming-pings (make-parameter #t)) (define propagate-common-errors (make-parameter #f)) (define max-frame-size (make-parameter 65536)) ; 64KiB (define max-message-size (make-parameter 1048576)) ; 1MiB (define (make-websocket-exception . conditions) (apply make-composite-condition (append `(,(make-property-condition 'websocket)) conditions))) (define (make-invalid-header-exception type k v) (make-composite-condition (make-websocket-exception (make-property-condition type k v)) (make-property-condition 'invalid-header))) (define (make-protocol-violation-exception msg) (make-composite-condition (make-property-condition 'websocket) (make-property-condition 'protocol-error 'msg msg))) (define (opcode->optype op) (case op ((0) 'continuation) ((1) 'text) ((2) 'binary) ((8) 'connection-close) ((9) 'ping) ((10) 'pong) (else (signal (make-protocol-violation-exception "bad opcode"))))) (define (optype->opcode t) (case t ('continuation 0) ('text 1) ('binary 2) ('connection-close 8) ('ping 9) ('pong 10) (else (error "bad optype")))) ; TODO (define (control-frame? optype) (or (eq? optype 'ping) (eq? optype 'pong) (eq? optype 'connection-close))) (define-record-type websocket (make-websocket inbound-port outbound-port user-thread send-mutex read-mutex last-message-timestamp state send-mailbox read-mailbox concurrent) websocket? (inbound-port websocket-inbound-port) (outbound-port websocket-outbound-port) (user-thread websocket-user-thread) (send-mutex websocket-send-mutex) (read-mutex websocket-read-mutex) (last-message-timestamp websocket-last-message-timestamp set-websocket-last-message-timestamp!) (state websocket-state set-websocket-state!) (send-mailbox websocket-send-mailbox) (read-mailbox websocket-read-mailbox) (concurrent websocket-concurrent?)) (define-record-type websocket-fragment (make-fragment payload length masked masking-key fin optype) fragment? (payload fragment-payload) (length fragment-length) (masked fragment-masked?) (masking-key fragment-masking-key) (fin fragment-last?) (optype fragment-optype)) (define (string->bytes str) (let* ((lst (map char->integer (string->list str))) (bv (make-u8vector (length lst)))) (let loop ((lst lst) (pos 0)) (if (null? lst) bv (begin (u8vector-set! bv pos (car lst)) (loop (cdr lst) (+ pos 1))))))) (define (hex-string->string hexstr) ;; convert a string like "a745ff12" to a string (let ((result (make-string (/ (string-length hexstr) 2)))) (let loop ((hexs (string->list hexstr)) (i 0)) (if (< (length hexs) 2) result (let ((ascii (string->number (string (car hexs) (cadr hexs)) 16))) (string-set! result i (integer->char ascii)) (loop (cddr hexs) (+ i 1))))))) (define (send-frame ws optype data last-frame) ; TODO this sucks (when (u8vector? data) (set! data (blob->string (u8vector->blob/shared data)))) (let* ((len (if (string? data) (string-length data) (u8vector-length data))) (frame-fin (if last-frame 1 0)) (frame-rsv1 0) (frame-rsv2 0) (frame-rsv3 0) (frame-opcode (optype->opcode optype)) (octet0 (bitwise-ior (arithmetic-shift frame-fin 7) (arithmetic-shift frame-rsv1 6) (arithmetic-shift frame-rsv2 5) (arithmetic-shift frame-rsv3 4) frame-opcode)) (frame-masked 0) (frame-payload-length (cond ((< len 126) len) ((< len 65536) 126) (else 127))) (octet1 (bitwise-ior (arithmetic-shift frame-masked 7) frame-payload-length)) (outbound-port (websocket-outbound-port ws))) (write-u8vector (u8vector octet0 octet1) outbound-port) (write-u8vector (cond ((= frame-payload-length 126) (u8vector (arithmetic-shift (bitwise-and len 65280) -8) (bitwise-and len 255))) ((= frame-payload-length 127) (u8vector 0 0 0 0 (arithmetic-shift (bitwise-and len 4278190080) -24) (arithmetic-shift (bitwise-and len 16711680) -16) (arithmetic-shift (bitwise-and len 65280) -8) (bitwise-and len 255))) (else (u8vector))) outbound-port) (write-string data len outbound-port) #t)) (define (send-message optype #!optional (data "") (ws (current-websocket))) ;; TODO break up large data into multiple frames? (dynamic-wind (lambda () (mutex-lock! (websocket-send-mutex ws))) (lambda () (send-frame ws optype data #t)) (lambda () (mutex-unlock! (websocket-send-mutex ws))))) (define (websocket-unmask-frame-payload payload len frame-masking-key) (define tmaskkey (make-u8vector 4 #f #t #t)) (u8vector-set! tmaskkey 0 (vector-ref frame-masking-key 0)) (u8vector-set! tmaskkey 1 (vector-ref frame-masking-key 1)) (u8vector-set! tmaskkey 2 (vector-ref frame-masking-key 2)) (u8vector-set! tmaskkey 3 (vector-ref frame-masking-key 3)) (define-external wsmaskkey blob (u8vector->blob/shared tmaskkey)) (define-external wslen int len) (define-external wsv scheme-pointer payload) ((foreign-lambda* void () " const unsigned char* maskkey2 = wsmaskkey; const unsigned int kd = *(unsigned int*)maskkey2; const unsigned char* __restrict kb = maskkey2; for (int i = wslen >> 2; i != 0; --i) { *((unsigned int*)wsv) ^= kd; wsv += 4; } const int rem = wslen & 3; for (int i = 0; i < rem; ++i) { *((unsigned int*)wsv++) ^= kb[i]; } " )) payload) (define (unmask fragment) (if (fragment-masked? fragment) (websocket-unmask-frame-payload (fragment-payload fragment) (fragment-length fragment) (fragment-masking-key fragment)) (fragment-payload fragment))) (define (read-frame-payload inbound-port frame-payload-length) (let ((masked-data (make-string frame-payload-length))) (read-string! frame-payload-length masked-data inbound-port) masked-data) ;; (let* ((masked-data (make-string frame-payload-length))) ;; (read-string! frame-payload-length masked-data inbound-port) ;; (define tmaskkey (make-u8vector 4 #f #t #t)) ;; (u8vector-set! tmaskkey 0 (vector-ref frame-masking-key 0)) ;; (u8vector-set! tmaskkey 1 (vector-ref frame-masking-key 1)) ;; (u8vector-set! tmaskkey 2 (vector-ref frame-masking-key 2)) ;; (u8vector-set! tmaskkey 3 (vector-ref frame-masking-key 3)) ;; (define-external wsmaskkey blob (u8vector->blob/shared tmaskkey)) ;; (define-external wslen int frame-payload-length) ;; (define-external wsv scheme-pointer masked-data) ;; (if frame-masked ;; (begin ;; ((foreign-lambda* void () ;; " ;; const unsigned char* maskkey2 = wsmaskkey; ;; const unsigned int kd = *(unsigned int*)maskkey2; ;; const unsigned char* __restrict kb = maskkey2; ;; for (int i = wslen >> 2; i != 0; --i) ;; { ;; *((unsigned int*)wsv) ^= kd; ;; wsv += 4; ;; } ;; const int rem = wslen & 3; ;; for (int i = 0; i < rem; ++i) ;; { ;; *((unsigned int*)wsv++) ^= kb[i]; ;; } ;; " ;; )) ;; masked-data) ;; masked-data)) ) (define (read-frame total-size ws) (let* ((inbound-port (websocket-inbound-port ws)) (b0 (read-byte inbound-port))) ; we don't support reserved bits yet (when (or (> (bitwise-and b0 64) 0) (> (bitwise-and b0 32) 0) (> (bitwise-and b0 16) 0)) (signal (make-websocket-exception (make-property-condition 'reserved-bits-not-supported) (make-property-condition 'protocol-error)))) (cond ((eof-object? b0) b0) (else (let* ((frame-fin (> (bitwise-and b0 128) 0)) (frame-opcode (bitwise-and b0 15)) (frame-optype (opcode->optype frame-opcode)) ;; second byte (b1 (read-byte inbound-port)) ; TODO die on unmasked frame? (frame-masked (> (bitwise-and b1 128) 0)) (frame-payload-length (bitwise-and b1 127))) (cond ((= frame-payload-length 126) (let ((bl0 (read-byte inbound-port)) (bl1 (read-byte inbound-port))) (set! frame-payload-length (+ (arithmetic-shift bl0 8) bl1)))) ((= frame-payload-length 127) (define (shift i r) (if (< i 0) r (shift (- i 1) (+ (arithmetic-shift (read-byte inbound-port) (* 8 i)) r)))) (set! frame-payload-length (shift 7 0)))) (when (or (> frame-payload-length (max-frame-size)) (> (+ frame-payload-length total-size) (max-message-size))) (signal (make-websocket-exception (make-property-condition 'message-too-large)))) (let* ((frame-masking-key (if frame-masked (let* ((fm0 (read-byte inbound-port)) (fm1 (read-byte inbound-port)) (fm2 (read-byte inbound-port)) (fm3 (read-byte inbound-port))) (vector fm0 fm1 fm2 fm3)) #f))) (cond ((or (eq? frame-optype 'text) (eq? frame-optype 'binary) (eq? frame-optype 'continuation) (eq? frame-optype 'ping) (eq? frame-optype 'pong)) (make-fragment (read-frame-payload inbound-port frame-payload-length) frame-payload-length frame-masked frame-masking-key frame-fin frame-optype)) ((eq? frame-optype 'connection-close) (make-fragment (read-frame-payload inbound-port frame-payload-length) frame-payload-length frame-masked frame-masking-key frame-fin frame-optype)) (else (thread-signal! (websocket-user-thread ws) (make-websocket-exception (make-property-condition 'unhandled-opcode 'optype frame-optype))) (signal (make-websocket-exception (make-property-condition 'unhandled-opcode 'optype frame-optype))))))))))) (define (valid-utf8-2? s) (define-external str c-string s) (define-external len int (string-length s)) (zero? ((foreign-lambda* int () " static const uint8_t utf8d[] = { 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, // 00..1f 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, // 20..3f 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, // 40..5f 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, // 60..7f 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9,9, // 80..9f 7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7,7, // a0..bf 8,8,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2,2, // c0..df 0xa,0x3,0x3,0x3,0x3,0x3,0x3,0x3,0x3,0x3,0x3,0x3,0x3,0x4,0x3,0x3, // e0..ef 0xb,0x6,0x6,0x6,0x5,0x8,0x8,0x8,0x8,0x8,0x8,0x8,0x8,0x8,0x8,0x8, // f0..ff 0x0,0x1,0x2,0x3,0x5,0x8,0x7,0x1,0x1,0x1,0x4,0x6,0x1,0x1,0x1,0x1, // s0..s0 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,1,1,1,1,1,0,1,0,1,1,1,1,1,1, // s1..s2 1,2,1,1,1,1,1,2,1,2,1,1,1,1,1,1,1,1,1,1,1,1,1,2,1,1,1,1,1,1,1,1, // s3..s4 1,2,1,1,1,1,1,1,1,2,1,1,1,1,1,1,1,1,1,1,1,1,1,3,1,3,1,1,1,1,1,1, // s5..s6 1,3,1,1,1,1,1,3,1,3,1,1,1,1,1,1,1,3,1,1,1,1,1,1,1,1,1,1,1,1,1,1, // s7..s8 }; uint32_t si; uint32_t *state; si = 0; state = &si; uint32_t type; for (int i = 0; i < len; i++) { // type = utf8d[(uint8_t)str[i]]; type = utf8d[*((uint8_t*)str)]; *state = utf8d[256 + (*state) * 16 + type]; if (*state != 0) // reject break; } C_return(*state); " )) )) (define (valid-utf8? s) (let ((len (string-length s))) ((foreign-lambda int "utf8_valid" scheme-pointer int) s len))) (define (close-code->integer s) (if (string-null? s) 1000 (+ (arithmetic-shift (char->integer (string-ref s 0)) 8) (char->integer (string-ref s 1))))) (define (close-code-string->close-reason s) (let ((c (close-code->integer s))) (case c ((1000) 'normal) ((1001) 'going-away) ((1002) 'protocol-error) ((1003) 'unknown-data-type) ((1007) 'invalid-data) ((1008) 'violated-policy) ((1009) 'message-too-large) ((1010) 'extension-negotiation-failed) ((1011) 'unexpected-error) (else (if (and (>= c 3000) (< c 5000)) 'unknown 'invalid-close-code))))) (define (valid-close-code? s) (neq? 'invalid-close-code (close-code-string->close-reason s))) (define (receive-fragments #!optional (ws (current-websocket))) (dynamic-wind (lambda () (mutex-lock! (websocket-read-mutex ws))) (lambda () (if (or (eq? (websocket-state ws) 'closing) (eq? (websocket-state ws) 'closed) (eq? (websocket-state ws) 'error)) (values #!eof #!eof) (let loop ((fragments '()) (first #t) (type 'text) (total-size 0)) (let* ((fragment (read-frame total-size ws)) (optype (fragment-optype fragment)) (len (fragment-length fragment)) (last-frame (fragment-last? fragment))) (set-websocket-last-message-timestamp! ws (current-time)) (cond ((and (control-frame? optype) (> len 125)) (set-websocket-state! ws 'error) (signal (make-protocol-violation-exception "control frame bodies must be less than 126 octets"))) ; connection close ((and (eq? optype 'connection-close) (= len 1)) (set-websocket-state! ws 'error) (signal (make-protocol-violation-exception "close frames must not have a length of 1"))) ((and (eq? optype 'connection-close) (not (valid-close-code? (unmask fragment)))) (set-websocket-state! ws 'error) (signal (make-protocol-violation-exception (string-append "invalid close code " (number->string (close-code->integer (unmask fragment))))))) ((eq? optype 'connection-close) (set-websocket-state! ws 'closing) (values `(,fragment) optype)) ; immediate response ((and (eq? optype 'ping) last-frame (<= len 125)) (unless (drop-incoming-pings) (send-message 'pong (unmask fragment))) (loop fragments first type total-size)) ; protocol violation checks ((or (and first (eq? optype 'continuation)) (and (not first) (neq? optype 'continuation))) (set-websocket-state! ws 'error) (signal (make-protocol-violation-exception "continuation frame out-of-order"))) ((and (not last-frame) (control-frame? optype)) (set-websocket-state! ws 'error) (signal (make-protocol-violation-exception "control frames can't be fragmented"))) ((eq? optype 'pong) (loop fragments first type total-size)) (else (if last-frame (values (cons fragment fragments) (if (null? fragments) optype type)) (loop (cons fragment fragments) #f (if first optype type) (+ total-size len))))))))) (lambda () (mutex-unlock! (websocket-read-mutex ws))))) (define (process-fragments fragments optype #!optional (ws (current-websocket))) (let ((message-body (string-concatenate/shared (reverse (map unmask fragments))))) (when (and (eq? optype 'text) (not (valid-utf8? message-body))) (set-websocket-state! ws 'error) (signal (make-websocket-exception (make-property-condition 'invalid-data 'msg "invalid UTF-8")))) (values message-body optype))) (define (receive-message #!optional (ws (current-websocket))) (if (websocket-concurrent? ws) (let ((msg (mailbox-receive! (websocket-read-mailbox ws)))) (values (car msg) (cdr msg))) (receive (fragments optype) (receive-fragments ws) (if (eof-object? fragments) (values #!eof optype) (process-fragments fragments optype))))) (define (close-websocket #!optional (ws (current-websocket)) #!key (close-reason 'normal) (data (make-u8vector 0))) (define invalid-close-reason #f) (define (close-reason->close-code reason) (case reason ('normal 1000) ('going-away 1001) ('protocol-error 1002) ('unknown-data-type 1003) ('invalid-data 1007) ('violated-policy 1008) ('message-too-large 1009) ('unexpected-error 1011) (else (set! invalid-close-reason reason) (close-reason->close-code 'unexpected-error)))) ; Use thread timeout to handle the close-timeout (let ((close-thread (make-thread (lambda () (if (eq? (websocket-state ws) 'open) (begin (set-websocket-state! ws 'closed) (send-frame ws 'connection-close (u8vector 3 (close-reason->close-code close-reason)) #t) (let loop () (receive (data type) (receive-message ws) (unless (eq? type 'connection-close) (loop))))) (begin (send-frame ws 'connection-close (u8vector 3 (close-reason->close-code close-reason)) #t))))))) (thread-start! close-thread) (if (> (close-timeout) 0) (unless (thread-join! close-thread (close-timeout) #f) ; TODO actually signal error? ;; (thread-signal! (websocket-user-thread (current-websocket)) ;; (make-websocket-exception ;; (make-property-condition 'close-timeout))) ) (thread-join! close-thread)) (log-to (error-log) "closed"))) (define (sha1-sum in-bv) (hex-string->string (string->sha1sum in-bv))) (define (websocket-compute-handshake client-key) (let* ((key-and-magic ; TODO generate new, randome, secure key every time (string-append client-key "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")) (key-and-magic-sha1 (sha1-sum key-and-magic))) (base64-encode key-and-magic-sha1))) (define (sec-websocket-accept-unparser header-contents) (map (lambda (header-content) (car (vector-ref header-content 0))) header-contents)) (header-unparsers (alist-update! 'sec-websocket-accept sec-websocket-accept-unparser (header-unparsers))) (define (websocket-accept #!optional (concurrent #f)) (let* ((user-thread (current-thread)) (headers (request-headers (current-request))) (client-key (header-value 'sec-websocket-key headers)) (ws-handshake (websocket-compute-handshake client-key)) (ws (make-websocket (request-port (current-request)) (response-port (current-response)) user-thread (make-mutex "send") (make-mutex "read") (current-time) 'open ; websocket state (make-mailbox "send") (make-mailbox "read") concurrent)) (ping-thread (make-thread (lambda () (let loop () (thread-sleep! (ping-interval)) (send-message 'ping "" ws) (loop)))))) ; make sure the request meets the spec for websockets (cond ((not (and (eq? (header-value 'connection headers #f) 'upgrade) (string-ci= (car (header-value 'upgrade headers '(""))) "websocket"))) (signal (make-invalid-header-exception 'upgrade 'value (header-value 'upgrade headers #f)))) ((not (string= (header-value 'sec-websocket-version headers "") "13")) (signal (make-invalid-header-exception 'websocket-version 'version (header-value 'sec-websocket-version headers #f)))) ((not ((accept-connection) (header-value 'origin headers ""))) (signal (make-invalid-header-exception 'origin 'value (header-value 'origin headers #f))))) (with-headers `((upgrade ("WebSocket" . #f)) (connection (upgrade . #t)) (sec-websocket-accept (,ws-handshake . #t))) (lambda () (send-response status: 'switching-protocols))) (flush-output (response-port (current-response))) ; connection timeout thread (when (> (connection-timeout) 0) (thread-start! (lambda () (let loop () (let ((t (websocket-last-message-timestamp ws))) ; Add one to attempt to alleviate checking the timestamp ; right before when the timeout should happen. (thread-sleep! (+ 1 (connection-timeout))) (if (< (- (time->seconds (current-time)) (time->seconds (websocket-last-message-timestamp ws))) (connection-timeout)) (loop) (begin (thread-signal! (websocket-user-thread ws) (make-websocket-exception (make-property-condition 'connection-timeout))) (close-websocket ws close-reason: 1001)))))))) (when (> (ping-interval) 0) (thread-start! ping-thread)) ws)) (define (with-websocket proc #!optional (concurrent #f)) (parameterize ((current-websocket (websocket-accept concurrent))) (condition-case (begin (proc) (close-websocket) (close-input-port (request-port (current-request))) (close-output-port (response-port (current-response)))) (exn (websocket protocol-error) (set-websocket-state! (current-websocket) 'closing) (close-websocket (current-websocket) close-reason: 'protocol-error) (unless (port-closed? (request-port (current-request))) (close-input-port (request-port (current-request)))) (unless (port-closed? (response-port (current-response))) (close-output-port (response-port (current-response)))) (when (propagate-common-errors) (signal exn))) (exn (websocket invalid-data) (set-websocket-state! (current-websocket) 'closing) (close-websocket (current-websocket) close-reason: 'invalid-data) (unless (port-closed? (request-port (current-request))) (close-input-port (request-port (current-request)))) (unless (port-closed? (response-port (current-response))) (close-output-port (response-port (current-response)))) (when (propagate-common-errors) (signal exn))) (exn (websocket connection-timeout) (set-websocket-state! (current-websocket) 'closing) (close-websocket (current-websocket) close-reason: 'going-away) (unless (port-closed? (request-port (current-request))) (close-input-port (request-port (current-request)))) (unless (port-closed? (response-port (current-response))) (close-output-port (response-port (current-response)))) (when (propagate-common-errors) (signal exn))) (exn (websocket message-too-large) (set-websocket-state! (current-websocket) 'closing) (close-websocket (current-websocket) close-reason: 'message-too-large) (unless (port-closed? (request-port (current-request))) (close-input-port (request-port (current-request)))) (unless (port-closed? (response-port (current-response))) (close-output-port (response-port (current-response)))) (when (propagate-common-errors) (signal exn))) (exn () (close-websocket (current-websocket) close-reason: 1011) (unless (port-closed? (request-port (current-request))) (close-input-port (request-port (current-request)))) (unless (port-closed? (response-port (current-response))) (close-output-port (response-port (current-response)))) (signal (make-websocket-exception (make-property-condition 'unexpected-error))))))) (define (with-concurrent-websocket proc) (let ((parent-thread (current-thread))) (with-websocket (lambda () (thread-start! (lambda () (handle-exceptions exn (thread-signal! parent-thread exn) (let loop () (receive (fragments optype) (receive-fragments) (unless (eof-object? fragments) (thread-start! (lambda () (handle-exceptions exn (thread-signal! parent-thread exn) (mailbox-send! (websocket-read-mailbox (current-websocket)) (receive (msg-body optype) (process-fragments fragments optype) `(,msg-body . ,optype)))))) (loop))))))) (proc)) #t))) (define (upgrade-to-websocket #!optional (concurrent #f)) (websocket-accept concurrent)) )