echoserv.cc 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. /*
  2. * Wslay - The WebSocket Library
  3. *
  4. * Copyright (c) 2011, 2012 Tatsuhiro Tsujikawa
  5. *
  6. * Permission is hereby granted, free of charge, to any person obtaining
  7. * a copy of this software and associated documentation files (the
  8. * "Software"), to deal in the Software without restriction, including
  9. * without limitation the rights to use, copy, modify, merge, publish,
  10. * distribute, sublicense, and/or sell copies of the Software, and to
  11. * permit persons to whom the Software is furnished to do so, subject to
  12. * the following conditions:
  13. *
  14. * The above copyright notice and this permission notice shall be
  15. * included in all copies or substantial portions of the Software.
  16. *
  17. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  18. * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  19. * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
  20. * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
  21. * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
  22. * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
  23. * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
  24. */
  25. // WebSocket Echo Server
  26. // This is suitable for Autobahn server test.
  27. // g++ -Wall -O2 -g -o echoserv echoserv.cc -L../lib/.libs -I../lib/includes -lwslay -lnettle
  28. // $ export LD_LIBRARY_PATH=../lib/.libs
  29. // $ ./a.out 9000
  30. #include <sys/types.h>
  31. #include <sys/socket.h>
  32. #include <netdb.h>
  33. #include <unistd.h>
  34. #include <fcntl.h>
  35. #include <sys/epoll.h>
  36. #include <netinet/in.h>
  37. #include <netinet/tcp.h>
  38. #include <signal.h>
  39. #include <cassert>
  40. #include <cstdio>
  41. #include <cerrno>
  42. #include <cstdlib>
  43. #include <cstring>
  44. #include <string>
  45. #include <iostream>
  46. #include <string>
  47. #include <set>
  48. #include <iomanip>
  49. #include <fstream>
  50. #include <nettle/base64.h>
  51. #include <nettle/sha.h>
  52. #include <wslay/wslay.h>
  53. int create_listen_socket(const char *service) {
  54. struct addrinfo hints;
  55. int sfd = -1;
  56. int r;
  57. memset(&hints, 0, sizeof(struct addrinfo));
  58. hints.ai_family = AF_UNSPEC;
  59. hints.ai_socktype = SOCK_STREAM;
  60. hints.ai_flags = AI_PASSIVE | AI_ADDRCONFIG;
  61. struct addrinfo *res;
  62. r = getaddrinfo(0, service, &hints, &res);
  63. if (r != 0) {
  64. std::cerr << "getaddrinfo: " << gai_strerror(r) << std::endl;
  65. return -1;
  66. }
  67. for (struct addrinfo *rp = res; rp; rp = rp->ai_next) {
  68. sfd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol);
  69. if (sfd == -1) {
  70. continue;
  71. }
  72. int val = 1;
  73. if (setsockopt(sfd, SOL_SOCKET, SO_REUSEADDR, &val,
  74. static_cast<socklen_t>(sizeof(val))) == -1) {
  75. continue;
  76. }
  77. if (bind(sfd, rp->ai_addr, rp->ai_addrlen) == 0) {
  78. break;
  79. }
  80. close(sfd);
  81. }
  82. freeaddrinfo(res);
  83. if (listen(sfd, 16) == -1) {
  84. perror("listen");
  85. close(sfd);
  86. return -1;
  87. }
  88. return sfd;
  89. }
  90. int make_non_block(int fd) {
  91. int flags, r;
  92. while ((flags = fcntl(fd, F_GETFL, 0)) == -1 && errno == EINTR)
  93. ;
  94. if (flags == -1) {
  95. return -1;
  96. }
  97. while ((r = fcntl(fd, F_SETFL, flags | O_NONBLOCK)) == -1 && errno == EINTR)
  98. ;
  99. if (r == -1) {
  100. return -1;
  101. }
  102. return 0;
  103. }
  104. std::string sha1(const std::string &src) {
  105. sha1_ctx ctx;
  106. sha1_init(&ctx);
  107. sha1_update(&ctx, src.size(), reinterpret_cast<const uint8_t *>(src.c_str()));
  108. uint8_t temp[SHA1_DIGEST_SIZE];
  109. sha1_digest(&ctx, SHA1_DIGEST_SIZE, temp);
  110. std::string res(&temp[0], &temp[SHA1_DIGEST_SIZE]);
  111. return res;
  112. }
  113. std::string base64(const std::string &src) {
  114. base64_encode_ctx ctx;
  115. base64_encode_init(&ctx);
  116. int dstlen = BASE64_ENCODE_RAW_LENGTH(src.size());
  117. char *dst = new char[dstlen];
  118. base64_encode_raw(dst, src.size(),
  119. reinterpret_cast<const uint8_t *>(src.c_str()));
  120. std::string res(&dst[0], &dst[dstlen]);
  121. delete[] dst;
  122. return res;
  123. }
  124. std::string create_acceptkey(const std::string &clientkey) {
  125. std::string s = clientkey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
  126. return base64(sha1(s));
  127. }
  128. class EventHandler {
  129. public:
  130. virtual ~EventHandler() {}
  131. virtual int on_read_event() = 0;
  132. virtual int on_write_event() = 0;
  133. virtual bool want_read() = 0;
  134. virtual bool want_write() = 0;
  135. virtual int fd() const = 0;
  136. virtual bool finish() = 0;
  137. virtual EventHandler *next() = 0;
  138. };
  139. ssize_t send_callback(wslay_event_context_ptr ctx, const uint8_t *data,
  140. size_t len, int flags, void *user_data);
  141. ssize_t recv_callback(wslay_event_context_ptr ctx, uint8_t *data, size_t len,
  142. int flags, void *user_data);
  143. void on_msg_recv_callback(wslay_event_context_ptr ctx,
  144. const struct wslay_event_on_msg_recv_arg *arg,
  145. void *user_data);
  146. class EchoWebSocketHandler : public EventHandler {
  147. public:
  148. EchoWebSocketHandler(int fd) : fd_(fd) {
  149. struct wslay_event_callbacks callbacks = {
  150. recv_callback,
  151. send_callback,
  152. NULL, /* genmask_callback */
  153. NULL, /* on_frame_recv_start_callback */
  154. NULL, /* on_frame_recv_callback */
  155. NULL, /* on_frame_recv_end_callback */
  156. on_msg_recv_callback};
  157. wslay_event_context_server_init(&ctx_, &callbacks, this);
  158. }
  159. virtual ~EchoWebSocketHandler() {
  160. wslay_event_context_free(ctx_);
  161. shutdown(fd_, SHUT_WR);
  162. close(fd_);
  163. }
  164. virtual int on_read_event() {
  165. if (wslay_event_recv(ctx_) == 0) {
  166. return 0;
  167. } else {
  168. return -1;
  169. }
  170. }
  171. virtual int on_write_event() {
  172. if (wslay_event_send(ctx_) == 0) {
  173. return 0;
  174. } else {
  175. return -1;
  176. }
  177. }
  178. ssize_t send_data(const uint8_t *data, size_t len, int flags) {
  179. ssize_t r;
  180. int sflags = 0;
  181. #ifdef MSG_MORE
  182. if (flags & WSLAY_MSG_MORE) {
  183. sflags |= MSG_MORE;
  184. }
  185. #endif // MSG_MORE
  186. while ((r = send(fd_, data, len, sflags)) == -1 && errno == EINTR)
  187. ;
  188. return r;
  189. }
  190. ssize_t recv_data(uint8_t *data, size_t len, int flags) {
  191. ssize_t r;
  192. while ((r = recv(fd_, data, len, 0)) == -1 && errno == EINTR)
  193. ;
  194. return r;
  195. }
  196. virtual bool want_read() { return wslay_event_want_read(ctx_); }
  197. virtual bool want_write() { return wslay_event_want_write(ctx_); }
  198. virtual int fd() const { return fd_; }
  199. virtual bool finish() { return !want_read() && !want_write(); }
  200. virtual EventHandler *next() { return 0; }
  201. private:
  202. int fd_;
  203. wslay_event_context_ptr ctx_;
  204. };
  205. ssize_t send_callback(wslay_event_context_ptr ctx, const uint8_t *data,
  206. size_t len, int flags, void *user_data) {
  207. EchoWebSocketHandler *sv = (EchoWebSocketHandler *)user_data;
  208. ssize_t r = sv->send_data(data, len, flags);
  209. if (r == -1) {
  210. if (errno == EAGAIN || errno == EWOULDBLOCK) {
  211. wslay_event_set_error(ctx, WSLAY_ERR_WOULDBLOCK);
  212. } else {
  213. wslay_event_set_error(ctx, WSLAY_ERR_CALLBACK_FAILURE);
  214. }
  215. }
  216. return r;
  217. }
  218. ssize_t recv_callback(wslay_event_context_ptr ctx, uint8_t *data, size_t len,
  219. int flags, void *user_data) {
  220. EchoWebSocketHandler *sv = (EchoWebSocketHandler *)user_data;
  221. ssize_t r = sv->recv_data(data, len, flags);
  222. if (r == -1) {
  223. if (errno == EAGAIN || errno == EWOULDBLOCK) {
  224. wslay_event_set_error(ctx, WSLAY_ERR_WOULDBLOCK);
  225. } else {
  226. wslay_event_set_error(ctx, WSLAY_ERR_CALLBACK_FAILURE);
  227. }
  228. } else if (r == 0) {
  229. wslay_event_set_error(ctx, WSLAY_ERR_CALLBACK_FAILURE);
  230. r = -1;
  231. }
  232. return r;
  233. }
  234. void on_msg_recv_callback(wslay_event_context_ptr ctx,
  235. const struct wslay_event_on_msg_recv_arg *arg,
  236. void *user_data) {
  237. if (!wslay_is_ctrl_frame(arg->opcode)) {
  238. struct wslay_event_msg msgarg = {arg->opcode, arg->msg, arg->msg_length};
  239. wslay_event_queue_msg(ctx, &msgarg);
  240. }
  241. }
  242. class HttpHandshakeSendHandler : public EventHandler {
  243. public:
  244. HttpHandshakeSendHandler(int fd, const std::string &accept_key)
  245. : fd_(fd),
  246. resheaders_("HTTP/1.1 101 Switching Protocols\r\n"
  247. "Upgrade: websocket\r\n"
  248. "Connection: Upgrade\r\n"
  249. "Sec-WebSocket-Accept: " +
  250. accept_key +
  251. "\r\n"
  252. "\r\n"),
  253. off_(0) {}
  254. virtual ~HttpHandshakeSendHandler() {
  255. if (fd_ != -1) {
  256. shutdown(fd_, SHUT_WR);
  257. close(fd_);
  258. }
  259. }
  260. virtual int on_read_event() { return 0; }
  261. virtual int on_write_event() {
  262. while (1) {
  263. size_t len = resheaders_.size() - off_;
  264. if (len == 0) {
  265. break;
  266. }
  267. ssize_t r;
  268. while ((r = write(fd_, resheaders_.c_str() + off_, len)) == -1 &&
  269. errno == EINTR)
  270. ;
  271. if (r == -1) {
  272. if (errno == EAGAIN || errno == EWOULDBLOCK) {
  273. break;
  274. } else {
  275. perror("write");
  276. return -1;
  277. }
  278. } else {
  279. off_ += r;
  280. }
  281. }
  282. return 0;
  283. }
  284. virtual bool want_read() { return false; }
  285. virtual bool want_write() { return true; }
  286. virtual int fd() const { return fd_; }
  287. virtual bool finish() { return off_ == resheaders_.size(); }
  288. virtual EventHandler *next() {
  289. if (finish()) {
  290. int fd = fd_;
  291. fd_ = -1;
  292. return new EchoWebSocketHandler(fd);
  293. } else {
  294. return 0;
  295. }
  296. }
  297. private:
  298. int fd_;
  299. std::string headers_;
  300. std::string resheaders_;
  301. size_t off_;
  302. };
  303. class HttpHandshakeRecvHandler : public EventHandler {
  304. public:
  305. HttpHandshakeRecvHandler(int fd) : fd_(fd) {}
  306. virtual ~HttpHandshakeRecvHandler() {
  307. if (fd_ != -1) {
  308. close(fd_);
  309. }
  310. }
  311. virtual int on_read_event() {
  312. char buf[4096];
  313. ssize_t r;
  314. std::string client_key;
  315. while (1) {
  316. while ((r = read(fd_, buf, sizeof(buf))) == -1 && errno == EINTR)
  317. ;
  318. if (r == -1) {
  319. if (errno == EWOULDBLOCK || errno == EAGAIN) {
  320. break;
  321. } else {
  322. perror("read");
  323. return -1;
  324. }
  325. } else if (r == 0) {
  326. std::cerr << "http_upgrade: Got EOF" << std::endl;
  327. return -1;
  328. } else {
  329. headers_.append(buf, buf + r);
  330. if (headers_.size() > 8192) {
  331. std::cerr << "Too large http header" << std::endl;
  332. return -1;
  333. }
  334. }
  335. }
  336. if (headers_.find("\r\n\r\n") != std::string::npos) {
  337. std::string::size_type keyhdstart;
  338. if (headers_.find("Upgrade: websocket\r\n") == std::string::npos ||
  339. headers_.find("Connection: Upgrade\r\n") == std::string::npos ||
  340. (keyhdstart = headers_.find("Sec-WebSocket-Key: ")) ==
  341. std::string::npos) {
  342. std::cerr << "http_upgrade: missing required headers" << std::endl;
  343. return -1;
  344. }
  345. keyhdstart += 19;
  346. std::string::size_type keyhdend = headers_.find("\r\n", keyhdstart);
  347. client_key = headers_.substr(keyhdstart, keyhdend - keyhdstart);
  348. accept_key_ = create_acceptkey(client_key);
  349. }
  350. return 0;
  351. }
  352. virtual int on_write_event() { return 0; }
  353. virtual bool want_read() { return true; }
  354. virtual bool want_write() { return false; }
  355. virtual int fd() const { return fd_; }
  356. virtual bool finish() { return !accept_key_.empty(); }
  357. virtual EventHandler *next() {
  358. if (finish()) {
  359. int fd = fd_;
  360. fd_ = -1;
  361. return new HttpHandshakeSendHandler(fd, accept_key_);
  362. } else {
  363. return 0;
  364. }
  365. }
  366. private:
  367. int fd_;
  368. std::string headers_;
  369. std::string accept_key_;
  370. };
  371. class ListenEventHandler : public EventHandler {
  372. public:
  373. ListenEventHandler(int fd) : fd_(fd), cfd_(-1) {}
  374. virtual ~ListenEventHandler() {
  375. close(fd_);
  376. close(cfd_);
  377. }
  378. virtual int on_read_event() {
  379. if (cfd_ != -1) {
  380. close(cfd_);
  381. }
  382. while ((cfd_ = accept(fd_, 0, 0)) == -1 && errno == EINTR)
  383. ;
  384. if (cfd_ == -1) {
  385. perror("accept");
  386. }
  387. return 0;
  388. }
  389. virtual int on_write_event() { return 0; }
  390. virtual bool want_read() { return true; }
  391. virtual bool want_write() { return false; }
  392. virtual int fd() const { return fd_; }
  393. virtual bool finish() { return false; }
  394. virtual EventHandler *next() {
  395. if (cfd_ != -1) {
  396. int val = 1;
  397. int fd = cfd_;
  398. cfd_ = -1;
  399. if (make_non_block(fd) == -1 ||
  400. setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &val,
  401. (socklen_t)sizeof(val)) == -1) {
  402. close(fd);
  403. return 0;
  404. }
  405. return new HttpHandshakeRecvHandler(fd);
  406. } else {
  407. return 0;
  408. }
  409. }
  410. private:
  411. int fd_;
  412. int cfd_;
  413. };
  414. int ctl_epollev(int epollfd, int op, EventHandler *handler) {
  415. epoll_event ev;
  416. memset(&ev, 0, sizeof(ev));
  417. int events = 0;
  418. if (handler->want_read()) {
  419. events |= EPOLLIN;
  420. }
  421. if (handler->want_write()) {
  422. events |= EPOLLOUT;
  423. }
  424. ev.events = events;
  425. ev.data.ptr = handler;
  426. return epoll_ctl(epollfd, op, handler->fd(), &ev);
  427. }
  428. void reactor(int sfd) {
  429. std::set<EventHandler *> handlers;
  430. ListenEventHandler *listen_handler = new ListenEventHandler(sfd);
  431. handlers.insert(listen_handler);
  432. int epollfd = epoll_create(16);
  433. if (epollfd == -1) {
  434. perror("epoll_create");
  435. exit(EXIT_FAILURE);
  436. }
  437. if (ctl_epollev(epollfd, EPOLL_CTL_ADD, listen_handler) == -1) {
  438. perror("epoll_ctl");
  439. exit(EXIT_FAILURE);
  440. }
  441. static const size_t MAX_EVENTS = 64;
  442. epoll_event events[MAX_EVENTS];
  443. while (1) {
  444. int nfds = epoll_wait(epollfd, events, MAX_EVENTS, -1);
  445. if (nfds == -1) {
  446. perror("epoll_wait");
  447. return;
  448. }
  449. for (int n = 0; n < nfds; ++n) {
  450. EventHandler *eh = (EventHandler *)events[n].data.ptr;
  451. if (((events[n].events & EPOLLIN) && eh->on_read_event() == -1) ||
  452. ((events[n].events & EPOLLOUT) && eh->on_write_event() == -1) ||
  453. (events[n].events & (EPOLLERR | EPOLLHUP))) {
  454. handlers.erase(eh);
  455. delete eh;
  456. } else {
  457. EventHandler *next = eh->next();
  458. if (next) {
  459. handlers.insert(next);
  460. if (ctl_epollev(epollfd, EPOLL_CTL_ADD, next) == -1) {
  461. if (errno == EEXIST) {
  462. if (ctl_epollev(epollfd, EPOLL_CTL_MOD, next) == -1) {
  463. perror("epoll_ctl");
  464. delete next;
  465. }
  466. } else {
  467. perror("epoll_ctl");
  468. delete next;
  469. }
  470. }
  471. }
  472. if (eh->finish()) {
  473. handlers.erase(eh);
  474. delete eh;
  475. } else {
  476. if (ctl_epollev(epollfd, EPOLL_CTL_MOD, eh) == -1) {
  477. perror("epoll_ctl");
  478. }
  479. }
  480. }
  481. }
  482. }
  483. }
  484. int main(int argc, char **argv) {
  485. if (argc < 2) {
  486. std::cerr << "Usage: " << argv[0] << " PORT" << std::endl;
  487. exit(EXIT_FAILURE);
  488. }
  489. struct sigaction act;
  490. memset(&act, 0, sizeof(struct sigaction));
  491. act.sa_handler = SIG_IGN;
  492. sigaction(SIGPIPE, &act, 0);
  493. int sfd = create_listen_socket(argv[1]);
  494. if (sfd == -1) {
  495. std::cerr << "Failed to create server socket" << std::endl;
  496. exit(EXIT_FAILURE);
  497. }
  498. std::cout << "WebSocket echo server, listening on " << argv[1] << std::endl;
  499. reactor(sfd);
  500. }