Browse Source

Add daemonization support to wsproxy.*.

Refactor how settings are passed around.
Joel Martin 15 years ago
parent
commit
6ee61a4cf6
7 changed files with 230 additions and 85 deletions
  1. 4 1
      tests/ws.py
  2. 4 1
      tests/wsencoding.py
  3. 68 23
      utils/websocket.c
  4. 11 0
      utils/websocket.h
  5. 73 23
      utils/websocket.py
  6. 50 27
      utils/wsproxy.c
  7. 20 10
      utils/wsproxy.py

+ 4 - 1
tests/ws.py

@@ -159,4 +159,7 @@ if __name__ == '__main__':
     for i in range(0, 100000):
     for i in range(0, 100000):
         rand_array.append(random.randint(0, 9))
         rand_array.append(random.randint(0, 9))
 
 
-    start_server(listen_port, test_handler)
+    settings['listen_port'] = listen_port
+    settings['daemon'] = False
+    settings['handler'] = test_handler
+    start_server()

+ 4 - 1
tests/wsencoding.py

@@ -81,4 +81,7 @@ if __name__ == '__main__':
         print "Usage: <listen_port>"
         print "Usage: <listen_port>"
         sys.exit(1)
         sys.exit(1)
 
 
-    start_server(listen_port, responder)
+    settings['listen_port'] = listen_port
+    settings['daemon'] = False
+    settings['handler'] = responder
+    start_server()

+ 68 - 23
utils/websocket.c

@@ -14,6 +14,8 @@
 #include <netinet/in.h>
 #include <netinet/in.h>
 #include <arpa/inet.h>
 #include <arpa/inet.h>
 #include <netdb.h>
 #include <netdb.h>
+#include <signal.h> // daemonizing
+#include <fcntl.h>  // daemonizing
 #include <openssl/err.h>
 #include <openssl/err.h>
 #include <openssl/ssl.h>
 #include <openssl/ssl.h>
 #include <resolv.h>      /* base64 encode/decode */
 #include <resolv.h>      /* base64 encode/decode */
@@ -37,6 +39,7 @@ const char policy_response[] = "<cross-domain-policy><allow-access-from domain=\
 int ssl_initialized = 0;
 int ssl_initialized = 0;
 char *tbuf, *cbuf, *tbuf_tmp, *cbuf_tmp;
 char *tbuf, *cbuf, *tbuf_tmp, *cbuf_tmp;
 unsigned int bufsize, dbufsize;
 unsigned int bufsize, dbufsize;
+settings_t settings;
 client_settings_t client_settings;
 client_settings_t client_settings;
 
 
 void traffic(char * token) {
 void traffic(char * token) {
@@ -269,7 +272,7 @@ int decode(char *src, size_t srclength, u_char *target, size_t targsize) {
     return retlen;
     return retlen;
 }
 }
 
 
-ws_ctx_t *do_handshake(int sock, int ssl_only) {
+ws_ctx_t *do_handshake(int sock) {
     char handshake[4096], response[4096];
     char handshake[4096], response[4096];
     char *scheme, *line, *path, *host, *origin;
     char *scheme, *line, *path, *host, *origin;
     char *args_start, *args_end, *arg_idx;
     char *args_start, *args_end, *arg_idx;
@@ -281,6 +284,7 @@ ws_ctx_t *do_handshake(int sock, int ssl_only) {
     client_settings.do_seq_num = 0;
     client_settings.do_seq_num = 0;
     client_settings.seq_num = 0;
     client_settings.seq_num = 0;
 
 
+    // Peek, but don't read the data
     len = recv(sock, handshake, 1024, MSG_PEEK);
     len = recv(sock, handshake, 1024, MSG_PEEK);
     handshake[len] = 0;
     handshake[len] = 0;
     if (bcmp(handshake, "<policy-file-request/>", 22) == 0) {
     if (bcmp(handshake, "<policy-file-request/>", 22) == 0) {
@@ -292,11 +296,11 @@ ws_ctx_t *do_handshake(int sock, int ssl_only) {
         return NULL;
         return NULL;
     } else if (bcmp(handshake, "\x16", 1) == 0) {
     } else if (bcmp(handshake, "\x16", 1) == 0) {
         // SSL
         // SSL
-        ws_ctx = ws_socket_ssl(sock, "self.pem");
+        ws_ctx = ws_socket_ssl(sock, settings.cert);
         if (! ws_ctx) { return NULL; }
         if (! ws_ctx) { return NULL; }
         scheme = "wss";
         scheme = "wss";
-        printf("Using SSL socket\n");
-    } else if (ssl_only) {
+        printf("  using SSL socket\n");
+    } else if (settings.ssl_only) {
         printf("Non-SSL connection disallowed");
         printf("Non-SSL connection disallowed");
         close(sock);
         close(sock);
         return NULL;
         return NULL;
@@ -304,7 +308,7 @@ ws_ctx_t *do_handshake(int sock, int ssl_only) {
         ws_ctx = ws_socket(sock);
         ws_ctx = ws_socket(sock);
         if (! ws_ctx) { return NULL; }
         if (! ws_ctx) { return NULL; }
         scheme = "ws";
         scheme = "ws";
-        printf("Using plain (not SSL) socket\n");
+        printf("  using plain (not SSL) socket\n");
     }
     }
     len = ws_recv(ws_ctx, handshake, 4096);
     len = ws_recv(ws_ctx, handshake, 4096);
     handshake[len] = 0;
     handshake[len] = 0;
@@ -327,7 +331,7 @@ ws_ctx_t *do_handshake(int sock, int ssl_only) {
     //printf("host: %s\n", host);
     //printf("host: %s\n", host);
     //printf("origin: %s\n", origin);
     //printf("origin: %s\n", origin);
     
     
-    // TODO: parse out client settings
+    // Parse client settings from the GET path
     args_start = strstr(path, "?");
     args_start = strstr(path, "?");
     if (args_start) {
     if (args_start) {
         if (strstr(args_start, "#")) {
         if (strstr(args_start, "#")) {
@@ -337,31 +341,70 @@ ws_ctx_t *do_handshake(int sock, int ssl_only) {
         }
         }
         arg_idx = strstr(args_start, "b64encode");
         arg_idx = strstr(args_start, "b64encode");
         if (arg_idx && arg_idx < args_end) {
         if (arg_idx && arg_idx < args_end) {
-            //printf("setting b64encode\n");
+            printf("  b64encode=1\n");
             client_settings.do_b64encode = 1;
             client_settings.do_b64encode = 1;
         }
         }
         arg_idx = strstr(args_start, "seq_num");
         arg_idx = strstr(args_start, "seq_num");
         if (arg_idx && arg_idx < args_end) {
         if (arg_idx && arg_idx < args_end) {
-            //printf("setting seq_num\n");
+            printf("  seq_num=1\n");
             client_settings.do_seq_num = 1;
             client_settings.do_seq_num = 1;
         }
         }
     }
     }
 
 
     sprintf(response, server_handshake, origin, scheme, host, path);
     sprintf(response, server_handshake, origin, scheme, host, path);
-    printf("response: %s\n", response);
+    //printf("response: %s\n", response);
     ws_send(ws_ctx, response, strlen(response));
     ws_send(ws_ctx, response, strlen(response));
 
 
     return ws_ctx;
     return ws_ctx;
 }
 }
 
 
-void start_server(int listen_port,
-                  void (*handler)(ws_ctx_t*),
-                  char *listen_host,
-                  int ssl_only) {
+void signal_handler(sig) {
+    switch (sig) {
+        case SIGHUP: break; // ignore
+        case SIGTERM: exit(0); break;
+    }
+}
+
+void daemonize() {
+    int pid, i;
+
+    umask(0);
+    chdir('/');
+    setgid(getgid());
+    setuid(getuid());
+
+    /* Double fork to daemonize */
+    pid = fork();
+    if (pid<0) { fatal("fork error"); }
+    if (pid>0) { exit(0); }  // parent exits
+    setsid();                // Obtain new process group
+    pid = fork();
+    if (pid<0) { fatal("fork error"); }
+    if (pid>0) { exit(0); }  // parent exits
+
+    /* Signal handling */
+    signal(SIGHUP, signal_handler);   // catch HUP
+    signal(SIGTERM, signal_handler);  // catch kill
+
+    /* Close open files */
+    for (i=getdtablesize(); i>=0; --i) {
+        close(i);
+    }
+    i=open("/dev/null", O_RDWR);  // Redirect stdin
+    dup(i);                       // Redirect stdout
+    dup(i);                       // Redirect stderr
+}
+
+
+void start_server() {
     int lsock, csock, clilen, sopt = 1, i;
     int lsock, csock, clilen, sopt = 1, i;
     struct sockaddr_in serv_addr, cli_addr;
     struct sockaddr_in serv_addr, cli_addr;
     ws_ctx_t *ws_ctx;
     ws_ctx_t *ws_ctx;
 
 
+    if (settings.daemon) {
+        daemonize();
+    }
+
     /* Initialize buffers */
     /* Initialize buffers */
     bufsize = 65536;
     bufsize = 65536;
     if (! (tbuf = malloc(bufsize)) )
     if (! (tbuf = malloc(bufsize)) )
@@ -377,15 +420,15 @@ void start_server(int listen_port,
     if (lsock < 0) { error("ERROR creating listener socket"); }
     if (lsock < 0) { error("ERROR creating listener socket"); }
     bzero((char *) &serv_addr, sizeof(serv_addr));
     bzero((char *) &serv_addr, sizeof(serv_addr));
     serv_addr.sin_family = AF_INET;
     serv_addr.sin_family = AF_INET;
-    serv_addr.sin_port = htons(listen_port);
+    serv_addr.sin_port = htons(settings.listen_port);
 
 
     /* Resolve listen address */
     /* Resolve listen address */
-    if ((listen_host == NULL) || (listen_host[0] == '\0')) {
-        serv_addr.sin_addr.s_addr = INADDR_ANY;
-    } else {
-        if (resolve_host(&serv_addr.sin_addr, listen_host) < -1) {
+    if (settings.listen_host && (settings.listen_host[0] != '\0')) {
+        if (resolve_host(&serv_addr.sin_addr, settings.listen_host) < -1) {
             fatal("Could not resolve listen address");
             fatal("Could not resolve listen address");
         }
         }
+    } else {
+        serv_addr.sin_addr.s_addr = INADDR_ANY;
     }
     }
 
 
     setsockopt(lsock, SOL_SOCKET, SO_REUSEADDR, (char *)&sopt, sizeof(sopt));
     setsockopt(lsock, SOL_SOCKET, SO_REUSEADDR, (char *)&sopt, sizeof(sopt));
@@ -396,10 +439,12 @@ void start_server(int listen_port,
 
 
     while (1) {
     while (1) {
         clilen = sizeof(cli_addr);
         clilen = sizeof(cli_addr);
-        if (listen_host) {
-            printf("waiting for connection on %s:%d\n", listen_host, listen_port);
+        if (settings.listen_host && settings.listen_host[0] != '\0') {
+            printf("waiting for connection on %s:%d\n",
+                   settings.listen_host, settings.listen_port);
         } else {
         } else {
-            printf("waiting for connection on port %d\n", listen_port);
+            printf("waiting for connection on port %d\n",
+                   settings.listen_port);
         }
         }
         csock = accept(lsock, 
         csock = accept(lsock, 
                        (struct sockaddr *) &cli_addr, 
                        (struct sockaddr *) &cli_addr, 
@@ -409,7 +454,7 @@ void start_server(int listen_port,
             continue;
             continue;
         }
         }
         printf("Got client connection from %s\n", inet_ntoa(cli_addr.sin_addr));
         printf("Got client connection from %s\n", inet_ntoa(cli_addr.sin_addr));
-        ws_ctx = do_handshake(csock, ssl_only);
+        ws_ctx = do_handshake(csock);
         if (ws_ctx == NULL) {
         if (ws_ctx == NULL) {
             close(csock);
             close(csock);
             continue;
             continue;
@@ -425,7 +470,7 @@ void start_server(int listen_port,
             dbufsize = (bufsize/2) - 20;
             dbufsize = (bufsize/2) - 20;
         }
         }
 
 
-        handler(ws_ctx);
+        settings.handler(ws_ctx);
         close(csock);
         close(csock);
     }
     }
 
 

+ 11 - 0
utils/websocket.h

@@ -1,4 +1,5 @@
 #include <openssl/ssl.h>
 #include <openssl/ssl.h>
+#include <unistd.h>
 
 
 typedef struct {
 typedef struct {
     int      sockfd;
     int      sockfd;
@@ -6,6 +7,16 @@ typedef struct {
     SSL     *ssl;
     SSL     *ssl;
 } ws_ctx_t;
 } ws_ctx_t;
 
 
+typedef struct {
+    char listen_host[256];
+    int listen_port;
+    void (*handler)(ws_ctx_t*);
+    int ssl_only;
+    int daemon;
+    char record[1024];
+    char cert[1024];
+} settings_t;
+
 typedef struct {
 typedef struct {
     int do_b64encode;
     int do_b64encode;
     int do_seq_num;
     int do_seq_num;

+ 73 - 23
utils/websocket.py

@@ -10,9 +10,21 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates
 '''
 '''
 
 
 import sys, socket, ssl, traceback
 import sys, socket, ssl, traceback
+import os, resource, errno, signal # daemonizing
 from base64 import b64encode, b64decode
 from base64 import b64encode, b64decode
 
 
-client_settings = {}
+settings = {
+    'listen_host' : '',
+    'listen_port' : None,
+    'handler'     : None,
+    'cert'        : None,
+    'ssl_only'    : False,
+    'daemon'      : True,
+    'record'      : None, }
+client_settings = {
+    'b64encode'   : False,
+    'seq_num'     : False, }
+
 send_seq = 0
 send_seq = 0
 
 
 server_handshake = """HTTP/1.1 101 Web Socket Protocol Handshake\r
 server_handshake = """HTTP/1.1 101 Web Socket Protocol Handshake\r
@@ -33,35 +45,39 @@ def traffic(token="."):
 def decode(buf):
 def decode(buf):
     """ Parse out WebSocket packets. """
     """ Parse out WebSocket packets. """
     if buf.count('\xff') > 1:
     if buf.count('\xff') > 1:
-        if client_settings["b64encode"]:
+        if client_settings['b64encode']:
             return [b64decode(d[1:]) for d in buf.split('\xff')]
             return [b64decode(d[1:]) for d in buf.split('\xff')]
         else:
         else:
             # Modified UTF-8 decode
             # Modified UTF-8 decode
             return [d[1:].replace("\xc4\x80", "\x00").decode('utf-8').encode('latin-1') for d in buf.split('\xff')]
             return [d[1:].replace("\xc4\x80", "\x00").decode('utf-8').encode('latin-1') for d in buf.split('\xff')]
     else:
     else:
-        if client_settings["b64encode"]:
+        if client_settings['b64encode']:
             return [b64decode(buf[1:-1])]
             return [b64decode(buf[1:-1])]
         else:
         else:
             return [buf[1:-1].replace("\xc4\x80", "\x00").decode('utf-8').encode('latin-1')]
             return [buf[1:-1].replace("\xc4\x80", "\x00").decode('utf-8').encode('latin-1')]
 
 
 def encode(buf):
 def encode(buf):
     global send_seq
     global send_seq
-    if client_settings["b64encode"]:
+    if client_settings['b64encode']:
         buf = b64encode(buf)
         buf = b64encode(buf)
     else:
     else:
         # Modified UTF-8 encode
         # Modified UTF-8 encode
         buf = buf.decode('latin-1').encode('utf-8').replace("\x00", "\xc4\x80")
         buf = buf.decode('latin-1').encode('utf-8').replace("\x00", "\xc4\x80")
 
 
-    if client_settings["seq_num"]:
+    if client_settings['seq_num']:
         send_seq += 1
         send_seq += 1
         return "\x00%d:%s\xff" % (send_seq-1, buf)
         return "\x00%d:%s\xff" % (send_seq-1, buf)
     else:
     else:
         return "\x00%s\xff" % buf
         return "\x00%s\xff" % buf
 
 
 
 
-def do_handshake(sock, ssl_only=False):
+def do_handshake(sock):
     global client_settings, send_seq
     global client_settings, send_seq
+
+    client_settings['b64encode'] = False
+    client_settings['seq_num'] = False
     send_seq = 0
     send_seq = 0
+
     # Peek, but don't read the data
     # Peek, but don't read the data
     handshake = sock.recv(1024, socket.MSG_PEEK)
     handshake = sock.recv(1024, socket.MSG_PEEK)
     #print "Handshake [%s]" % repr(handshake)
     #print "Handshake [%s]" % repr(handshake)
@@ -75,54 +91,88 @@ def do_handshake(sock, ssl_only=False):
         retsock = ssl.wrap_socket(
         retsock = ssl.wrap_socket(
                 sock,
                 sock,
                 server_side=True,
                 server_side=True,
-                certfile='self.pem',
+                certfile=settings['cert'],
                 ssl_version=ssl.PROTOCOL_TLSv1)
                 ssl_version=ssl.PROTOCOL_TLSv1)
         scheme = "wss"
         scheme = "wss"
-        print "Using SSL/TLS"
-    elif ssl_only:
+        print "  using SSL/TLS"
+    elif settings['ssl_only']:
         print "Non-SSL connection disallowed"
         print "Non-SSL connection disallowed"
         sock.close()
         sock.close()
         return False
         return False
     else:
     else:
         retsock = sock
         retsock = sock
         scheme = "ws"
         scheme = "ws"
-        print "Using plain (not SSL) socket"
+        print "  using plain (not SSL) socket"
     handshake = retsock.recv(4096)
     handshake = retsock.recv(4096)
     req_lines = handshake.split("\r\n")
     req_lines = handshake.split("\r\n")
     _, path, _ = req_lines[0].split(" ")
     _, path, _ = req_lines[0].split(" ")
     _, origin = req_lines[4].split(" ")
     _, origin = req_lines[4].split(" ")
     _, host = req_lines[3].split(" ")
     _, host = req_lines[3].split(" ")
 
 
-    # Parse settings from the path
+    # Parse client settings from the GET path
     cvars = path.partition('?')[2].partition('#')[0].split('&')
     cvars = path.partition('?')[2].partition('#')[0].split('&')
-    client_settings = {'b64encode': None, 'seq_num': None}
     for cvar in [c for c in cvars if c]:
     for cvar in [c for c in cvars if c]:
-        name, _, value = cvar.partition('=')
-        client_settings[name] = value and value or True
-
-    print "client_settings:", client_settings
+        name, _, val = cvar.partition('=')
+        if name not in ['b64encode', 'seq_num']: continue
+        value = val and val or True
+        client_settings[name] = value
+        print "  %s=%s" % (name, value)
 
 
     retsock.send(server_handshake % (origin, scheme, host, path))
     retsock.send(server_handshake % (origin, scheme, host, path))
     return retsock
     return retsock
 
 
-def start_server(listen_port, handler, listen_host='', ssl_only=False):
+def daemonize():
+    os.umask(0)
+    os.chdir('/')
+    os.setgid(os.getgid())  # relinquish elevations
+    os.setuid(os.getuid())  # relinquish elevations
+
+    # Double fork to daemonize
+    if os.fork() > 0: os._exit(0)  # Parent exits
+    os.setsid()                    # Obtain new process group
+    if os.fork() > 0: os._exit(0)  # Parent exits
+
+    # Signal handling
+    def terminate(a,b): os._exit(0)
+    signal.signal(signal.SIGTERM, terminate)
+    signal.signal(signal.SIGINT, signal.SIG_IGN)
+
+    # Close open files
+    maxfd = resource.getrlimit(resource.RLIMIT_NOFILE)[1]
+    if maxfd == resource.RLIM_INFINITY: maxfd = 256
+    for fd in reversed(range(maxfd)):
+        try:
+            os.close(fd)
+        except OSError, exc:
+            if exc.errno != errno.EBADF: raise
+
+    # Redirect I/O to /dev/null
+    os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdin.fileno())
+    os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdout.fileno())
+    os.dup2(os.open(os.devnull, os.O_RDWR), sys.stderr.fileno())
+
+
+def start_server():
+
+    if settings['daemon']: daemonize()
+
     lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
     lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
     lsock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
     lsock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
-    lsock.bind((listen_host, listen_port))
+    lsock.bind((settings['listen_host'], settings['listen_port']))
     lsock.listen(100)
     lsock.listen(100)
     while True:
     while True:
         try:
         try:
-            csock = None
-            print 'waiting for connection on port %s' % listen_port
+            csock = startsock = None
+            print 'waiting for connection on port %s' % settings['listen_port']
             startsock, address = lsock.accept()
             startsock, address = lsock.accept()
             print 'Got client connection from %s' % address[0]
             print 'Got client connection from %s' % address[0]
-            csock = do_handshake(startsock, ssl_only=ssl_only)
+            csock = do_handshake(startsock)
             if not csock: continue
             if not csock: continue
 
 
-            handler(csock)
+            settings['handler'](csock)
 
 
         except Exception:
         except Exception:
             print "Ignoring exception:"
             print "Ignoring exception:"
             print traceback.format_exc()
             print traceback.format_exc()
             if csock: csock.close()
             if csock: csock.close()
-
+            if startsock and startsock != csock: startsock.close()

+ 50 - 27
utils/wsproxy.c

@@ -36,11 +36,11 @@ void usage() {
     exit(1);
     exit(1);
 }
 }
 
 
-char *target_host;
+char target_host[256];
 int target_port;
 int target_port;
-char *record_filename = NULL;
 int recordfd = 0;
 int recordfd = 0;
 
 
+extern settings_t settings;
 extern char *tbuf, *cbuf, *tbuf_tmp, *cbuf_tmp;
 extern char *tbuf, *cbuf, *tbuf_tmp, *cbuf_tmp;
 extern unsigned int bufsize, dbufsize;
 extern unsigned int bufsize, dbufsize;
 
 
@@ -198,6 +198,11 @@ void proxy_handler(ws_ctx_t *ws_ctx) {
     int tsock = 0;
     int tsock = 0;
     struct sockaddr_in taddr;
     struct sockaddr_in taddr;
 
 
+    if (settings.record) {
+        recordfd = open(settings.record, O_WRONLY | O_CREAT | O_TRUNC,
+                        S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH);
+    }
+
     printf("Connecting to: %s:%d\n", target_host, target_port);
     printf("Connecting to: %s:%d\n", target_host, target_port);
 
 
     tsock = socket(AF_INET, SOCK_STREAM, 0);
     tsock = socket(AF_INET, SOCK_STREAM, 0);
@@ -220,11 +225,6 @@ void proxy_handler(ws_ctx_t *ws_ctx) {
         return;
         return;
     }
     }
 
 
-    if (record_filename) {
-        recordfd = open(record_filename, O_WRONLY | O_CREAT | O_TRUNC,
-                        S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH);
-    }
-
     printf("%s", traffic_legend);
     printf("%s", traffic_legend);
 
 
     do_proxy(ws_ctx, tsock);
     do_proxy(ws_ctx, tsock);
@@ -239,52 +239,74 @@ void proxy_handler(ws_ctx_t *ws_ctx) {
 int main(int argc, char *argv[])
 int main(int argc, char *argv[])
 {
 {
     int listen_port, c, option_index = 0;
     int listen_port, c, option_index = 0;
-    static int ssl_only = 0;
-    char *listen_host;
+    static int ssl_only = 0, foreground = 0;
+    char *found;
     static struct option long_options[] = {
     static struct option long_options[] = {
-        {"ssl-only", no_argument, &ssl_only, 1},
+        {"ssl-only",   no_argument,       &ssl_only,    1 },
+        {"foreground", no_argument,       &foreground, 'f'},
         /* ---- */
         /* ---- */
-        {"record",   required_argument, 0, 'r'},
+        {"record",     required_argument, 0,           'r'},
+        {"cert",       required_argument, 0,           'c'},
         {0, 0, 0, 0}
         {0, 0, 0, 0}
     };
     };
 
 
+    settings.record[0] = '\0';
+    strcpy(settings.cert, "self.pem");
+
     while (1) {
     while (1) {
-        c = getopt_long (argc, argv, "r:",
+        c = getopt_long (argc, argv, "fr:c:",
                          long_options, &option_index);
                          long_options, &option_index);
 
 
         /* Detect the end */
         /* Detect the end */
         if (c == -1) { break; }
         if (c == -1) { break; }
 
 
         switch (c) {
         switch (c) {
-            case 0:    break; // ignore
-            case 1:    break; // ignore
-            case 'r':  record_filename = optarg; break;
-            default:   usage();
+            case 0:
+                break; // ignore
+            case 1:
+                break; // ignore
+            case 'f':
+                foreground = 1;
+                break;
+            case 'r':
+                memcpy(settings.record, optarg, sizeof(settings.record));
+                break;
+            case 'c':
+                memcpy(settings.cert, optarg, sizeof(settings.cert));
+                break;
+            default:
+                usage();
         }
         }
     }
     }
+    settings.ssl_only  = ssl_only;
+    settings.daemon    = foreground ? 0: 1;
 
 
-    printf("ssl_only: %d\n", ssl_only);
-    printf("record_filename: %s\n", record_filename);
+    printf("  ssl_only: %d\n", settings.ssl_only);
+    printf("  daemon: %d\n",   settings.daemon);
+    printf("  record: %s\n",   settings.record);
+    printf("  cert: %s\n",     settings.cert);
 
 
     if ((argc-optind) != 2) {
     if ((argc-optind) != 2) {
         usage();
         usage();
     }
     }
 
 
-    if (strstr(argv[optind], ":")) {
-        listen_host = strtok(argv[optind], ":");
-        listen_port = strtol(strtok(NULL, ":"), NULL, 10);
+    found = strstr(argv[optind], ":");
+    if (found) {
+        memcpy(settings.listen_host, argv[optind], found-argv[optind]);
+        settings.listen_port = strtol(found+1, NULL, 10);
     } else {
     } else {
-        listen_host = NULL;
-        listen_port = strtol(argv[optind], NULL, 10);
+        settings.listen_host[0] = '\0';
+        settings.listen_port = strtol(argv[optind], NULL, 10);
     }
     }
     optind++;
     optind++;
     if ((errno != 0) || (listen_port == 0)) {
     if ((errno != 0) || (listen_port == 0)) {
         usage();
         usage();
     }
     }
 
 
-    if (strstr(argv[optind], ":")) {
-        target_host = strtok(argv[optind], ":");
-        target_port = strtol(strtok(NULL, ":"), NULL, 10);
+    found = strstr(argv[optind], ":");
+    if (found) {
+        memcpy(target_host, argv[optind], found-argv[optind]);
+        target_port = strtol(found+1, NULL, 10);
     } else {
     } else {
         usage();
         usage();
     }
     }
@@ -303,7 +325,8 @@ int main(int argc, char *argv[])
     if (! (cbuf_tmp = malloc(bufsize)) )
     if (! (cbuf_tmp = malloc(bufsize)) )
             { fatal("malloc()"); }
             { fatal("malloc()"); }
 
 
-    start_server(listen_port, &proxy_handler, listen_host, ssl_only);
+    settings.handler = proxy_handler; 
+    start_server();
 
 
     free(tbuf);
     free(tbuf);
     free(cbuf);
     free(cbuf);

+ 20 - 10
utils/wsproxy.py

@@ -99,14 +99,14 @@ def do_proxy(client, target):
 def proxy_handler(client):
 def proxy_handler(client):
     global target_host, target_port, options, rec
     global target_host, target_port, options, rec
 
 
+    if settings['record']:
+        print "Opening record file: %s" % settings['record']
+        rec = open(settings['record'], 'w')
+
     print "Connecting to: %s:%s" % (target_host, target_port)
     print "Connecting to: %s:%s" % (target_host, target_port)
     tsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
     tsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
     tsock.connect((target_host, target_port))
     tsock.connect((target_host, target_port))
 
 
-    if options.record:
-        print "Opening record file: %s" % options.record
-        rec = open(options.record, 'w')
-
     print traffic_legend
     print traffic_legend
 
 
     try:
     try:
@@ -122,25 +122,35 @@ if __name__ == '__main__':
     parser = optparse.OptionParser(usage=usage)
     parser = optparse.OptionParser(usage=usage)
     parser.add_option("--record",
     parser.add_option("--record",
             help="record session to a file", metavar="FILE")
             help="record session to a file", metavar="FILE")
+    parser.add_option("--foreground", "-f",
+            dest="daemon", default=True, action="store_false",
+            help="stay in foreground, do not daemonize")
     parser.add_option("--ssl-only", action="store_true",
     parser.add_option("--ssl-only", action="store_true",
             help="disallow non-encrypted connections")
             help="disallow non-encrypted connections")
+    parser.add_option("--cert", default="self.pem",
+            help="SSL certificate")
     (options, args) = parser.parse_args()
     (options, args) = parser.parse_args()
 
 
     if len(args) > 2: parser.error("Too many arguments")
     if len(args) > 2: parser.error("Too many arguments")
     if len(args) < 2: parser.error("Too few arguments")
     if len(args) < 2: parser.error("Too few arguments")
     if args[0].count(':') > 0:
     if args[0].count(':') > 0:
-        listen_host,listen_port = args[0].split(':')
+        host,port = args[0].split(':')
     else:
     else:
-        listen_host = ''
-        listen_port = args[0]
+        host,port = '',args[0]
     if args[1].count(':') > 0:
     if args[1].count(':') > 0:
         target_host,target_port = args[1].split(':')
         target_host,target_port = args[1].split(':')
     else:
     else:
         parser.error("Error parsing target")
         parser.error("Error parsing target")
-    try:    listen_port = int(listen_port)
+    try:    port = int(port)
     except: parser.error("Error parsing listen port")
     except: parser.error("Error parsing listen port")
     try:    target_port = int(target_port)
     try:    target_port = int(target_port)
     except: parser.error("Error parsing target port")
     except: parser.error("Error parsing target port")
 
 
-    start_server(listen_port, proxy_handler, listen_host=listen_host,
-            ssl_only=options.ssl_only)
+    settings['listen_host'] = host
+    settings['listen_port'] = port
+    settings['handler'] = proxy_handler
+    settings['cert'] = os.path.abspath(options.cert)
+    settings['ssl_only'] = options.ssl_only
+    settings['daemon'] = options.daemon
+    settings['record'] = os.path.abspath(options.record)
+    start_server()