Selaa lähdekoodia

Automatically detect TLS/SSL during handshake.

Use MSG_PEEK flag on recv to detect whether we are getting a flash
policy request, an SSL/TLS header, or a plain socket connection.
Joel Martin 15 vuotta sitten
vanhempi
commit
ca5785f570
1 muutettua tiedostoa jossa 45 lisäystä ja 29 poistoa
  1. 45 29
      wsproxy.py

+ 45 - 29
wsproxy.py

@@ -1,6 +1,6 @@
 #!/usr/bin/python
 #!/usr/bin/python
 
 
-import sys, os, socket, time, traceback, re
+import sys, os, socket, ssl, time, traceback, re
 from base64 import b64encode, b64decode
 from base64 import b64encode, b64decode
 from select import select
 from select import select
 
 
@@ -12,7 +12,7 @@ server_handshake = """HTTP/1.1 101 Web Socket Protocol Handshake\r
 Upgrade: WebSocket\r
 Upgrade: WebSocket\r
 Connection: Upgrade\r
 Connection: Upgrade\r
 WebSocket-Origin: %s\r
 WebSocket-Origin: %s\r
-WebSocket-Location: ws://%s%s\r
+WebSocket-Location: %s://%s%s\r
 WebSocket-Protocol: sample\r
 WebSocket-Protocol: sample\r
 \r
 \r
 """
 """
@@ -32,31 +32,6 @@ Traffic Legend:
 """
 """
 
 
 
 
-def do_handshake(client):
-    global client_settings
-    handshake = client.recv(1024)
-    #print "Handshake [%s]" % handshake
-    if handshake.startswith("<policy-file-request/>"):
-        print "Sending flash policy response"
-        client.send(policy_response)
-        client.close()
-        return False
-    req_lines = handshake.split("\r\n")
-    _, path, _ = req_lines[0].split(" ")
-    _, origin = req_lines[4].split(" ")
-    _, host = req_lines[3].split(" ")
-
-    # Parse settings from the path
-    cvars = path.partition('?')[2].partition('#')[0].split('&')
-    for cvar in [c for c in cvars if c]:
-        name, _, value = cvar.partition('=')
-        client_settings[name] = value and value or True
-
-    print "client_settings:", client_settings
-
-    client.send(server_handshake % (origin, host, path))
-    return True
-
 def traffic(token="."):
 def traffic(token="."):
     sys.stdout.write(token)
     sys.stdout.write(token)
     sys.stdout.flush()
     sys.stdout.flush()
@@ -139,6 +114,46 @@ def proxy(client, target):
                 cpartial = cpartial + buf
                 cpartial = cpartial + buf
 
 
 
 
+def do_handshake(sock):
+    global client_settings
+    # Peek, but don't read the data
+    handshake = sock.recv(1024, socket.MSG_PEEK)
+    #print "Handshake [%s]" % repr(handshake)
+    if handshake.startswith("<policy-file-request/>"):
+        handshake = sock.recv(1024)
+        print "Sending flash policy response"
+        sock.send(policy_response)
+        sock.close()
+        return False
+    elif handshake.startswith("\x16"):
+        retsock = ssl.wrap_socket(
+                sock,
+                server_side=True,
+                certfile='wsproxy.pem',
+                ssl_version=ssl.PROTOCOL_TLSv1)
+        scheme = "wss"
+        print "Using SSL/TLS"
+    else:
+        retsock = sock
+        scheme = "ws"
+        print "Using plain (not SSL) socket"
+    handshake = retsock.recv(4096)
+    req_lines = handshake.split("\r\n")
+    _, path, _ = req_lines[0].split(" ")
+    _, origin = req_lines[4].split(" ")
+    _, host = req_lines[3].split(" ")
+
+    # Parse settings from the path
+    cvars = path.partition('?')[2].partition('#')[0].split('&')
+    for cvar in [c for c in cvars if c]:
+        name, _, value = cvar.partition('=')
+        client_settings[name] = value and value or True
+
+    print "client_settings:", client_settings
+
+    retsock.send(server_handshake % (origin, scheme, host, path))
+    return retsock
+
 def start_server(listen_port, target_host, target_port):
 def start_server(listen_port, target_host, target_port):
     global send_seq
     global send_seq
     lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
     lsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@@ -150,9 +165,10 @@ def start_server(listen_port, target_host, target_port):
         try:
         try:
             csock = tsock = None
             csock = tsock = None
             print 'waiting for connection on port %s' % listen_port
             print 'waiting for connection on port %s' % listen_port
-            csock, address = lsock.accept()
+            startsock, address = lsock.accept()
             print 'Got client connection from %s' % address[0]
             print 'Got client connection from %s' % address[0]
-            if not do_handshake(csock): continue
+            csock = do_handshake(startsock)
+            if not csock: continue
             print "Connecting to: %s:%s" % (target_host, target_port)
             print "Connecting to: %s:%s" % (target_host, target_port)
             tsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
             tsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
             tsock.connect((target_host, target_port))
             tsock.connect((target_host, target_port))