SocketCore.cc 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. /* <!-- copyright */
  2. /*
  3. * aria2 - a simple utility for downloading files faster
  4. *
  5. * Copyright (C) 2006 Tatsuhiro Tsujikawa
  6. *
  7. * This program is free software; you can redistribute it and/or modify
  8. * it under the terms of the GNU General Public License as published by
  9. * the Free Software Foundation; either version 2 of the License, or
  10. * (at your option) any later version.
  11. *
  12. * This program is distributed in the hope that it will be useful,
  13. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  14. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  15. * GNU General Public License for more details.
  16. *
  17. * You should have received a copy of the GNU General Public License
  18. * along with this program; if not, write to the Free Software
  19. * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
  20. */
  21. /* copyright --> */
  22. #include "SocketCore.h"
  23. #include "DlRetryEx.h"
  24. #include "DlAbortEx.h"
  25. #include "message.h"
  26. #include <unistd.h>
  27. #include <fcntl.h>
  28. #include <netdb.h>
  29. #include <sys/types.h>
  30. #include <sys/socket.h>
  31. #include <netinet/in.h>
  32. #include <arpa/inet.h>
  33. #include <sys/time.h>
  34. #include <netdb.h>
  35. #include <errno.h>
  36. SocketCore::SocketCore():sockfd(-1) {
  37. init();
  38. }
  39. SocketCore::SocketCore(int sockfd):sockfd(sockfd) {
  40. init();
  41. }
  42. void SocketCore::init() {
  43. use = 1;
  44. secure = false;
  45. #ifdef HAVE_LIBSSL
  46. // for SSL
  47. sslCtx = NULL;
  48. ssl = NULL;
  49. #endif // HAVE_LIBSSL
  50. #ifdef HAVE_LIBGNUTLS
  51. sslSession = NULL;
  52. sslXcred = NULL;
  53. peekBufMax = 4096;
  54. peekBuf = new char[peekBufMax];
  55. peekBufLength = 0;
  56. #endif //HAVE_LIBGNUTLS
  57. }
  58. SocketCore::~SocketCore() {
  59. closeConnection();
  60. #ifdef HAVE_LIBGNUTLS
  61. delete [] peekBuf;
  62. #endif // HAVE_LIBGNUTLS
  63. }
  64. void SocketCore::beginListen(int port) {
  65. closeConnection();
  66. //sockfd = socket(AF_UNSPEC, SOCK_STREAM, PF_UNSPEC);
  67. sockfd = socket(AF_INET, SOCK_STREAM, 0);
  68. if(sockfd == -1) {
  69. throw new DlAbortEx(EX_SOCKET_OPEN, strerror(errno));
  70. }
  71. socklen_t sockopt = 1;
  72. if(setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &sockopt, sizeof(socklen_t)) < 0) {
  73. close(sockfd);
  74. sockfd = -1;
  75. throw new DlAbortEx(EX_SOCKET_SET_OPT, strerror(errno));
  76. }
  77. struct sockaddr_in sockaddr;
  78. memset((char*)&sockaddr, 0, sizeof(sockaddr));
  79. sockaddr.sin_family = AF_INET;
  80. sockaddr.sin_addr.s_addr = INADDR_ANY;
  81. sockaddr.sin_port = htons(port);
  82. if(bind(sockfd, (struct sockaddr*)&sockaddr, sizeof(sockaddr)) == -1) {
  83. throw new DlAbortEx(EX_SOCKET_BIND, strerror(errno));
  84. }
  85. if(listen(sockfd, 1) == -1) {
  86. throw new DlAbortEx(EX_SOCKET_LISTEN, strerror(errno));
  87. }
  88. }
  89. SocketCore* SocketCore::acceptConnection() const {
  90. struct sockaddr_in sockaddr;
  91. socklen_t len = sizeof(sockaddr);
  92. memset((char*)&sockaddr, 0, sizeof(sockaddr));
  93. int fd;
  94. if((fd = accept(sockfd, (struct sockaddr*)&sockaddr, &len)) == -1) {
  95. throw new DlAbortEx(EX_SOCKET_ACCEPT, strerror(errno));
  96. }
  97. SocketCore* s = new SocketCore(fd);
  98. return s;
  99. }
  100. void SocketCore::getAddrInfo(pair<string, int>& addrinfo) const {
  101. struct sockaddr_in listenaddr;
  102. memset((char*)&listenaddr, 0, sizeof(listenaddr));
  103. socklen_t len = sizeof(listenaddr);
  104. if(getsockname(sockfd, (struct sockaddr*)&listenaddr, &len) == -1) {
  105. throw new DlAbortEx(EX_SOCKET_GET_NAME, strerror(errno));
  106. }
  107. addrinfo.first = inet_ntoa(listenaddr.sin_addr);
  108. addrinfo.second = ntohs(listenaddr.sin_port);
  109. }
  110. void SocketCore::getPeerInfo(pair<string, int>& peerinfo) const {
  111. struct sockaddr_in peerin;
  112. memset(&peerin, 0, sizeof(peerin));
  113. int len = sizeof(peerin);
  114. if(getpeername(sockfd, (struct sockaddr*)&peerin, (socklen_t*)&len) < 0) {
  115. throw new DlAbortEx(EX_SOCKET_GET_PEER, strerror(errno));
  116. }
  117. peerinfo.first = inet_ntoa(peerin.sin_addr);
  118. peerinfo.second = ntohs(peerin.sin_port);
  119. }
  120. void SocketCore::establishConnection(const string& host, int port) {
  121. closeConnection();
  122. sockfd = socket(AF_INET, SOCK_STREAM, 0);
  123. if(sockfd == -1) {
  124. throw new DlAbortEx(EX_SOCKET_OPEN, strerror(errno));
  125. }
  126. socklen_t sockopt = 1;
  127. if(setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &sockopt, sizeof(socklen_t)) < 0) {
  128. close(sockfd);
  129. sockfd = -1;
  130. throw new DlAbortEx(EX_SOCKET_SET_OPT, strerror(errno));
  131. }
  132. struct sockaddr_in sockaddr;
  133. memset((char*)&sockaddr, 0, sizeof(sockaddr));
  134. sockaddr.sin_family = AF_INET;
  135. sockaddr.sin_port = htons(port);
  136. if(inet_aton(host.c_str(), &sockaddr.sin_addr)) {
  137. // ok
  138. } else {
  139. struct addrinfo ai;
  140. memset((char*)&ai, 0, sizeof(ai));
  141. ai.ai_flags = 0;
  142. ai.ai_family = PF_INET;
  143. ai.ai_socktype = SOCK_STREAM;
  144. ai.ai_protocol = 0;
  145. struct addrinfo* res;
  146. int ec;
  147. if((ec = getaddrinfo(host.c_str(), NULL, &ai, &res)) != 0) {
  148. throw new DlAbortEx(EX_RESOLVE_HOSTNAME,
  149. host.c_str(), gai_strerror(ec));
  150. }
  151. sockaddr.sin_addr = ((struct sockaddr_in*)res->ai_addr)->sin_addr;
  152. freeaddrinfo(res);
  153. }
  154. // make socket non-blocking mode
  155. int flags = fcntl(sockfd, F_GETFL, 0);
  156. fcntl(sockfd, F_SETFL, flags|O_NONBLOCK);
  157. if(connect(sockfd, (struct sockaddr*)&sockaddr, (socklen_t)sizeof(sockaddr)) == -1 && errno != EINPROGRESS) {
  158. throw new DlAbortEx(EX_SOCKET_CONNECT, host.c_str(), strerror(errno));
  159. }
  160. }
  161. void SocketCore::setBlockingMode() const {
  162. int flags = fcntl(sockfd, F_GETFL, 0);
  163. fcntl(sockfd, F_SETFL, flags&~O_NONBLOCK);
  164. }
  165. void SocketCore::closeConnection() {
  166. #ifdef HAVE_LIBSSL
  167. // for SSL
  168. if(secure) {
  169. SSL_shutdown(ssl);
  170. }
  171. #endif // HAVE_LIBSSL
  172. #ifdef HAVE_LIBGNUTLS
  173. if(secure) {
  174. gnutls_bye(sslSession, GNUTLS_SHUT_RDWR);
  175. }
  176. #endif // HAVE_LIBGNUTLS
  177. if(sockfd != -1) {
  178. close(sockfd);
  179. sockfd = -1;
  180. }
  181. #ifdef HAVE_LIBSSL
  182. // for SSL
  183. if(secure) {
  184. SSL_free(ssl);
  185. SSL_CTX_free(sslCtx);
  186. }
  187. #endif // HAVE_LIBSSL
  188. #ifdef HAVE_LIBGNUTLS
  189. if(secure) {
  190. gnutls_deinit(sslSession);
  191. gnutls_certificate_free_credentials(sslXcred);
  192. }
  193. #endif // HAVE_LIBGNUTLS
  194. }
  195. bool SocketCore::isWritable(int timeout) const {
  196. fd_set fds;
  197. FD_ZERO(&fds);
  198. FD_SET(sockfd, &fds);
  199. struct timeval tv;
  200. tv.tv_sec = timeout;
  201. tv.tv_usec = 0;
  202. int r = select(sockfd+1, NULL, &fds, NULL, &tv);
  203. if(r == 1) {
  204. return true;
  205. } else if(r == 0) {
  206. // time out
  207. return false;
  208. } else {
  209. if(errno == EINPROGRESS || errno == EINTR) {
  210. return false;
  211. } else {
  212. throw new DlRetryEx(EX_SOCKET_CHECK_WRITABLE, strerror(errno));
  213. }
  214. }
  215. }
  216. bool SocketCore::isReadable(int timeout) const {
  217. #ifdef HAVE_LIBGNUTLS
  218. if(secure && peekBufLength > 0) {
  219. return true;
  220. }
  221. #endif // HAVE_LIBGNUTLS
  222. fd_set fds;
  223. FD_ZERO(&fds);
  224. FD_SET(sockfd, &fds);
  225. struct timeval tv;
  226. tv.tv_sec = timeout;
  227. tv.tv_usec = 0;
  228. int r = select(sockfd+1, &fds, NULL, NULL, &tv);
  229. if(r == 1) {
  230. return true;
  231. } else if(r == 0) {
  232. // time out
  233. return false;
  234. } else {
  235. if(errno == EINPROGRESS || errno == EINTR) {
  236. return false;
  237. } else {
  238. throw new DlRetryEx(EX_SOCKET_CHECK_READABLE, strerror(errno));
  239. }
  240. }
  241. }
  242. void SocketCore::writeData(const char* data, int len) {
  243. int ret = 0;
  244. if(!secure && (ret = send(sockfd, data, (size_t)len, 0)) != len
  245. #ifdef HAVE_LIBSSL
  246. // for SSL
  247. // TODO handling len == 0 case required
  248. || secure && (ret = SSL_write(ssl, data, len)) != len
  249. #endif // HAVE_LIBSSL
  250. #ifdef HAVE_LIBGNUTLS
  251. || secure && (ret = gnutls_record_send(sslSession, data, len)) != len
  252. #endif // HAVE_LIBGNUTLS
  253. ) {
  254. const char* errorMsg;
  255. #ifdef HAVE_LIBGNUTLS
  256. if(secure) {
  257. errorMsg = gnutls_strerror(ret);
  258. } else {
  259. errorMsg = strerror(errno);
  260. }
  261. #else // HAVE_LIBGNUTLS
  262. errorMsg = strerror(errno);
  263. #endif
  264. throw new DlRetryEx(EX_SOCKET_SEND, errorMsg);
  265. }
  266. }
  267. void SocketCore::readData(char* data, int& len) {
  268. int ret = 0;
  269. if(!secure && (ret = recv(sockfd, data, (size_t)len, 0)) < 0
  270. #ifdef HAVE_LIBSSL
  271. // for SSL
  272. // TODO handling len == 0 case required
  273. || secure && (ret = SSL_read(ssl, data, len)) < 0
  274. #endif // HAVE_LIBSSL
  275. #ifdef HAVE_LIBGNUTLS
  276. || secure && (ret = gnutlsRecv(data, len)) < 0
  277. #endif // HAVE_LIBGNUTLS
  278. ) {
  279. const char* errorMsg;
  280. #ifdef HAVE_LIBGNUTLS
  281. if(secure) {
  282. errorMsg = gnutls_strerror(ret);
  283. } else {
  284. errorMsg = strerror(errno);
  285. }
  286. #else // HAVE_LIBGNUTLS
  287. errorMsg = strerror(errno);
  288. #endif
  289. throw new DlRetryEx(EX_SOCKET_RECV, errorMsg);
  290. }
  291. len = ret;
  292. }
  293. void SocketCore::peekData(char* data, int& len) {
  294. int ret = 0;
  295. if(!secure && (ret = recv(sockfd, data, (size_t)len, MSG_PEEK)) < 0
  296. #ifdef HAVE_LIBSSL
  297. // for SSL
  298. // TODO handling len == 0 case required
  299. || secure && (ret = SSL_peek(ssl, data, len)) < 0
  300. #endif // HAVE_LIBSSL
  301. #ifdef HAVE_LIBGNUTLS
  302. || secure && (ret = gnutlsPeek(data, len)) < 0
  303. #endif // HAVE_LIBGNUTLS
  304. ) {
  305. const char* errorMsg;
  306. #ifdef HAVE_LIBGNUTLS
  307. if(secure) {
  308. errorMsg = gnutls_strerror(ret);
  309. } else {
  310. errorMsg = strerror(errno);
  311. }
  312. #else // HAVE_LIBGNUTLS
  313. errorMsg = strerror(errno);
  314. #endif
  315. throw new DlRetryEx(EX_SOCKET_PEEK, errorMsg);
  316. }
  317. len = ret;
  318. }
  319. #ifdef HAVE_LIBGNUTLS
  320. int SocketCore::shiftPeekData(char* data, int len) {
  321. if(peekBufLength <= len) {
  322. memcpy(data, peekBuf, peekBufLength);
  323. int ret = peekBufLength;
  324. peekBufLength = 0;
  325. return ret;
  326. } else {
  327. memcpy(data, peekBuf, len);
  328. char* temp = new char[peekBufMax];
  329. memcpy(temp, peekBuf+len, peekBufLength-len);
  330. delete [] peekBuf;
  331. peekBuf = temp;
  332. peekBufLength -= len;
  333. return len;
  334. }
  335. }
  336. void SocketCore::addPeekData(char* data, int len) {
  337. if(peekBufLength+len > peekBufMax) {
  338. char* temp = new char[peekBufMax+len];
  339. memcpy(temp, peekBuf, peekBufLength);
  340. delete [] peekBuf;
  341. peekBuf = temp;
  342. peekBufMax = peekBufLength+len;
  343. }
  344. memcpy(peekBuf+peekBufLength, data, len);
  345. peekBufLength += len;
  346. }
  347. int SocketCore::gnutlsRecv(char* data, int len) {
  348. int plen = shiftPeekData(data, len);
  349. if(plen < len) {
  350. int ret = gnutls_record_recv(sslSession, data+plen, len-plen);
  351. if(ret < 0) {
  352. throw new DlRetryEx(EX_SOCKET_RECV, gnutls_strerror(ret));
  353. }
  354. return plen+ret;
  355. } else {
  356. return plen;
  357. }
  358. }
  359. int SocketCore::gnutlsPeek(char* data, int len) {
  360. if(peekBufLength >= len) {
  361. memcpy(data, peekBuf, len);
  362. return len;
  363. } else {
  364. memcpy(data, peekBuf, peekBufLength);
  365. int ret = gnutls_record_recv(sslSession, data+peekBufLength, len-peekBufLength);
  366. if(ret < 0) {
  367. throw new DlRetryEx(EX_SOCKET_PEEK, gnutls_strerror(ret));
  368. }
  369. addPeekData(data+peekBufLength, ret);
  370. return peekBufLength;
  371. }
  372. }
  373. #endif // HAVE_LIBGNUTLS
  374. void SocketCore::initiateSecureConnection() {
  375. #ifdef HAVE_LIBSSL
  376. // for SSL
  377. if(!secure) {
  378. sslCtx = SSL_CTX_new(SSLv23_client_method());
  379. if(sslCtx == NULL) {
  380. throw new DlAbortEx(EX_SSL_INIT_FAILURE);
  381. }
  382. SSL_CTX_set_mode(sslCtx, SSL_MODE_AUTO_RETRY);
  383. ssl = SSL_new(sslCtx);
  384. if(ssl == NULL) {
  385. throw new DlAbortEx(EX_SSL_INIT_FAILURE);
  386. }
  387. if(SSL_set_fd(ssl, sockfd) == 0) {
  388. throw new DlAbortEx(EX_SSL_INIT_FAILURE);
  389. }
  390. // TODO handling return value == 0 case required
  391. if(SSL_connect(ssl) <= 0) {
  392. throw new DlAbortEx(EX_SSL_INIT_FAILURE);
  393. }
  394. secure = true;
  395. }
  396. #endif // HAVE_LIBSSL
  397. #ifdef HAVE_LIBGNUTLS
  398. if(!secure) {
  399. const int cert_type_priority[3] = { GNUTLS_CRT_X509,
  400. GNUTLS_CRT_OPENPGP, 0
  401. };
  402. // while we do not support X509 certificate, most web servers require
  403. // X509 stuff.
  404. gnutls_certificate_allocate_credentials (&sslXcred);
  405. gnutls_init(&sslSession, GNUTLS_CLIENT);
  406. gnutls_set_default_priority(sslSession);
  407. gnutls_kx_set_priority(sslSession, cert_type_priority);
  408. // put the x509 credentials to the current session
  409. gnutls_credentials_set(sslSession, GNUTLS_CRD_CERTIFICATE, sslXcred);
  410. gnutls_transport_set_ptr(sslSession, (gnutls_transport_ptr_t)sockfd);
  411. int ret = gnutls_handshake(sslSession);
  412. if(ret < 0) {
  413. throw new DlAbortEx(gnutls_strerror(ret));
  414. }
  415. secure = true;
  416. }
  417. #endif // HAVE_LIBGNUTLS
  418. }