WinTLSSession.cc 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857
  1. /* <!-- copyright */
  2. /*
  3. * aria2 - The high speed download utility
  4. *
  5. * Copyright (C) 2013 Nils Maier
  6. *
  7. * This program is free software; you can redistribute it and/or modify
  8. * it under the terms of the GNU General Public License as published by
  9. * the Free Software Foundation; either version 2 of the License, or
  10. * (at your option) any later version.
  11. *
  12. * This program is distributed in the hope that it will be useful,
  13. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  14. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  15. * GNU General Public License for more details.
  16. *
  17. * You should have received a copy of the GNU General Public License
  18. * along with this program; if not, write to the Free Software
  19. * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  20. *
  21. * In addition, as a special exception, the copyright holders give
  22. * permission to link the code of portions of this program with the
  23. * OpenSSL library under certain conditions as described in each
  24. * individual source file, and distribute linked combinations
  25. * including the two.
  26. * You must obey the GNU General Public License in all respects
  27. * for all of the code used other than OpenSSL. If you modify
  28. * file(s) with this exception, you may extend this exception to your
  29. * version of the file(s), but you are not obligated to do so. If you
  30. * do not wish to do so, delete this exception statement from your
  31. * version. If you delete this exception statement from all source
  32. * files in the program, then also delete it here.
  33. */
  34. /* copyright --> */
  35. #include "WinTLSSession.h"
  36. #include <cassert>
  37. #include <sstream>
  38. #include "LogFactory.h"
  39. #include "a2functional.h"
  40. #include "fmt.h"
  41. #include "util.h"
  42. #ifndef SECBUFFER_ALERT
  43. #define SECBUFFER_ALERT 17
  44. #endif
  45. #ifndef SZ_ALG_MAX_SIZE
  46. #define SZ_ALG_MAX_SIZE 64
  47. #endif
  48. #ifndef SECPKGCONTEXT_CIPHERINFO_V1
  49. #define SECPKGCONTEXT_CIPHERINFO_V1 1
  50. #endif
  51. #ifndef SECPKG_ATTR_CIPHER_INFO
  52. #define SECPKG_ATTR_CIPHER_INFO 0x64
  53. #endif
  54. namespace {
  55. using namespace aria2;
  56. struct WinSecPkgContext_CipherInfo {
  57. DWORD dwVersion;
  58. DWORD dwProtocol;
  59. DWORD dwCipherSuite;
  60. DWORD dwBaseCipherSuite;
  61. WCHAR szCipherSuite[SZ_ALG_MAX_SIZE];
  62. WCHAR szCipher[SZ_ALG_MAX_SIZE];
  63. DWORD dwCipherLen;
  64. DWORD dwCipherBlockLen; // in bytes
  65. WCHAR szHash[SZ_ALG_MAX_SIZE];
  66. DWORD dwHashLen;
  67. WCHAR szExchange[SZ_ALG_MAX_SIZE];
  68. DWORD dwMinExchangeLen;
  69. DWORD dwMaxExchangeLen;
  70. WCHAR szCertificate[SZ_ALG_MAX_SIZE];
  71. DWORD dwKeyType;
  72. };
  73. static const ULONG kReqFlags =
  74. ISC_REQ_SEQUENCE_DETECT | ISC_REQ_REPLAY_DETECT | ISC_REQ_CONFIDENTIALITY |
  75. ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_USE_SUPPLIED_CREDS | ISC_REQ_STREAM;
  76. static const ULONG kReqAFlags =
  77. ASC_REQ_SEQUENCE_DETECT | ASC_REQ_REPLAY_DETECT | ASC_REQ_CONFIDENTIALITY |
  78. ASC_REQ_EXTENDED_ERROR | ASC_REQ_ALLOCATE_MEMORY | ASC_REQ_STREAM;
  79. class TLSBufferDesc : public ::SecBufferDesc {
  80. public:
  81. explicit TLSBufferDesc(SecBuffer* arr, ULONG buffers)
  82. {
  83. ulVersion = SECBUFFER_VERSION;
  84. cBuffers = buffers;
  85. pBuffers = arr;
  86. }
  87. };
  88. inline static std::string getCipherSuite(CtxtHandle* handle)
  89. {
  90. WinSecPkgContext_CipherInfo info = {SECPKGCONTEXT_CIPHERINFO_V1};
  91. if (QueryContextAttributes(handle, SECPKG_ATTR_CIPHER_INFO, &info) ==
  92. SEC_E_OK) {
  93. return wCharToUtf8(info.szCipherSuite);
  94. }
  95. return "Unknown";
  96. }
  97. inline static uint32_t getProtocolVersion(CtxtHandle* handle)
  98. {
  99. WinSecPkgContext_CipherInfo info = {SECPKGCONTEXT_CIPHERINFO_V1};
  100. if (QueryContextAttributes(handle, SECPKG_ATTR_CIPHER_INFO, &info) ==
  101. SEC_E_OK) {
  102. return info.dwProtocol;
  103. }
  104. // XXX Assume the best?!
  105. return std::numeric_limits<uint32_t>::max();
  106. }
  107. } // namespace
  108. namespace aria2 {
  109. TLSSession* TLSSession::make(TLSContext* ctx)
  110. {
  111. return new WinTLSSession(static_cast<WinTLSContext*>(ctx));
  112. }
  113. WinTLSSession::WinTLSSession(WinTLSContext* ctx)
  114. : sockfd_(0),
  115. side_(ctx->getSide()),
  116. cred_(ctx->getCredHandle()),
  117. writeBuffered_(0),
  118. state_(st_constructed),
  119. status_(SEC_E_OK),
  120. recordBytesSent_(0)
  121. {
  122. memset(&handle_, 0, sizeof(handle_));
  123. }
  124. WinTLSSession::~WinTLSSession()
  125. {
  126. ::DeleteSecurityContext(&handle_);
  127. state_ = st_error;
  128. }
  129. int WinTLSSession::init(sock_t sockfd)
  130. {
  131. if (state_ != st_constructed) {
  132. status_ = SEC_E_INVALID_HANDLE;
  133. return TLS_ERR_ERROR;
  134. }
  135. sockfd_ = sockfd;
  136. state_ = st_initialized;
  137. return TLS_ERR_OK;
  138. }
  139. int WinTLSSession::setSNIHostname(const std::string& hostname)
  140. {
  141. if (state_ != st_initialized) {
  142. status_ = SEC_E_INVALID_HANDLE;
  143. return TLS_ERR_ERROR;
  144. }
  145. hostname_ = hostname;
  146. return TLS_ERR_OK;
  147. }
  148. int WinTLSSession::closeConnection()
  149. {
  150. if (state_ != st_connected && state_ != st_closing) {
  151. if (state_ != st_error) {
  152. status_ = SEC_E_INVALID_HANDLE;
  153. state_ = st_error;
  154. }
  155. A2_LOG_DEBUG("WinTLS: Cannot close connection");
  156. return TLS_ERR_ERROR;
  157. }
  158. if (state_ == st_connected) {
  159. A2_LOG_DEBUG("WinTLS: Closing connection");
  160. state_ = st_closing;
  161. DWORD dwShut = SCHANNEL_SHUTDOWN;
  162. TLSBuffer shut(SECBUFFER_TOKEN, sizeof(dwShut), &dwShut);
  163. TLSBufferDesc shutDesc(&shut, 1);
  164. status_ = ::ApplyControlToken(&handle_, &shutDesc);
  165. if (status_ != SEC_E_OK) {
  166. state_ = st_error;
  167. return TLS_ERR_ERROR;
  168. }
  169. TLSBuffer ctx(SECBUFFER_EMPTY, 0, nullptr);
  170. TLSBufferDesc desc(&ctx, 1);
  171. ULONG flags = 0;
  172. if (side_ == TLS_CLIENT) {
  173. SEC_CHAR* host = hostname_.empty()
  174. ? nullptr
  175. : const_cast<SEC_CHAR*>(hostname_.c_str());
  176. status_ = ::InitializeSecurityContext(cred_, &handle_, host, kReqFlags, 0,
  177. 0, nullptr, 0, &handle_, &desc,
  178. &flags, nullptr);
  179. }
  180. else {
  181. status_ = ::AcceptSecurityContext(cred_, &handle_, nullptr, kReqAFlags, 0,
  182. &handle_, &desc, &flags, nullptr);
  183. }
  184. if ((status_ == SEC_E_OK || status_ == SEC_I_CONTEXT_EXPIRED) &&
  185. getLeftTLSRecordSize() == 0) {
  186. size_t len = ctx.cbBuffer;
  187. ssize_t rv = writeData(ctx.pvBuffer, ctx.cbBuffer);
  188. ::FreeContextBuffer(ctx.pvBuffer);
  189. if (rv == TLS_ERR_WOULDBLOCK) {
  190. return rv;
  191. }
  192. // Alright data is sent or buffered
  193. if (rv - len != 0) {
  194. return TLS_ERR_WOULDBLOCK;
  195. }
  196. }
  197. }
  198. A2_LOG_DEBUG("WinTLS: Closed Connection");
  199. state_ = st_closed;
  200. return TLS_ERR_OK;
  201. }
  202. int WinTLSSession::checkDirection()
  203. {
  204. if (state_ == st_handshake_write || state_ == st_handshake_write_last) {
  205. return TLS_WANT_WRITE;
  206. }
  207. if (state_ == st_handshake_read) {
  208. return TLS_WANT_READ;
  209. }
  210. if (readBuf_.size() || decBuf_.size()) {
  211. return TLS_WANT_READ;
  212. }
  213. if (getLeftTLSRecordSize() || writeBuf_.size()) {
  214. return TLS_WANT_WRITE;
  215. }
  216. return TLS_WANT_READ;
  217. }
  218. namespace {
  219. // Fills |iov| of length |len| to send remaining data in |buffers|.
  220. // We have already sent |offset| bytes. This function returns the
  221. // number of |iov| filled. It assumes the array |buffers| is at least
  222. // |len| elements.
  223. size_t fillSendIOV(a2iovec* iov, size_t len, TLSBuffer* buffers, size_t offset)
  224. {
  225. size_t iovcnt = 0;
  226. for (size_t i = 0; i < len; ++i) {
  227. if (offset < buffers[i].cbBuffer) {
  228. iov[iovcnt].A2IOVEC_BASE =
  229. static_cast<char*>(buffers[i].pvBuffer) + offset;
  230. iov[iovcnt].A2IOVEC_LEN = buffers[i].cbBuffer - offset;
  231. ++iovcnt;
  232. offset = 0;
  233. }
  234. else {
  235. offset -= buffers[i].cbBuffer;
  236. }
  237. }
  238. return iovcnt;
  239. }
  240. } // namespace
  241. size_t WinTLSSession::getLeftTLSRecordSize() const
  242. {
  243. return sendRecordBuffers_[0].cbBuffer + sendRecordBuffers_[1].cbBuffer +
  244. sendRecordBuffers_[2].cbBuffer - recordBytesSent_;
  245. }
  246. int WinTLSSession::sendTLSRecord()
  247. {
  248. A2_LOG_DEBUG(fmt("WinTLS: TLS record %" PRIu64 " bytes left",
  249. static_cast<uint64_t>(getLeftTLSRecordSize())));
  250. while (getLeftTLSRecordSize()) {
  251. std::array<a2iovec, 3> iov;
  252. auto iovcnt = fillSendIOV(iov.data(), iov.size(), sendRecordBuffers_.data(),
  253. recordBytesSent_);
  254. DWORD nwrite;
  255. auto rv =
  256. WSASend(sockfd_, iov.data(), iovcnt, &nwrite, 0, nullptr, nullptr);
  257. if (rv != 0) {
  258. auto errnum = ::WSAGetLastError();
  259. if (errnum == WSAEINTR) {
  260. continue;
  261. }
  262. if (errnum == WSAEWOULDBLOCK) {
  263. return TLS_ERR_WOULDBLOCK;
  264. }
  265. A2_LOG_ERROR("WinTLS: Connection error while writing");
  266. status_ = SEC_E_INCOMPLETE_MESSAGE;
  267. state_ = st_error;
  268. return TLS_ERR_ERROR;
  269. }
  270. recordBytesSent_ += nwrite;
  271. }
  272. recordBytesSent_ = 0;
  273. sendRecordBuffers_[0].cbBuffer = 0;
  274. sendRecordBuffers_[1].cbBuffer = 0;
  275. sendRecordBuffers_[2].cbBuffer = 0;
  276. return 0;
  277. }
  278. ssize_t WinTLSSession::writeData(const void* data, size_t len)
  279. {
  280. if (state_ == st_handshake_write || state_ == st_handshake_write_last ||
  281. state_ == st_handshake_read) {
  282. // Renegotiating
  283. std::string hn, err;
  284. TLSVersion ver;
  285. auto connect = tlsConnect(hn, ver, err);
  286. if (connect != TLS_ERR_OK) {
  287. return connect;
  288. }
  289. // Continue.
  290. }
  291. if (state_ != st_connected && state_ != st_closing) {
  292. status_ = SEC_E_INVALID_HANDLE;
  293. return TLS_ERR_ERROR;
  294. }
  295. A2_LOG_DEBUG(fmt("WinTLS: Write request: %" PRIu64 " buffered: %" PRIu64,
  296. (uint64_t)len, (uint64_t)recordBytesSent_));
  297. auto rv = sendTLSRecord();
  298. if (rv != 0) {
  299. return rv;
  300. }
  301. auto left = len;
  302. auto bytes = static_cast<const char*>(data);
  303. if (writeBuffered_) {
  304. // There was buffered data, hence we need to "remove" that data from the
  305. // incoming buffer to avoid writing it again
  306. if (len < writeBuffered_) {
  307. // We didn't get called with the same data again, obviously.
  308. status_ = SEC_E_INVALID_HANDLE;
  309. status_ = st_error;
  310. return TLS_ERR_ERROR;
  311. }
  312. // just advance the buffer by writeBuffered_ bytes
  313. bytes += writeBuffered_;
  314. left -= writeBuffered_;
  315. writeBuffered_ = 0;
  316. }
  317. if (!left) {
  318. // The buffer contained the full remainder. At this point, the buffer has
  319. // been written, so the request is done in its entirety;
  320. return len;
  321. }
  322. // Buffered data was already written ;)
  323. // If there was no buffered data, this will be len - len = 0.
  324. len -= left;
  325. while (left) {
  326. // Set up an outgoing message, according to streamSizes_
  327. writeBuffered_ =
  328. std::min(left, static_cast<size_t>(streamSizes_.cbMaximumMessage));
  329. sendRecordBuffers_ = {
  330. TLSBuffer(SECBUFFER_STREAM_HEADER, streamSizes_.cbHeader,
  331. sendBuffer_.data()),
  332. TLSBuffer(SECBUFFER_DATA, writeBuffered_,
  333. sendBuffer_.data() + streamSizes_.cbHeader),
  334. TLSBuffer(SECBUFFER_STREAM_TRAILER, streamSizes_.cbTrailer,
  335. sendBuffer_.data() + streamSizes_.cbHeader + writeBuffered_),
  336. TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
  337. };
  338. TLSBufferDesc desc(sendRecordBuffers_.data(), sendRecordBuffers_.size());
  339. std::copy_n(bytes, writeBuffered_,
  340. static_cast<char*>(sendRecordBuffers_[1].pvBuffer));
  341. status_ = ::EncryptMessage(&handle_, 0, &desc, 0);
  342. if (status_ != SEC_E_OK) {
  343. A2_LOG_ERROR(fmt("WinTLS: Failed to encrypt a message! %s",
  344. getLastErrorString().c_str()));
  345. state_ = st_error;
  346. return TLS_ERR_ERROR;
  347. }
  348. A2_LOG_DEBUG(fmt("WinTLS: Write TLS record header: %" PRIu64
  349. " body: %" PRIu64 " trailer: %" PRIu64,
  350. static_cast<uint64_t>(sendRecordBuffers_[0].cbBuffer),
  351. static_cast<uint64_t>(sendRecordBuffers_[1].cbBuffer),
  352. static_cast<uint64_t>(sendRecordBuffers_[2].cbBuffer)));
  353. auto rv = sendTLSRecord();
  354. if (rv == TLS_ERR_WOULDBLOCK) {
  355. if (len == 0) {
  356. return TLS_ERR_WOULDBLOCK;
  357. }
  358. return len;
  359. }
  360. if (rv != 0) {
  361. return rv;
  362. }
  363. len += writeBuffered_;
  364. bytes += writeBuffered_;
  365. left -= writeBuffered_;
  366. writeBuffered_ = 0;
  367. }
  368. A2_LOG_DEBUG(fmt("WinTLS: Write result: %" PRIu64, (uint64_t)len));
  369. return len;
  370. }
  371. ssize_t WinTLSSession::readData(void* data, size_t len)
  372. {
  373. A2_LOG_DEBUG(fmt("WinTLS: Read request: %" PRIu64 " buffered: %" PRIu64,
  374. (uint64_t)len, (uint64_t)readBuf_.size()));
  375. if (len == 0) {
  376. return 0;
  377. }
  378. // Can be filled from decBuffer entirely?
  379. if (decBuf_.size() >= len) {
  380. A2_LOG_DEBUG("WinTLS: Fullfilling req from buffer");
  381. memcpy(data, decBuf_.data(), len);
  382. decBuf_.eat(len);
  383. return len;
  384. }
  385. if (state_ == st_closing || state_ == st_closed || state_ == st_error) {
  386. auto nread = decBuf_.size();
  387. if (nread) {
  388. assert(nread < len);
  389. memcpy(data, decBuf_.data(), nread);
  390. decBuf_.clear();
  391. A2_LOG_DEBUG("WinTLS: Sending out decrypted buffer after EOF");
  392. return nread;
  393. }
  394. A2_LOG_DEBUG("WinTLS: Read request aborted. Connection already closed");
  395. return state_ == st_error ? TLS_ERR_ERROR : 0;
  396. }
  397. if (state_ == st_handshake_write || state_ == st_handshake_write_last ||
  398. state_ == st_handshake_read) {
  399. // Renegotiating
  400. std::string hn, err;
  401. TLSVersion ver;
  402. auto connect = tlsConnect(hn, ver, err);
  403. if (connect != TLS_ERR_OK) {
  404. return connect;
  405. }
  406. // Continue.
  407. }
  408. if (state_ != st_connected) {
  409. status_ = SEC_E_INVALID_HANDLE;
  410. return TLS_ERR_ERROR;
  411. }
  412. // Read as many bytes as available from the connection, up to len + 4k.
  413. readBuf_.resize(len + 4_k);
  414. while (readBuf_.free()) {
  415. ssize_t read = ::recv(sockfd_, readBuf_.end(), readBuf_.free(), 0);
  416. errno = ::WSAGetLastError();
  417. if (read < 0 && errno == WSAEINTR) {
  418. continue;
  419. }
  420. if (read < 0 && errno == WSAEWOULDBLOCK) {
  421. break;
  422. }
  423. if (read < 0) {
  424. status_ = errno;
  425. state_ = st_error;
  426. return TLS_ERR_ERROR;
  427. }
  428. if (read == 0) {
  429. A2_LOG_DEBUG("WinTLS: Connection abruptly closed!");
  430. // At least try to gracefully close our write end.
  431. closeConnection();
  432. break;
  433. }
  434. readBuf_.advance(read);
  435. }
  436. // Try to decrypt as many messages as possible from the readBuf_.
  437. while (readBuf_.size()) {
  438. TLSBuffer bufs[] = {
  439. TLSBuffer(SECBUFFER_DATA, readBuf_.size(), readBuf_.data()),
  440. TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
  441. TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
  442. TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
  443. };
  444. TLSBufferDesc desc(bufs, 4);
  445. status_ = ::DecryptMessage(&handle_, &desc, 0, nullptr);
  446. if (status_ == SEC_E_INCOMPLETE_MESSAGE) {
  447. // Need to stop now, and wait for more bytes to arrive on the socket.
  448. break;
  449. }
  450. if (status_ != SEC_E_OK && status_ != SEC_I_CONTEXT_EXPIRED &&
  451. status_ != SEC_I_RENEGOTIATE) {
  452. A2_LOG_ERROR(fmt("WinTLS: Failed to decrypt a message! %s",
  453. getLastErrorString().c_str()));
  454. state_ = st_error;
  455. return TLS_ERR_ERROR;
  456. }
  457. // Decrypted message successfully.
  458. bool ate = false;
  459. for (auto& buf : bufs) {
  460. if (buf.BufferType == SECBUFFER_DATA && buf.cbBuffer > 0) {
  461. decBuf_.write(buf.pvBuffer, buf.cbBuffer);
  462. }
  463. else if (buf.BufferType == SECBUFFER_EXTRA && buf.cbBuffer > 0) {
  464. readBuf_.eat(readBuf_.size() - buf.cbBuffer);
  465. ate = true;
  466. }
  467. }
  468. if (!ate) {
  469. readBuf_.clear();
  470. }
  471. if (status_ == SEC_I_RENEGOTIATE) {
  472. // Renegotiation basically means performing another handshake
  473. state_ = st_initialized;
  474. A2_LOG_INFO("WinTLS: Renegotiate");
  475. std::string hn, err;
  476. TLSVersion ver;
  477. auto connect = tlsConnect(hn, ver, err);
  478. if (connect == TLS_ERR_WOULDBLOCK) {
  479. break;
  480. }
  481. if (connect == TLS_ERR_ERROR) {
  482. return connect;
  483. }
  484. // Still good.
  485. }
  486. if (status_ == SEC_I_CONTEXT_EXPIRED) {
  487. // Connection is gone now, but the buffered bytes are still valid.
  488. A2_LOG_DEBUG("WinTLS: Connection gracefully closed!");
  489. closeConnection();
  490. break;
  491. }
  492. }
  493. len = std::min(decBuf_.size(), len);
  494. if (len == 0) {
  495. if (state_ != st_connected) {
  496. return state_ == st_error ? TLS_ERR_ERROR : 0;
  497. }
  498. return TLS_ERR_WOULDBLOCK;
  499. }
  500. memcpy(data, decBuf_.data(), len);
  501. decBuf_.eat(len);
  502. return len;
  503. }
  504. int WinTLSSession::tlsConnect(const std::string& hostname, TLSVersion& version,
  505. std::string& handshakeErr)
  506. {
  507. // Handshaking will require sending multiple read/write exchanges until the
  508. // handshake is actually done. The client will first generate the initial
  509. // handshake message, then write that to the server, read the response
  510. // message, and write and/or read additional messages until the handshake is
  511. // either complete and successful, or something went wrong.
  512. // The server works analog to that.
  513. A2_LOG_DEBUG("WinTLS: Starting/Resuming TLS Connect");
  514. ULONG flags = 0;
  515. restart:
  516. switch (state_) {
  517. default:
  518. A2_LOG_ERROR("WinTLS: Invalid state");
  519. status_ = SEC_E_INVALID_HANDLE;
  520. return TLS_ERR_ERROR;
  521. case st_initialized: {
  522. if (side_ == TLS_SERVER) {
  523. goto read;
  524. }
  525. if (!hostname.empty()) {
  526. setSNIHostname(hostname);
  527. }
  528. A2_LOG_DEBUG("WinTLS: Initializing handshake");
  529. TLSBuffer buf(SECBUFFER_EMPTY, 0, nullptr);
  530. TLSBufferDesc desc(&buf, 1);
  531. SEC_CHAR* host =
  532. hostname_.empty() ? nullptr : const_cast<SEC_CHAR*>(hostname_.c_str());
  533. status_ = ::InitializeSecurityContext(cred_, nullptr, host, kReqFlags, 0, 0,
  534. nullptr, 0, &handle_, &desc, &flags,
  535. nullptr);
  536. if (status_ != SEC_I_CONTINUE_NEEDED) {
  537. // Has to be SEC_I_CONTINUE_NEEDED, as we did not actually send data
  538. // at this point.
  539. state_ = st_error;
  540. return TLS_ERR_ERROR;
  541. }
  542. // Queue the initial message...
  543. writeBuf_.write(buf.pvBuffer, buf.cbBuffer);
  544. FreeContextBuffer(buf.pvBuffer);
  545. // ... and start sending it
  546. state_ = st_handshake_write;
  547. }
  548. // Fall through
  549. case st_handshake_write_last:
  550. case st_handshake_write: {
  551. A2_LOG_DEBUG("WinTLS: Writing handshake");
  552. // Write the currently queued handshake message until all data is sent.
  553. while (writeBuf_.size()) {
  554. ssize_t writ = ::send(sockfd_, writeBuf_.data(), writeBuf_.size(), 0);
  555. errno = ::WSAGetLastError();
  556. if (writ < 0 && errno == WSAEINTR) {
  557. continue;
  558. }
  559. if (writ < 0 && errno == WSAEWOULDBLOCK) {
  560. return TLS_ERR_WOULDBLOCK;
  561. }
  562. if (writ <= 0) {
  563. status_ = errno;
  564. state_ = st_error;
  565. return TLS_ERR_ERROR;
  566. }
  567. writeBuf_.eat(writ);
  568. }
  569. if (state_ == st_handshake_write_last) {
  570. state_ = st_handshake_done;
  571. goto restart;
  572. }
  573. // Have to read one or more response messages.
  574. state_ = st_handshake_read;
  575. }
  576. // Fall through
  577. case st_handshake_read: {
  578. read:
  579. A2_LOG_DEBUG("WinTLS: Reading handshake...");
  580. // All write buffered data is invalid at this point!
  581. writeBuf_.clear();
  582. // Read as many bytes as possible, up to 4k new bytes.
  583. // We do not know how many bytes will arrive from the server at this
  584. // point.
  585. readBuf_.resize(readBuf_.size() + 4_k);
  586. while (readBuf_.free()) {
  587. ssize_t read = ::recv(sockfd_, readBuf_.end(), readBuf_.free(), 0);
  588. errno = ::WSAGetLastError();
  589. if (read < 0 && errno == WSAEINTR) {
  590. continue;
  591. }
  592. if (read < 0 && errno == WSAEWOULDBLOCK) {
  593. break;
  594. }
  595. if (read <= 0) {
  596. status_ = errno;
  597. state_ = st_error;
  598. return TLS_ERR_ERROR;
  599. }
  600. if (read == 0) {
  601. A2_LOG_DEBUG("WinTLS: Connection abruptly closed during handshake!");
  602. status_ = SEC_E_INCOMPLETE_MESSAGE;
  603. state_ = st_error;
  604. return TLS_ERR_ERROR;
  605. }
  606. readBuf_.advance(read);
  607. break;
  608. }
  609. if (!readBuf_.size()) {
  610. return TLS_ERR_WOULDBLOCK;
  611. }
  612. // Need to copy the data, as Schannel is free to mess with it. But we
  613. // might later need unmodified data from the original read buffer.
  614. auto bufcopy = make_unique<char[]>(readBuf_.size());
  615. memcpy(bufcopy.get(), readBuf_.data(), readBuf_.size());
  616. // Set up buffers. inbufs will be the raw bytes the library has to decode.
  617. // outbufs will contain generated responses, if any.
  618. TLSBuffer inbufs[] = {
  619. TLSBuffer(SECBUFFER_TOKEN, readBuf_.size(), bufcopy.get()),
  620. TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
  621. };
  622. TLSBufferDesc indesc(inbufs, 2);
  623. TLSBuffer outbufs[] = {
  624. TLSBuffer(SECBUFFER_TOKEN, 0, nullptr),
  625. TLSBuffer(SECBUFFER_ALERT, 0, nullptr),
  626. };
  627. TLSBufferDesc outdesc(outbufs, 2);
  628. if (side_ == TLS_CLIENT) {
  629. SEC_CHAR* host = hostname_.empty()
  630. ? nullptr
  631. : const_cast<SEC_CHAR*>(hostname_.c_str());
  632. status_ = ::InitializeSecurityContext(cred_, &handle_, host, kReqFlags, 0,
  633. 0, &indesc, 0, nullptr, &outdesc,
  634. &flags, nullptr);
  635. }
  636. else {
  637. status_ = ::AcceptSecurityContext(
  638. cred_, state_ == st_initialized ? nullptr : &handle_, &indesc,
  639. kReqAFlags, 0, state_ == st_initialized ? &handle_ : nullptr,
  640. &outdesc, &flags, nullptr);
  641. }
  642. if (status_ == SEC_E_INCOMPLETE_MESSAGE) {
  643. // Not enough raw bytes read yet to decode a full message.
  644. return TLS_ERR_WOULDBLOCK;
  645. }
  646. if (status_ != SEC_E_OK && status_ != SEC_I_CONTINUE_NEEDED) {
  647. state_ = st_error;
  648. return TLS_ERR_ERROR;
  649. }
  650. // Raw bytes where not entirely consumed, i.e. readBuf_ still contains
  651. // unprocessed data from the next message?
  652. if (inbufs[1].BufferType == SECBUFFER_EXTRA && inbufs[1].cbBuffer > 0) {
  653. readBuf_.eat(readBuf_.size() - inbufs[1].cbBuffer);
  654. }
  655. else {
  656. readBuf_.clear();
  657. }
  658. // Check if the library produced a new outgoing message and queue it.
  659. for (auto& buf : outbufs) {
  660. if (buf.BufferType == SECBUFFER_TOKEN && buf.cbBuffer > 0) {
  661. writeBuf_.write(buf.pvBuffer, buf.cbBuffer);
  662. FreeContextBuffer(buf.pvBuffer);
  663. state_ = st_handshake_write;
  664. }
  665. }
  666. // Need to read additional messages?
  667. if (status_ == SEC_I_CONTINUE_NEEDED) {
  668. A2_LOG_DEBUG("WinTLS: Continuing with handshake");
  669. goto restart;
  670. }
  671. if (side_ == TLS_CLIENT && flags != kReqFlags) {
  672. A2_LOG_ERROR(fmt("WinTLS: Channel setup failed. Schannel provider did "
  673. "not fulfill requested flags. "
  674. "Excepted: %lu Actual: %lu",
  675. kReqFlags, flags));
  676. status_ = SEC_E_INTERNAL_ERROR;
  677. state_ = st_error;
  678. return TLS_ERR_ERROR;
  679. }
  680. if (state_ == st_handshake_write) {
  681. A2_LOG_DEBUG("WinTLS: Continuing with handshake (last write)");
  682. state_ = st_handshake_write_last;
  683. goto restart;
  684. }
  685. }
  686. // Fall through
  687. case st_handshake_done:
  688. if (obtainTLSRecordSizes() != 0) {
  689. return TLS_ERR_ERROR;
  690. }
  691. ensureSendBuffer();
  692. // All ready now :D
  693. state_ = st_connected;
  694. A2_LOG_INFO(
  695. fmt("WinTLS: connected with: %s", getCipherSuite(&handle_).c_str()));
  696. switch (getProtocolVersion(&handle_)) {
  697. case 0x300:
  698. version = TLS_PROTO_SSL3;
  699. break;
  700. case 0x301:
  701. version = TLS_PROTO_TLS10;
  702. break;
  703. case 0x302:
  704. version = TLS_PROTO_TLS11;
  705. break;
  706. case 0x303:
  707. version = TLS_PROTO_TLS12;
  708. break;
  709. default:
  710. version = TLS_PROTO_NONE;
  711. break;
  712. }
  713. return TLS_ERR_OK;
  714. }
  715. A2_LOG_ERROR("WinTLS: Unreachable reached during tlsConnect! This is a bug!");
  716. state_ = st_error;
  717. return TLS_ERR_ERROR;
  718. }
  719. int WinTLSSession::tlsAccept(TLSVersion& version)
  720. {
  721. std::string host, err;
  722. return tlsConnect(host, version, err);
  723. }
  724. std::string WinTLSSession::getLastErrorString()
  725. {
  726. std::stringstream ss;
  727. wchar_t* buf = nullptr;
  728. auto rv = FormatMessageW(
  729. FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM |
  730. FORMAT_MESSAGE_IGNORE_INSERTS,
  731. nullptr, status_, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPWSTR)&buf,
  732. 1024, nullptr);
  733. if (rv && buf) {
  734. ss << "Error: " << wCharToUtf8(buf) << "(" << std::hex << status_ << ")";
  735. LocalFree(buf);
  736. }
  737. else {
  738. ss << "Error: " << std::hex << status_;
  739. }
  740. return ss.str();
  741. }
  742. size_t WinTLSSession::getRecvBufferedLength() { return decBuf_.size(); }
  743. int WinTLSSession::obtainTLSRecordSizes()
  744. {
  745. status_ = ::QueryContextAttributes(&handle_, SECPKG_ATTR_STREAM_SIZES,
  746. &streamSizes_);
  747. if (status_ != SEC_E_OK || !streamSizes_.cbMaximumMessage) {
  748. A2_LOG_ERROR("WinTLS: Unable to obtain stream sizes");
  749. state_ = st_error;
  750. return -1;
  751. }
  752. return 0;
  753. }
  754. void WinTLSSession::ensureSendBuffer()
  755. {
  756. auto sum = streamSizes_.cbHeader + streamSizes_.cbMaximumMessage +
  757. streamSizes_.cbTrailer;
  758. if (sendBuffer_.size() < sum) {
  759. sendBuffer_.resize(sum);
  760. }
  761. }
  762. } // namespace aria2