Răsfoiți Sursa

First pass at working C wsproxy.

No sequence numbering and only support b64encoding at the moment.
Joel Martin 15 ani în urmă
părinte
comite
8e1aa95ba3
5 a modificat fișierele cu 571 adăugiri și 0 ștergeri
  1. 2 0
      .gitignore
  2. 12 0
      Makefile
  3. 266 0
      websocket.c
  4. 22 0
      websocket.h
  5. 269 0
      wsproxy.c

+ 2 - 0
.gitignore

@@ -1 +1,3 @@
 *.pyc
 *.pyc
+*.o
+wsproxy

+ 12 - 0
Makefile

@@ -0,0 +1,12 @@
+wsproxy: wsproxy.o websocket.o
+	$(CC) $^ -l ssl -l resolv -o $@
+
+#websocket.o: websocket.c
+#	$(CC) -c $^ -o $@
+#
+#wsproxy.o: wsproxy.c
+#	$(CC) -c $^ -o $@
+
+clean:
+	rm -f wsproxy wsproxy.o websocket.o
+

+ 266 - 0
websocket.c

@@ -0,0 +1,266 @@
+/*
+ * WebSocket lib with support for "wss://" encryption.
+ *
+ * You can make a cert/key with openssl using:
+ * openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem
+ * as taken from http://docs.python.org/dev/library/ssl.html#certificates
+ */
+#include <stdio.h>
+#include <stdlib.h>
+#include <errno.h>
+#include <strings.h>
+#include <sys/types.h> 
+#include <sys/socket.h>
+#include <netinet/in.h>
+#include <arpa/inet.h>
+#include <openssl/err.h>
+#include <openssl/ssl.h>
+#include "websocket.h"
+
+const char server_handshake[] = "HTTP/1.1 101 Web Socket Protocol Handshake\r\n\
+Upgrade: WebSocket\r\n\
+Connection: Upgrade\r\n\
+WebSocket-Origin: %s\r\n\
+WebSocket-Location: %s://%s%s\r\n\
+WebSocket-Protocol: sample\r\n\
+\r\n";
+
+const char policy_response[] = "<cross-domain-policy><allow-access-from domain=\"*\" to-ports=\"*\" /></cross-domain-policy>\n";
+
+void traffic(char * token) {
+    fprintf(stdout, "%s", token);
+    fflush(stdout);
+}
+
+void error(char *msg)
+{
+    perror(msg);
+}
+
+void fatal(char *msg)
+{
+    perror(msg);
+    exit(1);
+}
+
+/*
+ * SSL Wrapper Code
+ */
+
+/*   Warning: not thread safe */
+int ssl_initialized = 0;
+
+ssize_t ws_recv(ws_ctx_t *ctx, void *buf, size_t len) {
+    if (ctx->ssl) {
+        //printf("SSL recv\n");
+        return SSL_read(ctx->ssl, buf, len);
+    } else {
+        return recv(ctx->sockfd, buf, len, 0);
+    }
+}
+
+ssize_t ws_send(ws_ctx_t *ctx, const void *buf, size_t len) {
+    if (ctx->ssl) {
+        //printf("SSL send\n");
+        return SSL_write(ctx->ssl, buf, len);
+    } else {
+        return send(ctx->sockfd, buf, len, 0);
+    }
+}
+
+ws_ctx_t *ws_socket(int socket) {
+    ws_ctx_t *ctx;
+    ctx = malloc(sizeof(ws_ctx_t));
+    ctx->sockfd = socket;
+    ctx->ssl = NULL;
+    ctx->ssl_ctx = NULL;
+    return ctx;
+}
+
+ws_ctx_t *ws_socket_ssl(int socket, char * certfile) {
+    int ret;
+    char msg[1024];
+    ws_ctx_t *ctx;
+    ctx = ws_socket(socket);
+
+    // Initialize the library
+    if (! ssl_initialized) {
+        SSL_library_init();
+        OpenSSL_add_all_algorithms();
+        SSL_load_error_strings();
+        ssl_initialized = 1;
+
+    }
+
+    ctx->ssl_ctx = SSL_CTX_new(TLSv1_server_method());
+    if (ctx->ssl_ctx == NULL) {
+        ERR_print_errors_fp(stderr);
+        fatal("Failed to configure SSL context");
+    }
+
+    if (SSL_CTX_use_PrivateKey_file(ctx->ssl_ctx, certfile,
+                                     SSL_FILETYPE_PEM) <= 0) {
+        sprintf(msg, "Unable to load private key file %s\n", certfile);
+        fatal(msg);
+    }
+
+    if (SSL_CTX_use_certificate_file(ctx->ssl_ctx, certfile,
+                                     SSL_FILETYPE_PEM) <= 0) {
+        sprintf(msg, "Unable to load certificate file %s\n", certfile);
+        fatal(msg);
+    }
+
+//    if (SSL_CTX_set_cipher_list(ctx->ssl_ctx, "DEFAULT") != 1) {
+//        sprintf(msg, "Unable to set cipher\n");
+//        fatal(msg);
+//    }
+
+    // Associate socket and ssl object
+    ctx->ssl = SSL_new(ctx->ssl_ctx);
+    SSL_set_fd(ctx->ssl, socket);
+
+    ret = SSL_accept(ctx->ssl);
+    if (ret < 0) {
+        ERR_print_errors_fp(stderr);
+        return NULL;
+    }
+
+    return ctx;
+}
+
+int ws_socket_free(ws_ctx_t *ctx) {
+    if (ctx->ssl) {
+        SSL_free(ctx->ssl);
+        ctx->ssl = NULL;
+    }
+    if (ctx->ssl_ctx) {
+        SSL_CTX_free(ctx->ssl_ctx);
+        ctx->ssl_ctx = NULL;
+    }
+    if (ctx->sockfd) {
+        close(ctx->sockfd);
+        ctx->sockfd = 0;
+    }
+    free(ctx);
+}
+
+/* ------------------------------------------------------- */
+
+
+ws_ctx_t *do_handshake(int sock, client_settings_t *client_settings) {
+    char handshake[4096], response[4096];
+    char *scheme, *line, *path, *host, *origin;
+    char *args_start, *args_end, *arg_idx;
+    int len;
+    ws_ctx_t * ws_ctx;
+
+    // Reset settings
+    client_settings->b64encode = 0;
+    client_settings->seq_num = 0;
+
+    len = recv(sock, handshake, 1024, MSG_PEEK);
+    handshake[len] = 0;
+    if (bcmp(handshake, "<policy-file-request/>", 22) == 0) {
+        len = recv(sock, handshake, 1024, 0);
+        handshake[len] = 0;
+        printf("Sending flash policy response\n");
+        send(sock, policy_response, sizeof(policy_response), 0);
+        close(sock);
+        return NULL;
+    } else if (bcmp(handshake, "\x16", 1) == 0) {
+        // SSL
+        ws_ctx = ws_socket_ssl(sock, "self.pem");
+        if (! ws_ctx) { return NULL; }
+        scheme = "wss";
+        printf("Using SSL socket\n");
+    } else {
+        ws_ctx = ws_socket(sock);
+        if (! ws_ctx) { return NULL; }
+        scheme = "ws";
+        printf("Using plain (not SSL) socket\n");
+    }
+    len = ws_recv(ws_ctx, handshake, 4096);
+    handshake[len] = 0;
+    //printf("handshake: %s\n", handshake);
+    if ((len < 92) || (bcmp(handshake, "GET ", 4) != 0)) {
+        fprintf(stderr, "Invalid WS request\n");
+        return NULL;
+    }
+    strtok(handshake, " ");      // Skip "GET "
+    path = strtok(NULL, " ");    // Extract path
+    strtok(NULL, "\n");          // Skip to Upgrade line
+    strtok(NULL, "\n");          // Skip to Connection line
+    strtok(NULL, "\n");          // Skip to Host line
+    strtok(NULL, " ");           // Skip "Host: "
+    host = strtok(NULL, "\r");   // Extract host
+    strtok(NULL, " ");           // Skip "Origin: "
+    origin = strtok(NULL, "\r"); // Extract origin
+
+    //printf("path: %s\n", path);
+    //printf("host: %s\n", host);
+    //printf("origin: %s\n", origin);
+    
+    // TODO: parse out client settings
+    args_start = strstr(path, "?");
+    if (args_start) {
+        if (strstr(args_start, "#")) {
+            args_end = strstr(args_start, "#");
+        } else {
+            args_end = args_start + strlen(args_start);
+        }
+        arg_idx = strstr(args_start, "b64encode");
+        if (arg_idx && arg_idx < args_end) {
+            //printf("setting b64encode\n");
+            client_settings->b64encode = 1;
+        }
+        arg_idx = strstr(args_start, "seq_num");
+        if (arg_idx && arg_idx < args_end) {
+            //printf("setting seq_num\n");
+            client_settings->seq_num = 1;
+        }
+    }
+
+    sprintf(response, server_handshake, origin, scheme, host, path);
+    printf("response: %s\n", response);
+    ws_send(ws_ctx, response, strlen(response));
+
+    return ws_ctx;
+}
+
+void start_server(int listen_port,
+                  void (*handler)(ws_ctx_t*),
+                  client_settings_t *client_settings) {
+    int lsock, csock, clilen, sopt = 1;
+    struct sockaddr_in serv_addr, cli_addr;
+    ws_ctx_t *ws_ctx;
+
+    lsock = socket(AF_INET, SOCK_STREAM, 0);
+    if (lsock < 0) { error("ERROR creating listener socket"); }
+    bzero((char *) &serv_addr, sizeof(serv_addr));
+    serv_addr.sin_family = AF_INET;
+    serv_addr.sin_addr.s_addr = INADDR_ANY;
+    serv_addr.sin_port = htons(listen_port);
+    setsockopt(lsock, SOL_SOCKET, SO_REUSEADDR, (char *)&sopt, sizeof(sopt));
+    if (bind(lsock, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) {
+        error("ERROR on binding listener socket");
+    }
+    listen(lsock,100);
+
+    while (1) {
+        clilen = sizeof(cli_addr);
+        printf("waiting for connection on port %d\n", listen_port);
+        csock = accept(lsock, 
+                       (struct sockaddr *) &cli_addr, 
+                       &clilen);
+        if (csock < 0) {
+            error("ERROR on accept");
+        }
+        printf("Got client connection from %s\n", inet_ntoa(cli_addr.sin_addr));
+        ws_ctx = do_handshake(csock, client_settings);
+        if (ws_ctx == NULL) { continue; }
+        handler(ws_ctx);
+        close(csock);
+    }
+
+}
+

+ 22 - 0
websocket.h

@@ -0,0 +1,22 @@
+#include <openssl/ssl.h>
+
+typedef struct {
+    int      sockfd;
+    SSL_CTX *ssl_ctx;
+    SSL     *ssl;
+} ws_ctx_t;
+
+typedef struct {
+    int b64encode;
+    int seq_num;
+} client_settings_t;
+
+
+ssize_t ws_recv(ws_ctx_t *ctx, void *buf, size_t len);
+
+ssize_t ws_send(ws_ctx_t *ctx, const void *buf, size_t len);
+
+/* base64.c declarations */
+//int b64_ntop(u_char const *src, size_t srclength, char *target, size_t targsize);
+//int b64_pton(char const *src, u_char *target, size_t targsize);
+

+ 269 - 0
wsproxy.c

@@ -0,0 +1,269 @@
+/*
+ * A WebSocket to TCP socket proxy with support for "wss://" encryption.
+ *
+ * You can make a cert/key with openssl using:
+ * openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem
+ * as taken from http://docs.python.org/dev/library/ssl.html#certificates
+ */
+#include <stdio.h>
+#include <errno.h>
+#include <sys/socket.h>
+#include <netinet/in.h>
+#include <netdb.h>
+#include <sys/select.h>
+#include <resolv.h>
+#include <fcntl.h>
+#include <sys/stat.h>
+#include "websocket.h"
+
+char traffic_legend[] = "\n\
+Traffic Legend:\n\
+    }  - Client receive\n\
+    }. - Client receive partial\n\
+    {  - Target receive\n\
+\n\
+    >  - Target send\n\
+    >. - Target send partial\n\
+    <  - Client send\n\
+    <. - Client send partial\n\
+";
+
+void usage() {
+    fprintf(stderr,"Usage: <listen_port> <target_host> <target_port>\n");
+    exit(1);
+}
+
+char *target_host;
+int target_port;
+client_settings_t client_settings;
+char *record_filename = NULL;
+int recordfd = 0;
+char *tbuf, *cbuf, *tbuf_tmp, *cbuf_tmp;
+unsigned int bufsize, dbufsize;
+
+void do_proxy(ws_ctx_t *ws_ctx, int target) {
+    fd_set rlist, wlist, elist;
+    struct timeval tv;
+    int maxfd, client = ws_ctx->sockfd;
+    unsigned int tstart, tend, cstart, cend, ret;
+    ssize_t len, bytes;
+
+    tstart = tend = cstart = cend = 0;
+    maxfd = client > target ? client+1 : target+1;
+    // Account for base64 encoding and WebSocket delims:
+    //     49150 = 65536 * 3/4 + 2 - 1
+
+    while (1) {
+        tv.tv_sec = 1;
+        tv.tv_usec = 0;
+
+        FD_ZERO(&rlist);
+        FD_ZERO(&wlist);
+        FD_ZERO(&elist);
+
+        FD_SET(client, &elist);
+        FD_SET(target, &elist);
+
+        if (tend == tstart) {
+            // Nothing queued for target, so read from client
+            FD_SET(client, &rlist);
+        } else {
+            // Data queued for target, so write to it
+            FD_SET(target, &wlist);
+        }
+        if (cend == cstart) {
+            // Nothing queued for client, so read from target
+            FD_SET(target, &rlist);
+        } else {
+            // Data queued for client, so write to it
+            FD_SET(client, &wlist);
+        }
+
+        ret = select(maxfd, &rlist, &wlist, &elist, &tv);
+
+        if (FD_ISSET(target, &elist)) {
+            fprintf(stderr, "target exception\n");
+            break;
+        }
+        if (FD_ISSET(client, &elist)) {
+            fprintf(stderr, "client exception\n");
+            break;
+        }
+
+        if (ret == -1) {
+            error("select()");
+            break;
+        } else if (ret == 0) {
+            //fprintf(stderr, "select timeout\n");
+            continue;
+        }
+
+        if (FD_ISSET(target, &wlist)) {
+            len = tend-tstart;
+            bytes = send(target, tbuf + tstart, len, 0);
+            if (bytes < 0) {
+                error("target connection error");
+                break;
+            }
+            tstart += bytes;
+            if (tstart >= tend) {
+                tstart = tend = 0;
+                traffic(">");
+            } else {
+                traffic(">.");
+            }
+        }
+
+        if (FD_ISSET(client, &wlist)) {
+            len = cend-cstart;
+            bytes = ws_send(ws_ctx, cbuf + cstart, len);
+            if (len < 3) {
+                fprintf(stderr, "len: %d, bytes: %d: %d\n", len, bytes, *(cbuf + cstart));
+            }
+            cstart += bytes;
+            if (cstart >= cend) {
+                cstart = cend = 0;
+                traffic("<");
+                if (recordfd) {
+                    write(recordfd, "'>", 2);
+                    write(recordfd, cbuf + cstart + 1, bytes - 2);
+                    write(recordfd, "',\n", 3);
+                }
+            } else {
+                traffic("<.");
+            }
+        }
+
+        if (FD_ISSET(target, &rlist)) {
+            bytes = recv(target, cbuf_tmp, dbufsize , 0);
+            if (bytes <= 0) {
+                error("target closed connection");
+                break;
+            }
+            cbuf[0] = '\x00';
+            cstart = 0;
+            len = b64_ntop(cbuf_tmp, bytes, cbuf+1, bufsize-1);
+            if (len < 0) {
+                fprintf(stderr, "base64 encoding error\n");
+                break;
+            }
+            cbuf[len+1] = '\xff';
+            cend = len+1+1;
+            traffic("{");
+        }
+
+        if (FD_ISSET(client, &rlist)) {
+            bytes = ws_recv(ws_ctx, tbuf_tmp, bufsize-1);
+            if (bytes <= 0) {
+                fprintf(stderr, "client closed connection\n");
+                break;
+            }
+            if (tbuf_tmp[bytes-1] != '\xff') {
+                //traffic(".}");
+                fprintf(stderr, "Malformed packet\n");
+                break;
+            }
+            if (recordfd) {
+                write(recordfd, "'", 1);
+                write(recordfd, tbuf_tmp + 1, bytes - 2);
+                write(recordfd, "',\n", 3);
+            }
+            tbuf_tmp[bytes-1] = '\0';
+            len = b64_pton(tbuf_tmp+1, tbuf, bufsize-1);
+            if (len < 0) {
+                fprintf(stderr, "base64 decoding error\n");
+                break;
+            }
+            traffic("}");
+            tstart = 0;
+            tend = len;
+        }
+    }
+}
+
+void proxy_handler(ws_ctx_t *ws_ctx) {
+    int tsock = 0;
+    struct sockaddr_in taddr;
+    struct hostent *thost;
+
+    printf("Connecting to: %s:%d\n", target_host, target_port);
+
+    if (client_settings.b64encode) {
+        dbufsize = (bufsize * 3)/4 + 2 - 10; // padding and for good measure
+    } else {
+    }
+
+    tsock = socket(AF_INET, SOCK_STREAM, 0);
+    if (tsock < 0) {
+        error("Could not create target socket");
+        return;
+    }
+    thost = gethostbyname(target_host);
+    if (thost == NULL) {
+        error("Could not resolve server");
+        close(tsock);
+        return;
+    }
+    bzero((char *) &taddr, sizeof(taddr));
+    taddr.sin_family = AF_INET;
+    bcopy((char *) thost->h_addr,
+          (char *) &taddr.sin_addr.s_addr,
+          thost->h_length);
+    taddr.sin_port = htons(target_port);
+
+    if (connect(tsock, (struct sockaddr *) &taddr, sizeof(taddr)) < 0) {
+        error("Could not connect to target");
+        close(tsock);
+        return;
+    }
+
+    if (record_filename) {
+        recordfd = open(record_filename, O_WRONLY | O_CREAT | O_TRUNC,
+                        S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH);
+    }
+
+    printf("%s", traffic_legend);
+
+    do_proxy(ws_ctx, tsock);
+
+    close(tsock);
+    if (recordfd) {
+        close(recordfd);
+        recordfd = 0;
+    }
+}
+
+int main(int argc, char *argv[])
+{
+    int listen_port, idx=1;
+
+    if (strcmp(argv[idx], "--record") == 0) {
+        idx++;
+        record_filename = argv[idx++];
+    }
+
+    if ((argc-idx) != 3) { usage(); }
+    listen_port = strtol(argv[idx++], NULL, 10);
+    if (errno != 0) { usage(); }
+    target_host = argv[idx++];
+    target_port = strtol(argv[idx++], NULL, 10);
+    if (errno != 0) { usage(); }
+
+    /* Initialize buffers */
+    bufsize = 65536;
+    if (! (tbuf = malloc(bufsize)) )
+            { fatal("malloc()"); }
+    if (! (cbuf = malloc(bufsize)) )
+            { fatal("malloc()"); }
+    if (! (tbuf_tmp = malloc(bufsize)) )
+            { fatal("malloc()"); }
+    if (! (cbuf_tmp = malloc(bufsize)) )
+            { fatal("malloc()"); }
+
+    start_server(listen_port, &proxy_handler, &client_settings);
+
+    free(tbuf);
+    free(cbuf);
+    free(tbuf_tmp);
+    free(cbuf_tmp);
+}