SocketCore.cc 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. /* <!-- copyright */
  2. /*
  3. * aria2 - a simple utility for downloading files faster
  4. *
  5. * Copyright (C) 2006 Tatsuhiro Tsujikawa
  6. *
  7. * This program is free software; you can redistribute it and/or modify
  8. * it under the terms of the GNU General Public License as published by
  9. * the Free Software Foundation; either version 2 of the License, or
  10. * (at your option) any later version.
  11. *
  12. * This program is distributed in the hope that it will be useful,
  13. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  14. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  15. * GNU General Public License for more details.
  16. *
  17. * You should have received a copy of the GNU General Public License
  18. * along with this program; if not, write to the Free Software
  19. * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
  20. */
  21. /* copyright --> */
  22. #include "SocketCore.h"
  23. #include <unistd.h>
  24. #include <fcntl.h>
  25. #include <netdb.h>
  26. #include <sys/types.h>
  27. #include <sys/socket.h>
  28. #include <netinet/in.h>
  29. #include <arpa/inet.h>
  30. #include <sys/time.h>
  31. #include <netdb.h>
  32. #include "DlRetryEx.h"
  33. #include "DlAbortEx.h"
  34. #include <errno.h>
  35. #include "message.h"
  36. SocketCore::SocketCore():sockfd(-1) {
  37. init();
  38. }
  39. SocketCore::SocketCore(int sockfd):sockfd(sockfd) {
  40. init();
  41. }
  42. void SocketCore::init() {
  43. use = 1;
  44. secure = false;
  45. #ifdef HAVE_LIBSSL
  46. // for SSL
  47. sslCtx = NULL;
  48. ssl = NULL;
  49. #endif // HAVE_LIBSSL
  50. }
  51. SocketCore::~SocketCore() {
  52. closeConnection();
  53. }
  54. void SocketCore::beginListen() {
  55. closeConnection();
  56. //sockfd = socket(AF_UNSPEC, SOCK_STREAM, PF_UNSPEC);
  57. sockfd = socket(AF_INET, SOCK_STREAM, 0);
  58. if(sockfd == -1) {
  59. throw new DlAbortEx(strerror(errno));
  60. }
  61. socklen_t sockopt = 1;
  62. if(setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &sockopt, sizeof(socklen_t)) < 0) {
  63. close(sockfd);
  64. sockfd = -1;
  65. throw new DlAbortEx(strerror(errno));
  66. }
  67. struct sockaddr_in sockaddr;
  68. memset((char*)&sockaddr, 0, sizeof(sockaddr));
  69. sockaddr.sin_family = AF_INET;
  70. sockaddr.sin_addr.s_addr = INADDR_ANY;
  71. sockaddr.sin_port = htons(0);
  72. if(bind(sockfd, (struct sockaddr*)&sockaddr, sizeof(sockaddr)) == -1) {
  73. throw new DlAbortEx(strerror(errno));
  74. }
  75. if(listen(sockfd, 1) == -1) {
  76. throw new DlAbortEx(strerror(errno));
  77. }
  78. }
  79. SocketCore* SocketCore::acceptConnection() const {
  80. struct sockaddr_in sockaddr;
  81. socklen_t len = sizeof(sockaddr);
  82. memset((char*)&sockaddr, 0, sizeof(sockaddr));
  83. int fd;
  84. if((fd = accept(sockfd, (struct sockaddr*)&sockaddr, &len)) == -1) {
  85. throw new DlAbortEx(strerror(errno));
  86. }
  87. SocketCore* s = new SocketCore(fd);
  88. return s;
  89. }
  90. void SocketCore::getAddrInfo(pair<string, int>& addrinfo) const {
  91. struct sockaddr_in listenaddr;
  92. memset((char*)&listenaddr, 0, sizeof(listenaddr));
  93. socklen_t len = sizeof(listenaddr);
  94. if(getsockname(sockfd, (struct sockaddr*)&listenaddr, &len) == -1) {
  95. throw new DlAbortEx(strerror(errno));
  96. }
  97. addrinfo.first = inet_ntoa(listenaddr.sin_addr);
  98. addrinfo.second = ntohs(listenaddr.sin_port);
  99. }
  100. void SocketCore::establishConnection(string host, int port) {
  101. closeConnection();
  102. sockfd = socket(AF_INET, SOCK_STREAM, 0);
  103. if(sockfd == -1) {
  104. throw new DlAbortEx(strerror(errno));
  105. }
  106. socklen_t sockopt = 1;
  107. if(setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &sockopt, sizeof(socklen_t)) < 0) {
  108. close(sockfd);
  109. sockfd = -1;
  110. throw new DlAbortEx(strerror(errno));
  111. }
  112. struct sockaddr_in sockaddr;
  113. memset((char*)&sockaddr, 0, sizeof(sockaddr));
  114. sockaddr.sin_family = AF_INET;
  115. sockaddr.sin_port = htons(port);
  116. if(inet_aton(host.c_str(), &sockaddr.sin_addr)) {
  117. // ok
  118. } else {
  119. struct addrinfo ai;
  120. ai.ai_flags = 0;
  121. ai.ai_family = PF_INET;
  122. ai.ai_socktype = SOCK_STREAM;
  123. ai.ai_protocol = 0;
  124. ai.ai_addr = (struct sockaddr*)&sockaddr;
  125. struct addrinfo* res;
  126. int ec;
  127. if((ec = getaddrinfo(host.c_str(), NULL, &ai, &res)) != 0) {
  128. throw new DlAbortEx(gai_strerror(ec));
  129. }
  130. sockaddr.sin_addr = ((struct sockaddr_in*)res->ai_addr)->sin_addr;
  131. freeaddrinfo(res);
  132. }
  133. // make socket non-blocking mode
  134. int flags = fcntl(sockfd, F_GETFL, 0);
  135. fcntl(sockfd, F_SETFL, flags|O_NONBLOCK);
  136. if(connect(sockfd, (struct sockaddr*)&sockaddr, (socklen_t)sizeof(sockaddr)) == -1 && errno != EINPROGRESS) {
  137. throw new DlAbortEx(strerror(errno));
  138. }
  139. }
  140. void SocketCore::setBlockingMode() const {
  141. int flags = fcntl(sockfd, F_GETFL, 0);
  142. fcntl(sockfd, F_SETFL, flags&~O_NONBLOCK);
  143. }
  144. void SocketCore::closeConnection() {
  145. #ifdef HAVE_LIBSSL
  146. // for SSL
  147. if(secure && ssl != NULL) {
  148. SSL_shutdown(ssl);
  149. }
  150. #endif // HAVE_LIBSSL
  151. if(sockfd != -1) {
  152. close(sockfd);
  153. sockfd = -1;
  154. }
  155. #ifdef HAVE_LIBSSL
  156. // for SSL
  157. if(secure && ssl != NULL) {
  158. SSL_free(ssl);
  159. SSL_CTX_free(sslCtx);
  160. ssl = NULL;
  161. sslCtx = NULL;
  162. }
  163. #endif // HAVE_LIBSSL
  164. }
  165. bool SocketCore::isWritable(int timeout) const {
  166. fd_set fds;
  167. FD_ZERO(&fds);
  168. FD_SET(sockfd, &fds);
  169. struct timeval tv;
  170. tv.tv_sec = timeout;
  171. tv.tv_usec = 0;
  172. int r = select(sockfd+1, NULL, &fds, NULL, &tv);
  173. if(r == 1) {
  174. return true;
  175. } else if(r == 0) {
  176. // time out
  177. return false;
  178. } else {
  179. if(errno == EINPROGRESS) {
  180. return false;
  181. } else {
  182. throw new DlRetryEx(strerror(errno));
  183. }
  184. }
  185. }
  186. bool SocketCore::isReadable(int timeout) const {
  187. fd_set fds;
  188. FD_ZERO(&fds);
  189. FD_SET(sockfd, &fds);
  190. struct timeval tv;
  191. tv.tv_sec = timeout;
  192. tv.tv_usec = 0;
  193. int r = select(sockfd+1, &fds, NULL, NULL, &tv);
  194. if(r == 1) {
  195. return true;
  196. } else if(r == 0) {
  197. // time out
  198. return false;
  199. } else {
  200. if(errno == EINPROGRESS) {
  201. return false;
  202. } else {
  203. throw new DlRetryEx(strerror(errno));
  204. }
  205. }
  206. }
  207. void SocketCore::writeData(const char* data, int len, int timeout) const {
  208. if(!isWritable(timeout) ||
  209. !secure && send(sockfd, data, (size_t)len, 0) != len
  210. #ifdef HAVE_LIBSSL
  211. // for SSL
  212. // TODO handling len == 0 case required
  213. || secure && SSL_write(ssl, data, len) != len
  214. #endif // HAVE_LIBSSL
  215. ) {
  216. throw new DlRetryEx(strerror(errno));
  217. }
  218. }
  219. void SocketCore::readData(char* data, int& len, int timeout) const {
  220. if(!isReadable(timeout) ||
  221. !secure && (len = recv(sockfd, data, (size_t)len, 0)) < 0
  222. #ifdef HAVE_LIBSSL
  223. // for SSL
  224. // TODO handling len == 0 case required
  225. || secure && (len = SSL_read(ssl, data, len)) < 0
  226. #endif // HAVE_LIBSSL
  227. ) {
  228. throw new DlRetryEx(strerror(errno));
  229. }
  230. }
  231. void SocketCore::peekData(char* data, int& len, int timeout) const {
  232. if(!isReadable(timeout) ||
  233. !secure && (len = recv(sockfd, data, (size_t)len, MSG_PEEK)) < 0
  234. #ifdef HAVE_LIBSSL
  235. // for SSL
  236. // TODO handling len == 0 case required
  237. || secure && (len == SSL_peek(ssl, data, len)) < 0
  238. #endif // HAVE_LIBSSL
  239. ) {
  240. throw new DlRetryEx(strerror(errno));
  241. }
  242. }
  243. void SocketCore::initiateSecureConnection() {
  244. #ifdef HAVE_LIBSSL
  245. // for SSL
  246. if(!secure) {
  247. sslCtx = SSL_CTX_new(SSLv23_client_method());
  248. if(sslCtx == NULL) {
  249. throw new DlAbortEx(EX_SSL_INIT_FAILURE);
  250. }
  251. SSL_CTX_set_mode(sslCtx, SSL_MODE_AUTO_RETRY);
  252. ssl = SSL_new(sslCtx);
  253. if(ssl == NULL) {
  254. throw new DlAbortEx(EX_SSL_INIT_FAILURE);
  255. }
  256. if(SSL_set_fd(ssl, sockfd) == 0) {
  257. throw new DlAbortEx(EX_SSL_INIT_FAILURE);
  258. }
  259. // TODO handling return value == 0 case required
  260. if(SSL_connect(ssl) <= 0) {
  261. throw new DlAbortEx(EX_SSL_INIT_FAILURE);
  262. }
  263. secure = true;
  264. }
  265. #endif // HAVE_LIBSSL
  266. }