websocket.c 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530
  1. /*
  2. * WebSocket lib with support for "wss://" encryption.
  3. * Copyright 2010 Joel Martin
  4. * Licensed under LGPL version 3 (see docs/LICENSE.LGPL-3)
  5. *
  6. * You can make a cert/key with openssl using:
  7. * openssl req -new -x509 -days 365 -nodes -out self.pem -keyout self.pem
  8. * as taken from http://docs.python.org/dev/library/ssl.html#certificates
  9. */
  10. #include <stdio.h>
  11. #include <stdlib.h>
  12. #include <errno.h>
  13. #include <strings.h>
  14. #include <sys/types.h>
  15. #include <sys/socket.h>
  16. #include <netinet/in.h>
  17. #include <arpa/inet.h>
  18. #include <netdb.h>
  19. #include <signal.h> // daemonizing
  20. #include <fcntl.h> // daemonizing
  21. #include <openssl/err.h>
  22. #include <openssl/ssl.h>
  23. #include <resolv.h> /* base64 encode/decode */
  24. #include "websocket.h"
  25. const char server_handshake[] = "HTTP/1.1 101 Web Socket Protocol Handshake\r\n\
  26. Upgrade: WebSocket\r\n\
  27. Connection: Upgrade\r\n\
  28. %sWebSocket-Origin: %s\r\n\
  29. %sWebSocket-Location: %s://%s%s\r\n\
  30. %sWebSocket-Protocol: sample\r\n\
  31. \r\n%s";
  32. const char policy_response[] = "<cross-domain-policy><allow-access-from domain=\"*\" to-ports=\"*\" /></cross-domain-policy>\n";
  33. /*
  34. * Global state
  35. *
  36. * Warning: not thread safe
  37. */
  38. int ssl_initialized = 0;
  39. int pipe_error = 0;
  40. char *tbuf, *cbuf, *tbuf_tmp, *cbuf_tmp;
  41. unsigned int bufsize, dbufsize;
  42. settings_t settings;
  43. void traffic(char * token) {
  44. if ((settings.verbose) && (! settings.daemon)) {
  45. fprintf(stdout, "%s", token);
  46. fflush(stdout);
  47. }
  48. }
  49. void error(char *msg)
  50. {
  51. perror(msg);
  52. }
  53. void fatal(char *msg)
  54. {
  55. perror(msg);
  56. exit(1);
  57. }
  58. /* resolve host with also IP address parsing */
  59. int resolve_host(struct in_addr *sin_addr, const char *hostname)
  60. {
  61. if (!inet_aton(hostname, sin_addr)) {
  62. struct addrinfo *ai, *cur;
  63. struct addrinfo hints;
  64. memset(&hints, 0, sizeof(hints));
  65. hints.ai_family = AF_INET;
  66. if (getaddrinfo(hostname, NULL, &hints, &ai))
  67. return -1;
  68. for (cur = ai; cur; cur = cur->ai_next) {
  69. if (cur->ai_family == AF_INET) {
  70. *sin_addr = ((struct sockaddr_in *)cur->ai_addr)->sin_addr;
  71. freeaddrinfo(ai);
  72. return 0;
  73. }
  74. }
  75. freeaddrinfo(ai);
  76. return -1;
  77. }
  78. return 0;
  79. }
  80. /*
  81. * SSL Wrapper Code
  82. */
  83. ssize_t ws_recv(ws_ctx_t *ctx, void *buf, size_t len) {
  84. if (ctx->ssl) {
  85. //handler_msg("SSL recv\n");
  86. return SSL_read(ctx->ssl, buf, len);
  87. } else {
  88. return recv(ctx->sockfd, buf, len, 0);
  89. }
  90. }
  91. ssize_t ws_send(ws_ctx_t *ctx, const void *buf, size_t len) {
  92. if (ctx->ssl) {
  93. //handler_msg("SSL send\n");
  94. return SSL_write(ctx->ssl, buf, len);
  95. } else {
  96. return send(ctx->sockfd, buf, len, 0);
  97. }
  98. }
  99. ws_ctx_t *ws_socket(int socket) {
  100. ws_ctx_t *ctx;
  101. ctx = malloc(sizeof(ws_ctx_t));
  102. ctx->sockfd = socket;
  103. ctx->ssl = NULL;
  104. ctx->ssl_ctx = NULL;
  105. return ctx;
  106. }
  107. ws_ctx_t *ws_socket_ssl(int socket, char * certfile) {
  108. int ret;
  109. char msg[1024];
  110. ws_ctx_t *ctx;
  111. ctx = ws_socket(socket);
  112. // Initialize the library
  113. if (! ssl_initialized) {
  114. SSL_library_init();
  115. OpenSSL_add_all_algorithms();
  116. SSL_load_error_strings();
  117. ssl_initialized = 1;
  118. }
  119. ctx->ssl_ctx = SSL_CTX_new(TLSv1_server_method());
  120. if (ctx->ssl_ctx == NULL) {
  121. ERR_print_errors_fp(stderr);
  122. fatal("Failed to configure SSL context");
  123. }
  124. if (SSL_CTX_use_PrivateKey_file(ctx->ssl_ctx, certfile,
  125. SSL_FILETYPE_PEM) <= 0) {
  126. sprintf(msg, "Unable to load private key file %s\n", certfile);
  127. fatal(msg);
  128. }
  129. if (SSL_CTX_use_certificate_file(ctx->ssl_ctx, certfile,
  130. SSL_FILETYPE_PEM) <= 0) {
  131. sprintf(msg, "Unable to load certificate file %s\n", certfile);
  132. fatal(msg);
  133. }
  134. // if (SSL_CTX_set_cipher_list(ctx->ssl_ctx, "DEFAULT") != 1) {
  135. // sprintf(msg, "Unable to set cipher\n");
  136. // fatal(msg);
  137. // }
  138. // Associate socket and ssl object
  139. ctx->ssl = SSL_new(ctx->ssl_ctx);
  140. SSL_set_fd(ctx->ssl, socket);
  141. ret = SSL_accept(ctx->ssl);
  142. if (ret < 0) {
  143. ERR_print_errors_fp(stderr);
  144. return NULL;
  145. }
  146. return ctx;
  147. }
  148. int ws_socket_free(ws_ctx_t *ctx) {
  149. if (ctx->ssl) {
  150. SSL_free(ctx->ssl);
  151. ctx->ssl = NULL;
  152. }
  153. if (ctx->ssl_ctx) {
  154. SSL_CTX_free(ctx->ssl_ctx);
  155. ctx->ssl_ctx = NULL;
  156. }
  157. if (ctx->sockfd) {
  158. close(ctx->sockfd);
  159. ctx->sockfd = 0;
  160. }
  161. free(ctx);
  162. }
  163. /* ------------------------------------------------------- */
  164. int encode(u_char const *src, size_t srclength, char *target, size_t targsize) {
  165. int i, sz = 0, len = 0;
  166. unsigned char chr;
  167. target[sz++] = '\x00';
  168. len = __b64_ntop(src, srclength, target+sz, targsize-sz);
  169. if (len < 0) {
  170. return len;
  171. }
  172. sz += len;
  173. target[sz++] = '\xff';
  174. return sz;
  175. }
  176. int decode(char *src, size_t srclength, u_char *target, size_t targsize) {
  177. char *start, *end, cntstr[4];
  178. int i, len, framecount = 0, retlen = 0;
  179. unsigned char chr;
  180. if ((src[0] != '\x00') || (src[srclength-1] != '\xff')) {
  181. handler_emsg("WebSocket framing error\n");
  182. return -1;
  183. }
  184. start = src+1; // Skip '\x00' start
  185. do {
  186. /* We may have more than one frame */
  187. end = memchr(start, '\xff', srclength);
  188. *end = '\x00';
  189. len = __b64_pton(start, target+retlen, targsize-retlen);
  190. if (len < 0) {
  191. return len;
  192. }
  193. retlen += len;
  194. start = end + 2; // Skip '\xff' end and '\x00' start
  195. framecount++;
  196. } while (end < (src+srclength-1));
  197. if (framecount > 1) {
  198. snprintf(cntstr, 3, "%d", framecount);
  199. traffic(cntstr);
  200. }
  201. return retlen;
  202. }
  203. int parse_handshake(char *handshake, headers_t *headers) {
  204. char *start, *end;
  205. if ((strlen(handshake) < 92) || (bcmp(handshake, "GET ", 4) != 0)) {
  206. return 0;
  207. }
  208. start = handshake+4;
  209. end = strstr(start, " HTTP/1.1");
  210. if (!end) { return 0; }
  211. strncpy(headers->path, start, end-start);
  212. headers->path[end-start] = '\0';
  213. start = strstr(handshake, "\r\nHost: ");
  214. if (!start) { return 0; }
  215. start += 8;
  216. end = strstr(start, "\r\n");
  217. strncpy(headers->host, start, end-start);
  218. headers->host[end-start] = '\0';
  219. start = strstr(handshake, "\r\nOrigin: ");
  220. if (!start) { return 0; }
  221. start += 10;
  222. end = strstr(start, "\r\n");
  223. strncpy(headers->origin, start, end-start);
  224. headers->origin[end-start] = '\0';
  225. start = strstr(handshake, "\r\n\r\n");
  226. if (!start) { return 0; }
  227. start += 4;
  228. if (strlen(start) == 8) {
  229. strncpy(headers->key3, start, 8);
  230. headers->key3[8] = '\0';
  231. start = strstr(handshake, "\r\nSec-WebSocket-Key1: ");
  232. if (!start) { return 0; }
  233. start += 22;
  234. end = strstr(start, "\r\n");
  235. strncpy(headers->key1, start, end-start);
  236. headers->key1[end-start] = '\0';
  237. start = strstr(handshake, "\r\nSec-WebSocket-Key2: ");
  238. if (!start) { return 0; }
  239. start += 22;
  240. end = strstr(start, "\r\n");
  241. strncpy(headers->key2, start, end-start);
  242. headers->key2[end-start] = '\0';
  243. } else {
  244. headers->key1[0] = '\0';
  245. headers->key2[0] = '\0';
  246. headers->key3[0] = '\0';
  247. }
  248. return 1;
  249. }
  250. int gen_md5(headers_t *headers, char *target) {
  251. unsigned int i, spaces1 = 0, spaces2 = 0;
  252. unsigned long num1 = 0, num2 = 0;
  253. unsigned char buf[17];
  254. for (i=0; i < strlen(headers->key1); i++) {
  255. if (headers->key1[i] == ' ') {
  256. spaces1 += 1;
  257. }
  258. if ((headers->key1[i] >= 48) && (headers->key1[i] <= 57)) {
  259. num1 = num1 * 10 + (headers->key1[i] - 48);
  260. }
  261. }
  262. num1 = num1 / spaces1;
  263. for (i=0; i < strlen(headers->key2); i++) {
  264. if (headers->key2[i] == ' ') {
  265. spaces2 += 1;
  266. }
  267. if ((headers->key2[i] >= 48) && (headers->key2[i] <= 57)) {
  268. num2 = num2 * 10 + (headers->key2[i] - 48);
  269. }
  270. }
  271. num2 = num2 / spaces2;
  272. /* Pack it big-endian */
  273. buf[0] = (num1 & 0xff000000) >> 24;
  274. buf[1] = (num1 & 0xff0000) >> 16;
  275. buf[2] = (num1 & 0xff00) >> 8;
  276. buf[3] = num1 & 0xff;
  277. buf[4] = (num2 & 0xff000000) >> 24;
  278. buf[5] = (num2 & 0xff0000) >> 16;
  279. buf[6] = (num2 & 0xff00) >> 8;
  280. buf[7] = num2 & 0xff;
  281. strncpy(buf+8, headers->key3, 8);
  282. buf[16] = '\0';
  283. md5_buffer(buf, 16, target);
  284. target[16] = '\0';
  285. return 1;
  286. }
  287. ws_ctx_t *do_handshake(int sock) {
  288. char handshake[4096], response[4096], trailer[17];
  289. char *scheme, *pre;
  290. headers_t headers;
  291. int len, ret;
  292. ws_ctx_t * ws_ctx;
  293. // Peek, but don't read the data
  294. len = recv(sock, handshake, 1024, MSG_PEEK);
  295. handshake[len] = 0;
  296. if (len == 0) {
  297. handler_msg("ignoring empty handshake\n");
  298. close(sock);
  299. return NULL;
  300. } else if (bcmp(handshake, "<policy-file-request/>", 22) == 0) {
  301. len = recv(sock, handshake, 1024, 0);
  302. handshake[len] = 0;
  303. handler_msg("sending flash policy response\n");
  304. send(sock, policy_response, sizeof(policy_response), 0);
  305. close(sock);
  306. return NULL;
  307. } else if (bcmp(handshake, "\x16", 1) == 0) {
  308. // SSL
  309. if (! settings.cert) { return NULL; }
  310. ws_ctx = ws_socket_ssl(sock, settings.cert);
  311. if (! ws_ctx) { return NULL; }
  312. scheme = "wss";
  313. handler_msg("using SSL socket\n");
  314. } else if (settings.ssl_only) {
  315. handler_msg("non-SSL connection disallowed\n");
  316. close(sock);
  317. return NULL;
  318. } else {
  319. ws_ctx = ws_socket(sock);
  320. if (! ws_ctx) { return NULL; }
  321. scheme = "ws";
  322. handler_msg("using plain (not SSL) socket\n");
  323. }
  324. len = ws_recv(ws_ctx, handshake, 4096);
  325. handshake[len] = 0;
  326. if (!parse_handshake(handshake, &headers)) {
  327. handler_emsg("Invalid WS request\n");
  328. close(sock);
  329. return NULL;
  330. }
  331. if (headers.key3[0] != '\0') {
  332. gen_md5(&headers, trailer);
  333. pre = "Sec-";
  334. handler_msg("using protocol version 76\n");
  335. } else {
  336. trailer[0] = '\0';
  337. pre = "";
  338. handler_msg("using protocol version 75\n");
  339. }
  340. sprintf(response, server_handshake, pre, headers.origin, pre, scheme,
  341. headers.host, headers.path, pre, trailer);
  342. //handler_msg("response: %s\n", response);
  343. ws_send(ws_ctx, response, strlen(response));
  344. return ws_ctx;
  345. }
  346. void signal_handler(sig) {
  347. switch (sig) {
  348. case SIGHUP: break; // ignore for now
  349. case SIGPIPE: pipe_error = 1; break; // handle inline
  350. case SIGTERM: exit(0); break;
  351. }
  352. }
  353. void daemonize(int keepfd) {
  354. int pid, i;
  355. umask(0);
  356. chdir('/');
  357. setgid(getgid());
  358. setuid(getuid());
  359. /* Double fork to daemonize */
  360. pid = fork();
  361. if (pid<0) { fatal("fork error"); }
  362. if (pid>0) { exit(0); } // parent exits
  363. setsid(); // Obtain new process group
  364. pid = fork();
  365. if (pid<0) { fatal("fork error"); }
  366. if (pid>0) { exit(0); } // parent exits
  367. /* Signal handling */
  368. signal(SIGHUP, signal_handler); // catch HUP
  369. signal(SIGTERM, signal_handler); // catch kill
  370. /* Close open files */
  371. for (i=getdtablesize(); i>=0; --i) {
  372. if (i != keepfd) {
  373. close(i);
  374. } else if (settings.verbose) {
  375. printf("keeping fd %d\n", keepfd);
  376. }
  377. }
  378. i=open("/dev/null", O_RDWR); // Redirect stdin
  379. dup(i); // Redirect stdout
  380. dup(i); // Redirect stderr
  381. }
  382. void start_server() {
  383. int lsock, csock, pid, clilen, sopt = 1, i;
  384. struct sockaddr_in serv_addr, cli_addr;
  385. ws_ctx_t *ws_ctx;
  386. /* Initialize buffers */
  387. bufsize = 65536;
  388. if (! (tbuf = malloc(bufsize)) )
  389. { fatal("malloc()"); }
  390. if (! (cbuf = malloc(bufsize)) )
  391. { fatal("malloc()"); }
  392. if (! (tbuf_tmp = malloc(bufsize)) )
  393. { fatal("malloc()"); }
  394. if (! (cbuf_tmp = malloc(bufsize)) )
  395. { fatal("malloc()"); }
  396. lsock = socket(AF_INET, SOCK_STREAM, 0);
  397. if (lsock < 0) { error("ERROR creating listener socket"); }
  398. bzero((char *) &serv_addr, sizeof(serv_addr));
  399. serv_addr.sin_family = AF_INET;
  400. serv_addr.sin_port = htons(settings.listen_port);
  401. /* Resolve listen address */
  402. if (settings.listen_host && (settings.listen_host[0] != '\0')) {
  403. if (resolve_host(&serv_addr.sin_addr, settings.listen_host) < -1) {
  404. fatal("Could not resolve listen address");
  405. }
  406. } else {
  407. serv_addr.sin_addr.s_addr = INADDR_ANY;
  408. }
  409. setsockopt(lsock, SOL_SOCKET, SO_REUSEADDR, (char *)&sopt, sizeof(sopt));
  410. if (bind(lsock, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) {
  411. fatal("ERROR on binding listener socket");
  412. }
  413. listen(lsock,100);
  414. signal(SIGPIPE, signal_handler); // catch pipe
  415. if (settings.daemon) {
  416. daemonize(lsock);
  417. }
  418. // Reep zombies
  419. signal(SIGCHLD, SIG_IGN);
  420. printf("Waiting for connections on %s:%d\n",
  421. settings.listen_host, settings.listen_port);
  422. while (1) {
  423. clilen = sizeof(cli_addr);
  424. pipe_error = 0;
  425. pid = 0;
  426. csock = accept(lsock,
  427. (struct sockaddr *) &cli_addr,
  428. &clilen);
  429. if (csock < 0) {
  430. error("ERROR on accept");
  431. continue;
  432. }
  433. handler_msg("got client connection from %s\n",
  434. inet_ntoa(cli_addr.sin_addr));
  435. /* base64 is 4 bytes for every 3
  436. * 20 for WS '\x00' / '\xff' and good measure */
  437. dbufsize = (bufsize * 3)/4 - 20;
  438. handler_msg("forking handler process\n");
  439. pid = fork();
  440. if (pid == 0) { // handler process
  441. ws_ctx = do_handshake(csock);
  442. if (ws_ctx == NULL) {
  443. close(csock);
  444. handler_msg("No connection after handshake");
  445. break; // Child process exits
  446. }
  447. settings.handler(ws_ctx);
  448. if (pipe_error) {
  449. handler_emsg("Closing due to SIGPIPE\n");
  450. }
  451. close(csock);
  452. handler_msg("handler exit\n");
  453. break; // Child process exits
  454. } else { // parent process
  455. settings.handler_id += 1;
  456. }
  457. }
  458. }