Эх сурвалжийг харах

websockify --run-once, --timeout, numpy fallback

Pull websockify 724aa3a.

- Use array module for unmasking HyBi when no numpy module is
    available.

- Detect client close properly when using python 3.

- Print request URL path is specified.

- New option --run-once will exit after handling a single WebSocket
  connection (but not ater flash policy or normal web requests).

- New option --timeout TIME will stop listening for new connections
  after exit after TIME seconds (the master process shuts down).
  Existing WebSocket connections will continue but once all
  connections are closed all processes will terminate.
Joel Martin 14 жил өмнө
parent
commit
1e50871599

+ 65 - 27
utils/websocket.py

@@ -16,7 +16,8 @@ as taken from http://docs.python.org/dev/library/ssl.html#certificates
 
 
 '''
 '''
 
 
-import os, sys, time, errno, signal, socket, struct, traceback, select
+import os, sys, time, errno, signal, socket, traceback, select
+import struct, array
 from cgi import parse_qsl
 from cgi import parse_qsl
 from base64 import b64encode, b64decode
 from base64 import b64encode, b64decode
 
 
@@ -28,6 +29,7 @@ if sys.hexversion > 0x3000000:
     from urllib.parse import urlsplit
     from urllib.parse import urlsplit
     b2s = lambda buf: buf.decode('latin_1')
     b2s = lambda buf: buf.decode('latin_1')
     s2b = lambda s: s.encode('latin_1')
     s2b = lambda s: s.encode('latin_1')
+    s2a = lambda s: s
 else:
 else:
     # python 2.X
     # python 2.X
     from cStringIO import StringIO
     from cStringIO import StringIO
@@ -36,6 +38,7 @@ else:
     # No-ops
     # No-ops
     b2s = lambda buf: buf
     b2s = lambda buf: buf
     s2b = lambda s: s
     s2b = lambda s: s
+    s2a = lambda s: [ord(c) for c in s]
 
 
 if sys.hexversion >= 0x2060000:
 if sys.hexversion >= 0x2060000:
     # python >= 2.6
     # python >= 2.6
@@ -54,7 +57,7 @@ for mod, sup in [('numpy', 'HyBi protocol'),
         globals()[mod] = __import__(mod)
         globals()[mod] = __import__(mod)
     except ImportError:
     except ImportError:
         globals()[mod] = None
         globals()[mod] = None
-        print("WARNING: no '%s' module, %s support disabled" % (
+        print("WARNING: no '%s' module, %s decode may be slower" % (
             mod, sup))
             mod, sup))
 
 
 
 
@@ -88,7 +91,8 @@ Sec-WebSocket-Accept: %s\r
 
 
     def __init__(self, listen_host='', listen_port=None, source_is_ipv6=False,
     def __init__(self, listen_host='', listen_port=None, source_is_ipv6=False,
             verbose=False, cert='', key='', ssl_only=None,
             verbose=False, cert='', key='', ssl_only=None,
-            daemon=False, record='', web=''):
+            daemon=False, record='', web='',
+            run_once=False, timeout=0):
 
 
         # settings
         # settings
         self.verbose        = verbose
         self.verbose        = verbose
@@ -96,6 +100,11 @@ Sec-WebSocket-Accept: %s\r
         self.listen_port    = listen_port
         self.listen_port    = listen_port
         self.ssl_only       = ssl_only
         self.ssl_only       = ssl_only
         self.daemon         = daemon
         self.daemon         = daemon
+        self.run_once       = run_once
+        self.timeout        = timeout
+
+        self.launch_time    = time.time()
+        self.ws_connection  = False
         self.handler_id     = 1
         self.handler_id     = 1
 
 
         # Make paths settings absolute
         # Make paths settings absolute
@@ -207,6 +216,38 @@ Sec-WebSocket-Accept: %s\r
         os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdout.fileno())
         os.dup2(os.open(os.devnull, os.O_RDWR), sys.stdout.fileno())
         os.dup2(os.open(os.devnull, os.O_RDWR), sys.stderr.fileno())
         os.dup2(os.open(os.devnull, os.O_RDWR), sys.stderr.fileno())
 
 
+    @staticmethod
+    def unmask(buf, f):
+        pstart = f['hlen'] + 4
+        pend = pstart + f['length']
+        if numpy:
+            b = c = s2b('')
+            if f['length'] >= 4:
+                mask = numpy.frombuffer(buf, dtype=numpy.dtype('<u4'),
+                        offset=f['hlen'], count=1)
+                data = numpy.frombuffer(buf, dtype=numpy.dtype('<u4'),
+                        offset=pstart, count=int(f['length'] / 4))
+                #b = numpy.bitwise_xor(data, mask).data
+                b = numpy.bitwise_xor(data, mask).tostring()
+
+            if f['length'] % 4:
+                #print("Partial unmask")
+                mask = numpy.frombuffer(buf, dtype=numpy.dtype('B'),
+                        offset=f['hlen'], count=(f['length'] % 4))
+                data = numpy.frombuffer(buf, dtype=numpy.dtype('B'),
+                        offset=pend - (f['length'] % 4),
+                        count=(f['length'] % 4))
+                c = numpy.bitwise_xor(data, mask).tostring()
+            return b + c
+        else:
+            # Slower fallback
+            data = array.array('B')
+            mask = s2a(f['mask'])
+            data.fromstring(buf[pstart:pend])
+            for i in range(len(data)):
+                data[i] ^= mask[i % 4]
+            return data.tostring()
+
     @staticmethod
     @staticmethod
     def encode_hybi(buf, opcode, base64=False):
     def encode_hybi(buf, opcode, base64=False):
         """ Encode a HyBi style WebSocket frame.
         """ Encode a HyBi style WebSocket frame.
@@ -295,24 +336,7 @@ Sec-WebSocket-Accept: %s\r
         if has_mask:
         if has_mask:
             # unmask payload
             # unmask payload
             f['mask'] = buf[f['hlen']:f['hlen']+4]
             f['mask'] = buf[f['hlen']:f['hlen']+4]
-            b = c = s2b('')
-            if f['length'] >= 4:
-                mask = numpy.frombuffer(buf, dtype=numpy.dtype('<u4'),
-                        offset=f['hlen'], count=1)
-                data = numpy.frombuffer(buf, dtype=numpy.dtype('<u4'),
-                        offset=f['hlen'] + 4, count=int(f['length'] / 4))
-                #b = numpy.bitwise_xor(data, mask).data
-                b = numpy.bitwise_xor(data, mask).tostring()
-
-            if f['length'] % 4:
-                #print("Partial unmask")
-                mask = numpy.frombuffer(buf, dtype=numpy.dtype('B'),
-                        offset=f['hlen'], count=(f['length'] % 4))
-                data = numpy.frombuffer(buf, dtype=numpy.dtype('B'),
-                        offset=full_len - (f['length'] % 4),
-                        count=(f['length'] % 4))
-                c = numpy.bitwise_xor(data, mask).tostring()
-            f['payload'] = b + c
+            f['payload'] = WebSocketServer.unmask(buf, f)
         else:
         else:
             print("Unmasked frame: %s" % repr(buf))
             print("Unmasked frame: %s" % repr(buf))
             f['payload'] = buf[(f['hlen'] + has_mask * 4):full_len]
             f['payload'] = buf[(f['hlen'] + has_mask * 4):full_len]
@@ -468,11 +492,11 @@ Sec-WebSocket-Accept: %s\r
                         break
                         break
 
 
             else:
             else:
-                if buf[0:2] == '\xff\x00':
+                if buf[0:2] == s2b('\xff\x00'):
                     closed = "Client sent orderly close frame"
                     closed = "Client sent orderly close frame"
                     break
                     break
 
 
-                elif buf[0:2] == '\x00\xff':
+                elif buf[0:2] == s2b('\x00\xff'):
                     buf = buf[2:]
                     buf = buf[2:]
                     continue # No-op
                     continue # No-op
 
 
@@ -611,9 +635,6 @@ Sec-WebSocket-Accept: %s\r
         if ver:
         if ver:
             # HyBi/IETF version of the protocol
             # HyBi/IETF version of the protocol
 
 
-            if sys.hexversion < 0x2060000 or not numpy:
-                raise self.EClose("Python >= 2.6 and numpy module is required for HyBi-07 or greater")
-
             # HyBi-07 report version 7
             # HyBi-07 report version 7
             # HyBi-08 - HyBi-12 report version 8
             # HyBi-08 - HyBi-12 report version 8
             # HyBi-13 reports version 13
             # HyBi-13 reports version 13
@@ -669,6 +690,9 @@ Sec-WebSocket-Accept: %s\r
         self.msg("%s: %s WebSocket connection" % (address[0], stype))
         self.msg("%s: %s WebSocket connection" % (address[0], stype))
         self.msg("%s: Version %s, base64: '%s'" % (address[0],
         self.msg("%s: Version %s, base64: '%s'" % (address[0],
             self.version, self.base64))
             self.version, self.base64))
+        if self.path != '/':
+            self.msg("%s: Path: '%s'" % (address[0], self.path))
+
 
 
         # Send server WebSockets handshake response
         # Send server WebSockets handshake response
         #self.msg("sending response [%s]" % response)
         #self.msg("sending response [%s]" % response)
@@ -727,6 +751,7 @@ Sec-WebSocket-Accept: %s\r
                     self.rec = open(fname, 'w+')
                     self.rec = open(fname, 'w+')
                     self.rec.write("var VNC_frame_data = [\n")
                     self.rec.write("var VNC_frame_data = [\n")
 
 
+                self.ws_connection = True
                 self.new_client()
                 self.new_client()
             except self.EClose:
             except self.EClose:
                 _, exc, _ = sys.exc_info()
                 _, exc, _ = sys.exc_info()
@@ -777,6 +802,12 @@ Sec-WebSocket-Accept: %s\r
                     startsock = None
                     startsock = None
                     pid = err = 0
                     pid = err = 0
 
 
+                    time_elapsed = time.time() - self.launch_time
+                    if self.timeout and time_elapsed > self.timeout:
+                        self.msg('listener exit due to --timeout %s'
+                                % self.timeout)
+                        break
+
                     try:
                     try:
                         self.poll()
                         self.poll()
 
 
@@ -799,7 +830,14 @@ Sec-WebSocket-Accept: %s\r
                         else:
                         else:
                             raise
                             raise
 
 
-                    if Process:
+                    if self.run_once:
+                        # Run in same process if run_once
+                        self.top_new_client(startsock, address)
+                        if self.ws_connection :
+                            self.msg('%s: exiting due to --run-once'
+                                    % address[0])
+                            break
+                    elif Process:
                         self.vmsg('%s: new handler Process' % address[0])
                         self.vmsg('%s: new handler Process' % address[0])
                         p = Process(target=self.top_new_client,
                         p = Process(target=self.top_new_client,
                                 args=(startsock, address))
                                 args=(startsock, address))

+ 4 - 0
utils/websockify

@@ -226,6 +226,10 @@ if __name__ == '__main__':
     parser.add_option("--daemon", "-D",
     parser.add_option("--daemon", "-D",
             dest="daemon", action="store_true",
             dest="daemon", action="store_true",
             help="become a daemon (background process)")
             help="become a daemon (background process)")
+    parser.add_option("--run-once", action="store_true",
+            help="handle a single WebSocket connection and exit")
+    parser.add_option("--timeout", type=int, default=0,
+            help="after TIMEOUT seconds exit when not connected")
     parser.add_option("--cert", default="self.pem",
     parser.add_option("--cert", default="self.pem",
             help="SSL certificate file")
             help="SSL certificate file")
     parser.add_option("--key", default=None,
     parser.add_option("--key", default=None,