浏览代码

Pull websockify socket() static method.

Pull websockify 46e2fbe.

WebSocketServer.socket() is a static method takes a host and port and
an optional connect parameter. If connect is not set then it returns
a socket listening on host and port. If connect is set then
a connection will be made host and port and the socket returned. This
has IPv6 support like the addrinfo method it replaces.

Also, prefer IPv4 resolutions if they are in the list. This can be
overriden to prefer IPv6 resolutions for the same host using the
optional prefer_ipv6 parameter.
Joel Martin 14 年之前
父节点
当前提交
4f8c746518
共有 2 个文件被更改,包括 24 次插入15 次删除
  1. 22 12
      utils/websocket.py
  2. 2 3
      utils/websockify

+ 22 - 12
utils/websocket.py

@@ -144,16 +144,30 @@ Sec-WebSocket-Accept: %s\r
     #
     #
 
 
     @staticmethod
     @staticmethod
-    def addrinfo(host, port=None):
-        """ Resolve a host (and optional port) to an IPv4 or IPv6 address.
-        Returns: family, socktype, proto, canonname, sockaddr
+    def socket(host, port=None, connect=False, prefer_ipv6=False):
+        """ Resolve a host (and optional port) to an IPv4 or IPv6
+        address. Create a socket. Bind to it if listen is set. Return
+        a socket that is ready for listen or connect.
         """
         """
-        if not host:
-            host = 'localhost'
-        addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM, socket.IPPROTO_TCP)
+        flags = 0
+        if host == '': host = None
+        if not connect:
+            flags = flags | socket.AI_PASSIVE
+        addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM,
+                socket.IPPROTO_TCP, flags)
         if not addrs:
         if not addrs:
             raise Exception("Could resolve host '%s'" % host)
             raise Exception("Could resolve host '%s'" % host)
-        return addrs[0]
+        addrs.sort(key=lambda x: x[0])
+        if prefer_ipv6:
+            addrs.reverse()
+        sock = socket.socket(addrs[0][0], addrs[0][1])
+        if connect:
+            sock.connect(addrs[0][4])
+        else:
+            sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+            sock.bind(addrs[0][4])
+            sock.listen(100)
+        return sock
 
 
     @staticmethod
     @staticmethod
     def daemonize(keepfd=None, chdir='/'):
     def daemonize(keepfd=None, chdir='/'):
@@ -738,11 +752,7 @@ Sec-WebSocket-Accept: %s\r
         is a WebSockets client then call new_client() method (which must
         is a WebSockets client then call new_client() method (which must
         be overridden) for each new client connection.
         be overridden) for each new client connection.
         """
         """
-        addr = self.addrinfo(self.listen_host, self.listen_port)
-        lsock = socket.socket(addr[0], addr[1])
-        lsock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
-        lsock.bind((self.listen_host, self.listen_port))
-        lsock.listen(100)
+        lsock = self.socket(self.listen_host, self.listen_port)
 
 
         if self.daemon:
         if self.daemon:
             self.daemonize(keepfd=lsock.fileno(), chdir=self.web)
             self.daemonize(keepfd=lsock.fileno(), chdir=self.web)

+ 2 - 3
utils/websockify

@@ -141,9 +141,8 @@ Traffic Legend:
         # Connect to the target
         # Connect to the target
         self.msg("connecting to: %s:%s" % (
         self.msg("connecting to: %s:%s" % (
                  self.target_host, self.target_port))
                  self.target_host, self.target_port))
-        addr = self.addrinfo(self.target_host, self.target_port)
-        tsock = socket.socket(addr[0], addr[1])
-        tsock.connect((self.target_host, self.target_port))
+        tsock = self.socket(self.target_host, self.target_port,
+                connect=True)
 
 
         if self.verbose and not self.daemon:
         if self.verbose and not self.daemon:
             print(self.traffic_legend)
             print(self.traffic_legend)