echoserv.cc 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606
  1. /*
  2. * Wslay - The WebSocket Library
  3. *
  4. * Copyright (c) 2011, 2012 Tatsuhiro Tsujikawa
  5. *
  6. * Permission is hereby granted, free of charge, to any person obtaining
  7. * a copy of this software and associated documentation files (the
  8. * "Software"), to deal in the Software without restriction, including
  9. * without limitation the rights to use, copy, modify, merge, publish,
  10. * distribute, sublicense, and/or sell copies of the Software, and to
  11. * permit persons to whom the Software is furnished to do so, subject to
  12. * the following conditions:
  13. *
  14. * The above copyright notice and this permission notice shall be
  15. * included in all copies or substantial portions of the Software.
  16. *
  17. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  18. * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  19. * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
  20. * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
  21. * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
  22. * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
  23. * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
  24. */
  25. // WebSocket Echo Server
  26. // This is suitable for Autobahn server test.
  27. // g++ -Wall -O2 -g -o echoserv echoserv.cc -L../lib/.libs -I../lib/includes -lwslay -lnettle
  28. // $ export LD_LIBRARY_PATH=../lib/.libs
  29. // $ ./a.out 9000
  30. #include <sys/types.h>
  31. #include <sys/socket.h>
  32. #include <netdb.h>
  33. #include <unistd.h>
  34. #include <fcntl.h>
  35. #include <sys/epoll.h>
  36. #include <netinet/in.h>
  37. #include <netinet/tcp.h>
  38. #include <signal.h>
  39. #include <cassert>
  40. #include <cstdio>
  41. #include <cerrno>
  42. #include <cstdlib>
  43. #include <cstring>
  44. #include <string>
  45. #include <iostream>
  46. #include <string>
  47. #include <set>
  48. #include <iomanip>
  49. #include <fstream>
  50. #include <nettle/base64.h>
  51. #include <nettle/sha.h>
  52. #include <wslay/wslay.h>
  53. int create_listen_socket(const char *service)
  54. {
  55. struct addrinfo hints;
  56. int sfd = -1;
  57. int r;
  58. memset(&hints, 0, sizeof(struct addrinfo));
  59. hints.ai_family = AF_UNSPEC;
  60. hints.ai_socktype = SOCK_STREAM;
  61. hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG;
  62. struct addrinfo *res;
  63. r = getaddrinfo(0, service, &hints, &res);
  64. if(r != 0) {
  65. std::cerr << "getaddrinfo: " << gai_strerror(r) << std::endl;
  66. return -1;
  67. }
  68. for(struct addrinfo *rp = res; rp; rp = rp->ai_next) {
  69. sfd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
  70. if(sfd == -1) {
  71. continue;
  72. }
  73. int val = 1;
  74. if(setsockopt(sfd, SOL_SOCKET, SO_REUSEADDR, &val,
  75. static_cast<socklen_t>(sizeof(val))) == -1) {
  76. continue;
  77. }
  78. if(bind(sfd, rp->ai_addr, rp->ai_addrlen) == 0) {
  79. break;
  80. }
  81. close(sfd);
  82. }
  83. freeaddrinfo(res);
  84. if(listen(sfd, 16) == -1) {
  85. perror("listen");
  86. close(sfd);
  87. return -1;
  88. }
  89. return sfd;
  90. }
  91. int make_non_block(int fd)
  92. {
  93. int flags, r;
  94. while((flags = fcntl(fd, F_GETFL, 0)) == -1 && errno == EINTR);
  95. if(flags == -1) {
  96. return -1;
  97. }
  98. while((r = fcntl(fd, F_SETFL, flags | O_NONBLOCK)) == -1 && errno == EINTR);
  99. if(r == -1) {
  100. return -1;
  101. }
  102. return 0;
  103. }
  104. std::string sha1(const std::string& src)
  105. {
  106. sha1_ctx ctx;
  107. sha1_init(&ctx);
  108. sha1_update(&ctx, src.size(), reinterpret_cast<const uint8_t*>(src.c_str()));
  109. uint8_t temp[SHA1_DIGEST_SIZE];
  110. sha1_digest(&ctx, SHA1_DIGEST_SIZE, temp);
  111. std::string res(&temp[0], &temp[SHA1_DIGEST_SIZE]);
  112. return res;
  113. }
  114. std::string base64(const std::string& src)
  115. {
  116. base64_encode_ctx ctx;
  117. base64_encode_init(&ctx);
  118. int dstlen = BASE64_ENCODE_RAW_LENGTH(src.size());
  119. uint8_t *dst = new uint8_t[dstlen];
  120. base64_encode_raw(dst, src.size(), reinterpret_cast<const uint8_t*>(src.c_str()));
  121. std::string res(&dst[0], &dst[dstlen]);
  122. delete [] dst;
  123. return res;
  124. }
  125. std::string create_acceptkey(const std::string& clientkey)
  126. {
  127. std::string s = clientkey+"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
  128. return base64(sha1(s));
  129. }
  130. class EventHandler {
  131. public:
  132. virtual ~EventHandler() {}
  133. virtual int on_read_event() = 0;
  134. virtual int on_write_event() = 0;
  135. virtual bool want_read() = 0;
  136. virtual bool want_write() = 0;
  137. virtual int fd() const = 0;
  138. virtual bool finish() = 0;
  139. virtual EventHandler* next() = 0;
  140. };
  141. ssize_t send_callback(wslay_event_context_ptr ctx,
  142. const uint8_t *data, size_t len, int flags,
  143. void *user_data);
  144. ssize_t recv_callback(wslay_event_context_ptr ctx, uint8_t *data, size_t len,
  145. int flags, void *user_data);
  146. void on_msg_recv_callback(wslay_event_context_ptr ctx,
  147. const struct wslay_event_on_msg_recv_arg *arg,
  148. void *user_data);
  149. class EchoWebSocketHandler : public EventHandler {
  150. public:
  151. EchoWebSocketHandler(int fd)
  152. : fd_(fd)
  153. {
  154. struct wslay_event_callbacks callbacks = {
  155. recv_callback,
  156. send_callback,
  157. NULL, /* genmask_callback */
  158. NULL, /* on_frame_recv_start_callback */
  159. NULL, /* on_frame_recv_callback */
  160. NULL, /* on_frame_recv_end_callback */
  161. on_msg_recv_callback
  162. };
  163. wslay_event_context_server_init(&ctx_, &callbacks, this);
  164. }
  165. virtual ~EchoWebSocketHandler()
  166. {
  167. wslay_event_context_free(ctx_);
  168. shutdown(fd_, SHUT_WR);
  169. close(fd_);
  170. }
  171. virtual int on_read_event()
  172. {
  173. if(wslay_event_recv(ctx_) == 0) {
  174. return 0;
  175. } else {
  176. return -1;
  177. }
  178. }
  179. virtual int on_write_event()
  180. {
  181. if(wslay_event_send(ctx_) == 0) {
  182. return 0;
  183. } else {
  184. return -1;
  185. }
  186. }
  187. ssize_t send_data(const uint8_t *data, size_t len, int flags)
  188. {
  189. ssize_t r;
  190. int sflags = 0;
  191. #ifdef MSG_MORE
  192. if(flags & WSLAY_MSG_MORE) {
  193. sflags |= MSG_MORE;
  194. }
  195. #endif // MSG_MORE
  196. while((r = send(fd_, data, len, sflags)) == -1 && errno == EINTR);
  197. return r;
  198. }
  199. ssize_t recv_data(uint8_t *data, size_t len, int flags)
  200. {
  201. ssize_t r;
  202. while((r = recv(fd_, data, len, 0)) == -1 && errno == EINTR);
  203. return r;
  204. }
  205. virtual bool want_read()
  206. {
  207. return wslay_event_want_read(ctx_);
  208. }
  209. virtual bool want_write()
  210. {
  211. return wslay_event_want_write(ctx_);
  212. }
  213. virtual int fd() const
  214. {
  215. return fd_;
  216. }
  217. virtual bool finish()
  218. {
  219. return !want_read() && !want_write();
  220. }
  221. virtual EventHandler* next()
  222. {
  223. return 0;
  224. }
  225. private:
  226. int fd_;
  227. wslay_event_context_ptr ctx_;
  228. };
  229. ssize_t send_callback(wslay_event_context_ptr ctx,
  230. const uint8_t *data, size_t len, int flags,
  231. void *user_data)
  232. {
  233. EchoWebSocketHandler *sv = (EchoWebSocketHandler*)user_data;
  234. ssize_t r = sv->send_data(data, len, flags);
  235. if(r == -1) {
  236. if(errno == EAGAIN || errno == EWOULDBLOCK) {
  237. wslay_event_set_error(ctx, WSLAY_ERR_WOULDBLOCK);
  238. } else {
  239. wslay_event_set_error(ctx, WSLAY_ERR_CALLBACK_FAILURE);
  240. }
  241. }
  242. return r;
  243. }
  244. ssize_t recv_callback(wslay_event_context_ptr ctx, uint8_t *data, size_t len,
  245. int flags, void *user_data)
  246. {
  247. EchoWebSocketHandler *sv = (EchoWebSocketHandler*)user_data;
  248. ssize_t r = sv->recv_data(data, len, flags);
  249. if(r == -1) {
  250. if(errno == EAGAIN || errno == EWOULDBLOCK) {
  251. wslay_event_set_error(ctx, WSLAY_ERR_WOULDBLOCK);
  252. } else {
  253. wslay_event_set_error(ctx, WSLAY_ERR_CALLBACK_FAILURE);
  254. }
  255. } else if(r == 0) {
  256. wslay_event_set_error(ctx, WSLAY_ERR_CALLBACK_FAILURE);
  257. r = -1;
  258. }
  259. return r;
  260. }
  261. void on_msg_recv_callback(wslay_event_context_ptr ctx,
  262. const struct wslay_event_on_msg_recv_arg *arg,
  263. void *user_data)
  264. {
  265. if(!wslay_is_ctrl_frame(arg->opcode)) {
  266. struct wslay_event_msg msgarg = {
  267. arg->opcode, arg->msg, arg->msg_length
  268. };
  269. wslay_event_queue_msg(ctx, &msgarg);
  270. }
  271. }
  272. class HttpHandshakeSendHandler : public EventHandler {
  273. public:
  274. HttpHandshakeSendHandler(int fd, const std::string& accept_key)
  275. : fd_(fd),
  276. resheaders_("HTTP/1.1 101 Switching Protocols\r\n"
  277. "Upgrade: websocket\r\n"
  278. "Connection: Upgrade\r\n"
  279. "Sec-WebSocket-Accept: "+accept_key+"\r\n"
  280. "\r\n"),
  281. off_(0)
  282. {}
  283. virtual ~HttpHandshakeSendHandler()
  284. {
  285. if(fd_ != -1) {
  286. shutdown(fd_, SHUT_WR);
  287. close(fd_);
  288. }
  289. }
  290. virtual int on_read_event()
  291. {
  292. return 0;
  293. }
  294. virtual int on_write_event()
  295. {
  296. while(1) {
  297. size_t len = resheaders_.size()-off_;
  298. if(len == 0) {
  299. break;
  300. }
  301. ssize_t r;
  302. while((r = write(fd_, resheaders_.c_str()+off_, len)) == -1 &&
  303. errno == EINTR);
  304. if(r == -1) {
  305. if(errno == EAGAIN || errno == EWOULDBLOCK) {
  306. break;
  307. } else {
  308. perror("write");
  309. return -1;
  310. }
  311. } else {
  312. off_ += r;
  313. }
  314. }
  315. return 0;
  316. }
  317. virtual bool want_read()
  318. {
  319. return false;
  320. }
  321. virtual bool want_write()
  322. {
  323. return true;
  324. }
  325. virtual int fd() const
  326. {
  327. return fd_;
  328. }
  329. virtual bool finish()
  330. {
  331. return off_ == resheaders_.size();
  332. }
  333. virtual EventHandler* next()
  334. {
  335. if(finish()) {
  336. int fd = fd_;
  337. fd_ = -1;
  338. return new EchoWebSocketHandler(fd);
  339. } else {
  340. return 0;
  341. }
  342. }
  343. private:
  344. int fd_;
  345. std::string headers_;
  346. std::string resheaders_;
  347. size_t off_;
  348. };
  349. class HttpHandshakeRecvHandler : public EventHandler {
  350. public:
  351. HttpHandshakeRecvHandler(int fd)
  352. : fd_(fd)
  353. {}
  354. virtual ~HttpHandshakeRecvHandler()
  355. {
  356. if(fd_ != -1) {
  357. close(fd_);
  358. }
  359. }
  360. virtual int on_read_event()
  361. {
  362. char buf[4096];
  363. ssize_t r;
  364. std::string client_key;
  365. while(1) {
  366. while((r = read(fd_, buf, sizeof(buf))) == -1 && errno == EINTR);
  367. if(r == -1) {
  368. if(errno == EWOULDBLOCK || errno == EAGAIN) {
  369. break;
  370. } else {
  371. perror("read");
  372. return -1;
  373. }
  374. } else if(r == 0) {
  375. std::cerr << "http_upgrade: Got EOF" << std::endl;
  376. return -1;
  377. } else {
  378. headers_.append(buf, buf+r);
  379. if(headers_.size() > 8192) {
  380. std::cerr << "Too large http header" << std::endl;
  381. return -1;
  382. }
  383. }
  384. }
  385. if(headers_.find("\r\n\r\n") != std::string::npos) {
  386. std::string::size_type keyhdstart;
  387. if(headers_.find("Upgrade: websocket\r\n") == std::string::npos ||
  388. headers_.find("Connection: Upgrade\r\n") == std::string::npos ||
  389. (keyhdstart = headers_.find("Sec-WebSocket-Key: ")) ==
  390. std::string::npos) {
  391. std::cerr << "http_upgrade: missing required headers" << std::endl;
  392. return -1;
  393. }
  394. keyhdstart += 19;
  395. std::string::size_type keyhdend = headers_.find("\r\n", keyhdstart);
  396. client_key = headers_.substr(keyhdstart, keyhdend-keyhdstart);
  397. accept_key_ = create_acceptkey(client_key);
  398. }
  399. return 0;
  400. }
  401. virtual int on_write_event()
  402. {
  403. return 0;
  404. }
  405. virtual bool want_read()
  406. {
  407. return true;
  408. }
  409. virtual bool want_write()
  410. {
  411. return false;
  412. }
  413. virtual int fd() const
  414. {
  415. return fd_;
  416. }
  417. virtual bool finish()
  418. {
  419. return !accept_key_.empty();
  420. }
  421. virtual EventHandler* next()
  422. {
  423. if(finish()) {
  424. int fd = fd_;
  425. fd_ = -1;
  426. return new HttpHandshakeSendHandler(fd, accept_key_);
  427. } else {
  428. return 0;
  429. }
  430. }
  431. private:
  432. int fd_;
  433. std::string headers_;
  434. std::string accept_key_;
  435. };
  436. class ListenEventHandler : public EventHandler {
  437. public:
  438. ListenEventHandler(int fd)
  439. : fd_(fd), cfd_(-1)
  440. {}
  441. virtual ~ListenEventHandler()
  442. {
  443. close(fd_);
  444. close(cfd_);
  445. }
  446. virtual int on_read_event()
  447. {
  448. if(cfd_ != -1) {
  449. close(cfd_);
  450. }
  451. while((cfd_ = accept(fd_, 0, 0)) == -1 && errno == EINTR);
  452. if(cfd_ == -1) {
  453. perror("accept");
  454. }
  455. return 0;
  456. }
  457. virtual int on_write_event()
  458. {
  459. return 0;
  460. }
  461. virtual bool want_read()
  462. {
  463. return true;
  464. }
  465. virtual bool want_write()
  466. {
  467. return false;
  468. }
  469. virtual int fd() const
  470. {
  471. return fd_;
  472. }
  473. virtual bool finish()
  474. {
  475. return false;
  476. }
  477. virtual EventHandler* next()
  478. {
  479. if(cfd_ != -1) {
  480. int val = 1;
  481. int fd = cfd_;
  482. cfd_ = -1;
  483. if(make_non_block(fd) == -1 ||
  484. setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &val, (socklen_t)sizeof(val))
  485. == -1) {
  486. close(fd);
  487. return 0;
  488. }
  489. return new HttpHandshakeRecvHandler(fd);
  490. } else {
  491. return 0;
  492. }
  493. }
  494. private:
  495. int fd_;
  496. int cfd_;
  497. };
  498. int ctl_epollev(int epollfd, int op, EventHandler *handler)
  499. {
  500. epoll_event ev;
  501. memset(&ev, 0, sizeof(ev));
  502. int events = 0;
  503. if(handler->want_read()) {
  504. events |= EPOLLIN;
  505. }
  506. if(handler->want_write()) {
  507. events |= EPOLLOUT;
  508. }
  509. ev.events = events;
  510. ev.data.ptr = handler;
  511. return epoll_ctl(epollfd, op, handler->fd(), &ev);
  512. }
  513. void reactor(int sfd)
  514. {
  515. std::set<EventHandler*> handlers;
  516. ListenEventHandler* listen_handler = new ListenEventHandler(sfd);
  517. handlers.insert(listen_handler);
  518. int epollfd = epoll_create(16);
  519. if(epollfd == -1) {
  520. perror("epoll_create");
  521. exit(EXIT_FAILURE);
  522. }
  523. if(ctl_epollev(epollfd, EPOLL_CTL_ADD, listen_handler) == -1) {
  524. perror("epoll_ctl");
  525. exit(EXIT_FAILURE);
  526. }
  527. static const size_t MAX_EVENTS = 64;
  528. epoll_event events[MAX_EVENTS];
  529. while(1) {
  530. int nfds = epoll_wait(epollfd, events, MAX_EVENTS, -1);
  531. if(nfds == -1) {
  532. perror("epoll_wait");
  533. return;
  534. }
  535. for(int n = 0; n < nfds; ++n) {
  536. EventHandler* eh = (EventHandler*)events[n].data.ptr;
  537. if(((events[n].events & EPOLLIN) && eh->on_read_event() == -1) ||
  538. ((events[n].events & EPOLLOUT) && eh->on_write_event() == -1) ||
  539. (events[n].events & (EPOLLERR | EPOLLHUP))) {
  540. handlers.erase(eh);
  541. delete eh;
  542. } else {
  543. EventHandler* next = eh->next();
  544. if(next) {
  545. handlers.insert(next);
  546. if(ctl_epollev(epollfd, EPOLL_CTL_ADD, next) == -1) {
  547. if(errno == EEXIST) {
  548. if(ctl_epollev(epollfd, EPOLL_CTL_MOD, next) == -1) {
  549. perror("epoll_ctl");
  550. delete next;
  551. }
  552. } else {
  553. perror("epoll_ctl");
  554. delete next;
  555. }
  556. }
  557. }
  558. if(eh->finish()) {
  559. handlers.erase(eh);
  560. delete eh;
  561. } else {
  562. if(ctl_epollev(epollfd, EPOLL_CTL_MOD, eh) == -1) {
  563. perror("epoll_ctl");
  564. }
  565. }
  566. }
  567. }
  568. }
  569. }
  570. int main(int argc, char **argv)
  571. {
  572. if(argc < 2) {
  573. std::cerr << "Usage: " << argv[0] << " PORT" << std::endl;
  574. exit(EXIT_FAILURE);
  575. }
  576. struct sigaction act;
  577. memset(&act, 0, sizeof(struct sigaction));
  578. act.sa_handler = SIG_IGN;
  579. sigaction(SIGPIPE, &act, 0);
  580. int sfd = create_listen_socket(argv[1]);
  581. if(sfd == -1) {
  582. std::cerr << "Failed to create server socket" << std::endl;
  583. exit(EXIT_FAILURE);
  584. }
  585. std::cout << "WebSocket echo server, listening on " << argv[1] << std::endl;
  586. reactor(sfd);
  587. }