Selaa lähdekoodia

C wsproxy: seq numbers and decode multiple frames.

Joel Martin 15 vuotta sitten
vanhempi
commit
9e61a9c6f0
3 muutettua tiedostoa jossa 122 lisäystä ja 41 poistoa
  1. 95 12
      websocket.c
  2. 2 1
      websocket.h
  3. 25 28
      wsproxy.c

+ 95 - 12
websocket.c

@@ -15,6 +15,7 @@
 #include <arpa/inet.h>
 #include <arpa/inet.h>
 #include <openssl/err.h>
 #include <openssl/err.h>
 #include <openssl/ssl.h>
 #include <openssl/ssl.h>
+#include <resolv.h>      /* base64 encode/decode */
 #include "websocket.h"
 #include "websocket.h"
 
 
 const char server_handshake[] = "HTTP/1.1 101 Web Socket Protocol Handshake\r\n\
 const char server_handshake[] = "HTTP/1.1 101 Web Socket Protocol Handshake\r\n\
@@ -27,6 +28,16 @@ WebSocket-Protocol: sample\r\n\
 
 
 const char policy_response[] = "<cross-domain-policy><allow-access-from domain=\"*\" to-ports=\"*\" /></cross-domain-policy>\n";
 const char policy_response[] = "<cross-domain-policy><allow-access-from domain=\"*\" to-ports=\"*\" /></cross-domain-policy>\n";
 
 
+/*
+ * Global state
+ *
+ *   Warning: not thread safe
+ */
+int ssl_initialized = 0;
+char *tbuf, *cbuf, *tbuf_tmp, *cbuf_tmp;
+unsigned int bufsize, dbufsize;
+client_settings_t client_settings;
+
 void traffic(char * token) {
 void traffic(char * token) {
     fprintf(stdout, "%s", token);
     fprintf(stdout, "%s", token);
     fflush(stdout);
     fflush(stdout);
@@ -47,9 +58,6 @@ void fatal(char *msg)
  * SSL Wrapper Code
  * SSL Wrapper Code
  */
  */
 
 
-/*   Warning: not thread safe */
-int ssl_initialized = 0;
-
 ssize_t ws_recv(ws_ctx_t *ctx, void *buf, size_t len) {
 ssize_t ws_recv(ws_ctx_t *ctx, void *buf, size_t len) {
     if (ctx->ssl) {
     if (ctx->ssl) {
         //printf("SSL recv\n");
         //printf("SSL recv\n");
@@ -147,7 +155,56 @@ int ws_socket_free(ws_ctx_t *ctx) {
 /* ------------------------------------------------------- */
 /* ------------------------------------------------------- */
 
 
 
 
-ws_ctx_t *do_handshake(int sock, client_settings_t *client_settings) {
+int encode(u_char const *src, size_t srclength, char *target, size_t targsize) {
+    int sz = 0, len = 0;
+    target[sz++] = '\x00';
+    if (client_settings.do_seq_num) {
+        sz += sprintf(target+sz, "%d:", client_settings.seq_num);
+        client_settings.seq_num++;
+    }
+    if (client_settings.do_b64encode) {
+        len = __b64_ntop(src, srclength, target+sz, targsize-sz);
+    } else {
+        fatal("UTF-8 not yet implemented");
+    }
+    if (len < 0) {
+        return len;
+    }
+    sz += len;
+    target[sz++] = '\xff';
+    return sz;
+}
+
+int decode(char *src, size_t srclength, u_char *target, size_t targsize) {
+    char *start, *end;
+    int len, retlen = 0;
+    if ((src[0] != '\x00') || (src[srclength-1] != '\xff')) {
+        fprintf(stderr, "WebSocket framing error\n");
+        return -1;
+    }
+    start = src+1; // Skip '\x00' start
+    do {
+        /* We may have more than one frame */
+        end = strchr(start, '\xff');
+        if (end < (src+srclength-1)) {
+            printf("More than one frame to decode\n");
+        }
+        *end = '\x00';
+        if (client_settings.do_b64encode) {
+            len = __b64_pton(start, target+retlen, targsize-retlen);
+        } else {
+            fatal("UTF-8 not yet implemented");
+        }
+        if (len < 0) {
+            return len;
+        }
+        retlen += len;
+        start = end + 2; // Skip '\xff' end and '\x00' start 
+    } while (end < (src+srclength-1));
+    return retlen;
+}
+
+ws_ctx_t *do_handshake(int sock) {
     char handshake[4096], response[4096];
     char handshake[4096], response[4096];
     char *scheme, *line, *path, *host, *origin;
     char *scheme, *line, *path, *host, *origin;
     char *args_start, *args_end, *arg_idx;
     char *args_start, *args_end, *arg_idx;
@@ -155,8 +212,9 @@ ws_ctx_t *do_handshake(int sock, client_settings_t *client_settings) {
     ws_ctx_t * ws_ctx;
     ws_ctx_t * ws_ctx;
 
 
     // Reset settings
     // Reset settings
-    client_settings->b64encode = 0;
-    client_settings->seq_num = 0;
+    client_settings.do_b64encode = 0;
+    client_settings.do_seq_num = 0;
+    client_settings.seq_num = 0;
 
 
     len = recv(sock, handshake, 1024, MSG_PEEK);
     len = recv(sock, handshake, 1024, MSG_PEEK);
     handshake[len] = 0;
     handshake[len] = 0;
@@ -211,12 +269,12 @@ ws_ctx_t *do_handshake(int sock, client_settings_t *client_settings) {
         arg_idx = strstr(args_start, "b64encode");
         arg_idx = strstr(args_start, "b64encode");
         if (arg_idx && arg_idx < args_end) {
         if (arg_idx && arg_idx < args_end) {
             //printf("setting b64encode\n");
             //printf("setting b64encode\n");
-            client_settings->b64encode = 1;
+            client_settings.do_b64encode = 1;
         }
         }
         arg_idx = strstr(args_start, "seq_num");
         arg_idx = strstr(args_start, "seq_num");
         if (arg_idx && arg_idx < args_end) {
         if (arg_idx && arg_idx < args_end) {
             //printf("setting seq_num\n");
             //printf("setting seq_num\n");
-            client_settings->seq_num = 1;
+            client_settings.do_seq_num = 1;
         }
         }
     }
     }
 
 
@@ -228,12 +286,22 @@ ws_ctx_t *do_handshake(int sock, client_settings_t *client_settings) {
 }
 }
 
 
 void start_server(int listen_port,
 void start_server(int listen_port,
-                  void (*handler)(ws_ctx_t*),
-                  client_settings_t *client_settings) {
+                  void (*handler)(ws_ctx_t*)) {
     int lsock, csock, clilen, sopt = 1;
     int lsock, csock, clilen, sopt = 1;
     struct sockaddr_in serv_addr, cli_addr;
     struct sockaddr_in serv_addr, cli_addr;
     ws_ctx_t *ws_ctx;
     ws_ctx_t *ws_ctx;
 
 
+    /* Initialize buffers */
+    bufsize = 65536;
+    if (! (tbuf = malloc(bufsize)) )
+            { fatal("malloc()"); }
+    if (! (cbuf = malloc(bufsize)) )
+            { fatal("malloc()"); }
+    if (! (tbuf_tmp = malloc(bufsize)) )
+            { fatal("malloc()"); }
+    if (! (cbuf_tmp = malloc(bufsize)) )
+            { fatal("malloc()"); }
+
     lsock = socket(AF_INET, SOCK_STREAM, 0);
     lsock = socket(AF_INET, SOCK_STREAM, 0);
     if (lsock < 0) { error("ERROR creating listener socket"); }
     if (lsock < 0) { error("ERROR creating listener socket"); }
     bzero((char *) &serv_addr, sizeof(serv_addr));
     bzero((char *) &serv_addr, sizeof(serv_addr));
@@ -256,8 +324,23 @@ void start_server(int listen_port,
             error("ERROR on accept");
             error("ERROR on accept");
         }
         }
         printf("Got client connection from %s\n", inet_ntoa(cli_addr.sin_addr));
         printf("Got client connection from %s\n", inet_ntoa(cli_addr.sin_addr));
-        ws_ctx = do_handshake(csock, client_settings);
-        if (ws_ctx == NULL) { continue; }
+        ws_ctx = do_handshake(csock);
+        if (ws_ctx == NULL) {
+            close(csock);
+            continue;
+        }
+
+        /* Calculate dbufsize based on client_settings */
+        if (client_settings.do_b64encode) {
+            /* base64 is 4 bytes for every 3
+             *    20 for WS '\x00' / '\xff', seq_num and good measure  */
+            dbufsize = (bufsize * 3)/4 - 20;
+        } else {
+            fatal("UTF-8 not yet implemented");
+            /* UTF-8 encoding is up to 2X larger */
+            dbufsize = (bufsize/2) - 15;
+        }
+
         handler(ws_ctx);
         handler(ws_ctx);
         close(csock);
         close(csock);
     }
     }

+ 2 - 1
websocket.h

@@ -7,7 +7,8 @@ typedef struct {
 } ws_ctx_t;
 } ws_ctx_t;
 
 
 typedef struct {
 typedef struct {
-    int b64encode;
+    int do_b64encode;
+    int do_seq_num;
     int seq_num;
     int seq_num;
 } client_settings_t;
 } client_settings_t;
 
 

+ 25 - 28
wsproxy.c

@@ -11,7 +11,6 @@
 #include <netinet/in.h>
 #include <netinet/in.h>
 #include <netdb.h>
 #include <netdb.h>
 #include <sys/select.h>
 #include <sys/select.h>
-#include <resolv.h>
 #include <fcntl.h>
 #include <fcntl.h>
 #include <sys/stat.h>
 #include <sys/stat.h>
 #include "websocket.h"
 #include "websocket.h"
@@ -35,23 +34,21 @@ void usage() {
 
 
 char *target_host;
 char *target_host;
 int target_port;
 int target_port;
-client_settings_t client_settings;
 char *record_filename = NULL;
 char *record_filename = NULL;
 int recordfd = 0;
 int recordfd = 0;
-char *tbuf, *cbuf, *tbuf_tmp, *cbuf_tmp;
-unsigned int bufsize, dbufsize;
+
+extern char *tbuf, *cbuf, *tbuf_tmp, *cbuf_tmp;
+extern unsigned int bufsize, dbufsize;
 
 
 void do_proxy(ws_ctx_t *ws_ctx, int target) {
 void do_proxy(ws_ctx_t *ws_ctx, int target) {
     fd_set rlist, wlist, elist;
     fd_set rlist, wlist, elist;
     struct timeval tv;
     struct timeval tv;
-    int maxfd, client = ws_ctx->sockfd;
+    int i, maxfd, client = ws_ctx->sockfd;
     unsigned int tstart, tend, cstart, cend, ret;
     unsigned int tstart, tend, cstart, cend, ret;
     ssize_t len, bytes;
     ssize_t len, bytes;
 
 
     tstart = tend = cstart = cend = 0;
     tstart = tend = cstart = cend = 0;
     maxfd = client > target ? client+1 : target+1;
     maxfd = client > target ? client+1 : target+1;
-    // Account for base64 encoding and WebSocket delims:
-    //     49150 = 65536 * 3/4 + 2 - 1
 
 
     while (1) {
     while (1) {
         tv.tv_sec = 1;
         tv.tv_sec = 1;
@@ -137,18 +134,22 @@ void do_proxy(ws_ctx_t *ws_ctx, int target) {
         if (FD_ISSET(target, &rlist)) {
         if (FD_ISSET(target, &rlist)) {
             bytes = recv(target, cbuf_tmp, dbufsize , 0);
             bytes = recv(target, cbuf_tmp, dbufsize , 0);
             if (bytes <= 0) {
             if (bytes <= 0) {
-                error("target closed connection");
+                fprintf(stderr, "target closed connection");
                 break;
                 break;
             }
             }
-            cbuf[0] = '\x00';
             cstart = 0;
             cstart = 0;
-            len = b64_ntop(cbuf_tmp, bytes, cbuf+1, bufsize-1);
-            if (len < 0) {
-                fprintf(stderr, "base64 encoding error\n");
+            cend = encode(cbuf_tmp, bytes, cbuf, bufsize);
+            /*
+            printf("encoded: ");
+            for (i=0; i< bytes; i++) {
+                printf("%d,", *(cbuf+i));
+            }
+            printf("\n");
+            */
+            if (cend < 0) {
+                fprintf(stderr, "encoding error\n");
                 break;
                 break;
             }
             }
-            cbuf[len+1] = '\xff';
-            cend = len+1+1;
             traffic("{");
             traffic("{");
         }
         }
 
 
@@ -158,20 +159,21 @@ void do_proxy(ws_ctx_t *ws_ctx, int target) {
                 fprintf(stderr, "client closed connection\n");
                 fprintf(stderr, "client closed connection\n");
                 break;
                 break;
             }
             }
-            if (tbuf_tmp[bytes-1] != '\xff') {
-                //traffic(".}");
-                fprintf(stderr, "Malformed packet\n");
-                break;
-            }
             if (recordfd) {
             if (recordfd) {
                 write(recordfd, "'", 1);
                 write(recordfd, "'", 1);
                 write(recordfd, tbuf_tmp + 1, bytes - 2);
                 write(recordfd, tbuf_tmp + 1, bytes - 2);
                 write(recordfd, "',\n", 3);
                 write(recordfd, "',\n", 3);
             }
             }
-            tbuf_tmp[bytes-1] = '\0';
-            len = b64_pton(tbuf_tmp+1, tbuf, bufsize-1);
+            len = decode(tbuf_tmp, bytes, tbuf, bufsize-1);
+            /*
+            printf("decoded: ");
+            for (i=0; i< bytes; i++) {
+                printf("%d,", *(tbuf+i));
+            }
+            printf("\n");
+            */
             if (len < 0) {
             if (len < 0) {
-                fprintf(stderr, "base64 decoding error\n");
+                fprintf(stderr, "decoding error\n");
                 break;
                 break;
             }
             }
             traffic("}");
             traffic("}");
@@ -188,11 +190,6 @@ void proxy_handler(ws_ctx_t *ws_ctx) {
 
 
     printf("Connecting to: %s:%d\n", target_host, target_port);
     printf("Connecting to: %s:%d\n", target_host, target_port);
 
 
-    if (client_settings.b64encode) {
-        dbufsize = (bufsize * 3)/4 + 2 - 10; // padding and for good measure
-    } else {
-    }
-
     tsock = socket(AF_INET, SOCK_STREAM, 0);
     tsock = socket(AF_INET, SOCK_STREAM, 0);
     if (tsock < 0) {
     if (tsock < 0) {
         error("Could not create target socket");
         error("Could not create target socket");
@@ -260,7 +257,7 @@ int main(int argc, char *argv[])
     if (! (cbuf_tmp = malloc(bufsize)) )
     if (! (cbuf_tmp = malloc(bufsize)) )
             { fatal("malloc()"); }
             { fatal("malloc()"); }
 
 
-    start_server(listen_port, &proxy_handler, &client_settings);
+    start_server(listen_port, &proxy_handler);
 
 
     free(tbuf);
     free(tbuf);
     free(cbuf);
     free(cbuf);