Переглянути джерело

wswrapper: Allow multiple WebSockets connections.

Allocate buffer and state memory for each accepted connection. This
allows all WebSockets connections to a given listen port to be wrapped
with WebSockets support.
Joel Martin 14 роки тому
батько
коміт
c99124b527
1 змінених файлів з 86 додано та 79 видалено
  1. 86 79
      utils/wswrapper.c

+ 86 - 79
utils/wswrapper.c

@@ -48,7 +48,6 @@
     return -1;
 
 
-
 const char _WS_response[] = "HTTP/1.1 101 Web Socket Protocol Handshake\r\n\
 Upgrade: WebSocket\r\n\
 Connection: Upgrade\r\n\
@@ -57,6 +56,17 @@ Connection: Upgrade\r\n\
 %sWebSocket-Protocol: sample\r\n\
 \r\n%s";
 
+#define WS_BUFSIZE 65536
+
+typedef struct {
+    char rbuf[WS_BUFSIZE];
+    char sbuf[WS_BUFSIZE];
+    int  rcarry_cnt;
+    char rcarry[3];
+    int  newframe;
+} _WS_connection;
+
+
 /*
  * If WSWRAP_PORT environment variable is set then listen to the bind fd that
  * matches WSWRAP_PORT, otherwise listen to the first socket fd that bind is
@@ -65,26 +75,12 @@ Connection: Upgrade\r\n\
 int   _WS_listen_fd  = 0;
 int   _WS_sockfd     = 0;
 
-typedef struct {
-    char _WS_rbuf[65536];
-    char _WS_sbuf[65536];
-} _WS_connection;
+_WS_connection * _WS_connections[65546];
 
-int   _WS_bufsize    = 65536;
-char *_WS_rbuf       = NULL;
-char *_WS_sbuf       = NULL;
-int   _WS_rcarry_cnt = 0;
-char  _WS_rcarry[3]  = "";
-int   _WS_newframe   = 1;
 
-int _WS_init() {
-    if (! (_WS_rbuf = malloc(_WS_bufsize)) ) {
-        return 0;
-    }
-    if (! (_WS_sbuf = malloc(_WS_bufsize)) ) {
-        return 0;
-    }
-}
+/* 
+ * WebSocket handshake routines
+ */
 
 int _WS_gen_md5(char *key1, char *key2, char *key3, char *target) {
     unsigned int i, spaces1 = 0, spaces2 = 0;
@@ -246,13 +242,17 @@ int _WS_handshake(int sockfd)
     return ret;
 }
 
+/*
+ * WebSockets recv and read interposer routine
+ */
 ssize_t _WS_recv(int recvf, int sockfd, const void *buf,
                  size_t len, int flags)
 {
+    _WS_connection *ws = _WS_connections[sockfd];
     int rawcount, deccount, left, rawlen, retlen, decodelen;
     int sockflags;
     int i;
-    char * fstart, * fend, * cstart;
+    char *fstart, *fend, *cstart;
 
     static void * (*rfunc)(), * (*rfunc2)();
     if (!rfunc) rfunc = (void *(*)()) dlsym(RTLD_NEXT, "recv");
@@ -262,7 +262,7 @@ ssize_t _WS_recv(int recvf, int sockfd, const void *buf,
         return 0;
     }
 
-    if ((_WS_sockfd == 0) || (_WS_sockfd != sockfd)) {
+    if (! ws) {
         // Not our file descriptor, just pass through
         if (recvf) {
             return (ssize_t) rfunc(sockfd, buf, len, flags);
@@ -277,26 +277,26 @@ ssize_t _WS_recv(int recvf, int sockfd, const void *buf,
     retlen = 0;
 
     // first copy in any carry-over bytes
-    if (_WS_rcarry_cnt) {
-        if (_WS_rcarry_cnt == 1) {
-            DEBUG("Using carry byte: %u (", _WS_rcarry[0]);
-        } else if (_WS_rcarry_cnt == 2) {
-            DEBUG("Using carry bytes: %u,%u (", _WS_rcarry[0],
-                    _WS_rcarry[1]);
+    if (ws->rcarry_cnt) {
+        if (ws->rcarry_cnt == 1) {
+            DEBUG("Using carry byte: %u (", ws->rcarry[0]);
+        } else if (ws->rcarry_cnt == 2) {
+            DEBUG("Using carry bytes: %u,%u (", ws->rcarry[0],
+                    ws->rcarry[1]);
         } else {
             RET_ERROR(EIO, "Too many carry-over bytes\n");
         }
-        if (len <= _WS_rcarry_cnt) {
+        if (len <= ws->rcarry_cnt) {
             DEBUG("final)\n");
-            memcpy((char *) buf, _WS_rcarry, len);
-            _WS_rcarry_cnt -= len;
+            memcpy((char *) buf, ws->rcarry, len);
+            ws->rcarry_cnt -= len;
             return len;
         } else {
             DEBUG("prepending)\n");
-            memcpy((char *) buf, _WS_rcarry, _WS_rcarry_cnt);
-            retlen += _WS_rcarry_cnt;
-            left -= _WS_rcarry_cnt;
-            _WS_rcarry_cnt = 0;
+            memcpy((char *) buf, ws->rcarry, ws->rcarry_cnt);
+            retlen += ws->rcarry_cnt;
+            left -= ws->rcarry_cnt;
+            ws->rcarry_cnt = 0;
         }
     }
 
@@ -304,20 +304,20 @@ ssize_t _WS_recv(int recvf, int sockfd, const void *buf,
     rawcount = (left * 4) / 3 + 3;
     rawcount -= rawcount%4;
 
-    if (rawcount > _WS_bufsize - 1) {
+    if (rawcount > WS_BUFSIZE - 1) {
         RET_ERROR(ENOMEM, "recv of %d bytes is larger than buffer\n", rawcount);
     }
 
     i = 0;
     while (1) {
         // Peek at everything available
-        rawlen = (int) rfunc(sockfd, _WS_rbuf, _WS_bufsize-1,
+        rawlen = (int) rfunc(sockfd, ws->rbuf, WS_BUFSIZE-1,
                             flags | MSG_PEEK);
         if (rawlen <= 0) {
             DEBUG("_WS_recv: returning because rawlen %d\n", rawlen);
             return (ssize_t) rawlen;
         }
-        fstart = _WS_rbuf;
+        fstart = ws->rbuf;
 
         /*
         while (rawlen >= 2 && fstart[0] == '\x00' && fstart[1] == '\xff') {
@@ -326,7 +326,7 @@ ssize_t _WS_recv(int recvf, int sockfd, const void *buf,
         }
         */
         if (rawlen >= 2 && fstart[0] == '\x00' && fstart[1] == '\xff') {
-            rawlen = (int) rfunc(sockfd, _WS_rbuf, 2, flags);
+            rawlen = (int) rfunc(sockfd, ws->rbuf, 2, flags);
             if (rawlen != 2) {
                 RET_ERROR(EIO, "Could not strip empty frame headers\n");
             }
@@ -335,7 +335,7 @@ ssize_t _WS_recv(int recvf, int sockfd, const void *buf,
 
         fstart[rawlen] = '\x00';
 
-        if (rawlen - _WS_newframe >= 4) {
+        if (rawlen - ws->newframe >= 4) {
             // We have enough to base64 decode at least 1 byte
             break;
         }
@@ -362,19 +362,19 @@ ssize_t _WS_recv(int recvf, int sockfd, const void *buf,
     DEBUG("\n");
     */
 
-    if (_WS_newframe) {
+    if (ws->newframe) {
         if (fstart[0] != '\x00') {
             RET_ERROR(EPROTO, "Missing frame start\n");
         }
         fstart++;
         rawlen--;
-        _WS_newframe = 0;
+        ws->newframe = 0;
     }
 
     fend = memchr(fstart, '\xff', rawlen);
 
     if (fend) {
-        _WS_newframe = 1;
+        ws->newframe = 1;
         if ((fend - fstart) % 4) {
             RET_ERROR(EPROTO, "Frame length is not multiple of 4\n");
         }
@@ -387,7 +387,7 @@ ssize_t _WS_recv(int recvf, int sockfd, const void *buf,
 
     // How much should we consume
     if (rawcount < fend - fstart) {
-        _WS_newframe = 0;
+        ws->newframe = 0;
         deccount = rawcount;
     } else {
         deccount = fend - fstart;
@@ -397,7 +397,7 @@ ssize_t _WS_recv(int recvf, int sockfd, const void *buf,
     if (flags & MSG_PEEK) {
         MSG("*** Got MSG_PEEK ***\n");
     } else {
-        rfunc(sockfd, _WS_rbuf, fstart - _WS_rbuf + deccount + _WS_newframe, flags);
+        rfunc(sockfd, ws->rbuf, fstart - ws->rbuf + deccount + ws->newframe, flags);
     }
 
     fstart[deccount] = '\x00'; // base64 terminator
@@ -415,16 +415,16 @@ ssize_t _WS_recv(int recvf, int sockfd, const void *buf,
 
         if (! (flags & MSG_PEEK)) {
             // Add anything left over to the carry-over
-            _WS_rcarry_cnt = decodelen - left;
-            if (_WS_rcarry_cnt > 2) {
+            ws->rcarry_cnt = decodelen - left;
+            if (ws->rcarry_cnt > 2) {
                 RET_ERROR(EPROTO, "Got too much base64 data\n");
             }
-            memcpy(_WS_rcarry, buf + retlen, _WS_rcarry_cnt);
-            if (_WS_rcarry_cnt == 1) {
-                DEBUG("Saving carry byte: %u\n", _WS_rcarry[0]);
-            } else if (_WS_rcarry_cnt == 2) {
-                DEBUG("Saving carry bytes: %u,%u\n", _WS_rcarry[0],
-                        _WS_rcarry[1]);
+            memcpy(ws->rcarry, buf + retlen, ws->rcarry_cnt);
+            if (ws->rcarry_cnt == 1) {
+                DEBUG("Saving carry byte: %u\n", ws->rcarry[0]);
+            } else if (ws->rcarry_cnt == 2) {
+                DEBUG("Saving carry bytes: %u,%u\n", ws->rcarry[0],
+                        ws->rcarry[1]);
             } else {
                 MSG("Waah2!\n");
             }
@@ -442,9 +442,13 @@ ssize_t _WS_recv(int recvf, int sockfd, const void *buf,
     return retlen;
 }
 
+/*
+ * WebSockets send and write interposer routine
+ */
 ssize_t _WS_send(int sendf, int sockfd, const void *buf,
                  size_t len, int flags)
 {
+    _WS_connection *ws = _WS_connections[sockfd];
     int rawlen, enclen, rlen, over, left, clen, retlen, dbufsize;
     int sockflags;
     char * target;
@@ -453,7 +457,7 @@ ssize_t _WS_send(int sendf, int sockfd, const void *buf,
     if (!sfunc) sfunc = (void *(*)()) dlsym(RTLD_NEXT, "send");
     if (!sfunc2) sfunc2 = (void *(*)()) dlsym(RTLD_NEXT, "write");
 
-    if ((_WS_sockfd == 0) || (_WS_sockfd != sockfd)) {
+    if (! ws) {
         // Not our file descriptor, just pass through
         if (sendf) {
             return (ssize_t) sfunc(sockfd, buf, len, flags);
@@ -465,22 +469,22 @@ ssize_t _WS_send(int sendf, int sockfd, const void *buf,
 
     sockflags = fcntl(sockfd, F_GETFL, 0);
 
-    dbufsize = (_WS_bufsize * 3)/4 - 2;
+    dbufsize = (WS_BUFSIZE * 3)/4 - 2;
     if (len > dbufsize) {
         RET_ERROR(ENOMEM, "send of %d bytes is larger than send buffer\n", len);
     }
 
     // base64 encode and add frame markers
     rawlen = 0;
-    _WS_sbuf[rawlen++] = '\x00';
-    enclen = b64_ntop(buf, len, _WS_sbuf+rawlen, _WS_bufsize-rawlen);
+    ws->sbuf[rawlen++] = '\x00';
+    enclen = b64_ntop(buf, len, ws->sbuf+rawlen, WS_BUFSIZE-rawlen);
     if (enclen < 0) {
         RET_ERROR(EPROTO, "Base64 encoding error\n");
     }
     rawlen += enclen;
-    _WS_sbuf[rawlen++] = '\xff';
+    ws->sbuf[rawlen++] = '\xff';
 
-    rlen = (int) sfunc(sockfd, _WS_sbuf, rawlen, flags);
+    rlen = (int) sfunc(sockfd, ws->sbuf, rawlen, flags);
 
     if (rlen <= 0) {
         return rlen;
@@ -490,11 +494,11 @@ ssize_t _WS_send(int sendf, int sockfd, const void *buf,
         left = (4 - over) % 4 + 1; // left to send
         DEBUG("_WS_send: rlen: %d (over: %d, left: %d), rawlen: %d\n", rlen, over, left, rawlen);
         rlen += left;
-        _WS_sbuf[rlen-1] = '\xff';
+        ws->sbuf[rlen-1] = '\xff';
         i = 0;
         do {
             i++;
-            clen = (int) sfunc(sockfd, _WS_sbuf + rlen - left, left, flags);
+            clen = (int) sfunc(sockfd, ws->sbuf + rlen - left, left, flags);
             if (clen > 0) {
                 left -= clen;
             } else if (clen == 0) {
@@ -518,8 +522,8 @@ ssize_t _WS_send(int sendf, int sockfd, const void *buf,
     // Adjust for framing
     retlen = rlen - 2;
     // Adjust for base64 padding
-    if (_WS_sbuf[rlen-1] == '=') { retlen --; }
-    if (_WS_sbuf[rlen-2] == '=') { retlen --; }
+    if (ws->sbuf[rlen-1] == '=') { retlen --; }
+    if (ws->sbuf[rlen-2] == '=') { retlen --; }
 
     // Adjust for base64 encoding
     retlen = (retlen*3)/4;
@@ -529,13 +533,15 @@ ssize_t _WS_send(int sendf, int sockfd, const void *buf,
     for (i = 0; i < retlen; i++) {
         DEBUG("%u,", (unsigned char) ((char *)buf)[i]);
     }
-    DEBUG(" as '%s' (%d)\n", _WS_sbuf+1, rlen);
+    DEBUG(" as '%s' (%d)\n", ws->sbuf+1, rlen);
     */
     return (ssize_t) retlen;
 }
 
 
-/* Override network routines */
+/*
+ * Overload (LD_PRELOAD) standard library network routines
+ */
 
 /*
 int socket(int domain, int type, int protocol)
@@ -603,24 +609,24 @@ int accept(int sockfd, struct sockaddr *addr, socklen_t *addrlen)
         return fd;
     }
 
-    if (_WS_sockfd == 0) {
-        // TODO: not just first connection
-        _WS_sockfd = fd;
-
-        if (!_WS_rbuf) {
-            if (! _WS_init()) {
-                RET_ERROR(ENOMEM, "Could not allocate interposer buffer\n");
-            }
+    if (_WS_connections[fd]) {
+        MSG("error, already interposing on fd %d\n", fd);
+    } else {
+        if (! (_WS_connections[fd] = malloc(sizeof(_WS_connection)))) {
+            RET_ERROR(ENOMEM, "Could not allocate interposer memory\n");
         }
+        _WS_connections[fd]->rcarry_cnt = 0;
+        _WS_connections[fd]->rcarry[0]  = '\0';
+        _WS_connections[fd]->newframe   = 1;
 
-        ret = _WS_handshake(_WS_sockfd);
+        ret = _WS_handshake(fd);
         if (ret < 0) {
+            free(_WS_connections[fd]);
+            _WS_connections[fd] = NULL;
             errno = EPROTO;
             return ret;
         }
-        MSG("interposing on fd %d\n", _WS_sockfd);
-    } else {
-        DEBUG("already interposing on fd %d\n", _WS_sockfd);
+        MSG("interposing on fd %d (allocated memory)\n", fd);
     }
 
     return fd;
@@ -631,9 +637,10 @@ int close(int fd)
     static void * (*func)();
     if (!func) func = (void *(*)()) dlsym(RTLD_NEXT, "close");
 
-    if ((_WS_sockfd != 0) && (_WS_sockfd == fd)) {
-        MSG("finished interposing on fd %d\n", _WS_sockfd);
-        _WS_sockfd = 0;
+    if (_WS_connections[fd]) {
+        free(_WS_connections[fd]);
+        _WS_connections[fd] = NULL;
+        MSG("finished interposing on fd %d (freed memory)\n", fd);
     }
     return (int) func(fd);
 }