WinTLSSession.cc 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837
  1. /* <!-- copyright */
  2. /*
  3. * aria2 - The high speed download utility
  4. *
  5. * Copyright (C) 2013 Nils Maier
  6. *
  7. * This program is free software; you can redistribute it and/or modify
  8. * it under the terms of the GNU General Public License as published by
  9. * the Free Software Foundation; either version 2 of the License, or
  10. * (at your option) any later version.
  11. *
  12. * This program is distributed in the hope that it will be useful,
  13. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  14. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  15. * GNU General Public License for more details.
  16. *
  17. * You should have received a copy of the GNU General Public License
  18. * along with this program; if not, write to the Free Software
  19. * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  20. *
  21. * In addition, as a special exception, the copyright holders give
  22. * permission to link the code of portions of this program with the
  23. * OpenSSL library under certain conditions as described in each
  24. * individual source file, and distribute linked combinations
  25. * including the two.
  26. * You must obey the GNU General Public License in all respects
  27. * for all of the code used other than OpenSSL. If you modify
  28. * file(s) with this exception, you may extend this exception to your
  29. * version of the file(s), but you are not obligated to do so. If you
  30. * do not wish to do so, delete this exception statement from your
  31. * version. If you delete this exception statement from all source
  32. * files in the program, then also delete it here.
  33. */
  34. /* copyright --> */
  35. #include "WinTLSSession.h"
  36. #include <cassert>
  37. #include <sstream>
  38. #include "LogFactory.h"
  39. #include "a2functional.h"
  40. #include "fmt.h"
  41. #include "util.h"
  42. #ifndef SECBUFFER_ALERT
  43. #define SECBUFFER_ALERT 17
  44. #endif
  45. #ifndef SZ_ALG_MAX_SIZE
  46. #define SZ_ALG_MAX_SIZE 64
  47. #endif
  48. #ifndef SECPKGCONTEXT_CIPHERINFO_V1
  49. #define SECPKGCONTEXT_CIPHERINFO_V1 1
  50. #endif
  51. #ifndef SECPKG_ATTR_CIPHER_INFO
  52. #define SECPKG_ATTR_CIPHER_INFO 0x64
  53. #endif
  54. namespace {
  55. using namespace aria2;
  56. struct WinSecPkgContext_CipherInfo {
  57. DWORD dwVersion;
  58. DWORD dwProtocol;
  59. DWORD dwCipherSuite;
  60. DWORD dwBaseCipherSuite;
  61. WCHAR szCipherSuite[SZ_ALG_MAX_SIZE];
  62. WCHAR szCipher[SZ_ALG_MAX_SIZE];
  63. DWORD dwCipherLen;
  64. DWORD dwCipherBlockLen; // in bytes
  65. WCHAR szHash[SZ_ALG_MAX_SIZE];
  66. DWORD dwHashLen;
  67. WCHAR szExchange[SZ_ALG_MAX_SIZE];
  68. DWORD dwMinExchangeLen;
  69. DWORD dwMaxExchangeLen;
  70. WCHAR szCertificate[SZ_ALG_MAX_SIZE];
  71. DWORD dwKeyType;
  72. };
  73. static const ULONG kReqFlags =
  74. ISC_REQ_SEQUENCE_DETECT | ISC_REQ_REPLAY_DETECT | ISC_REQ_CONFIDENTIALITY |
  75. ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_USE_SUPPLIED_CREDS | ISC_REQ_STREAM;
  76. static const ULONG kReqAFlags =
  77. ASC_REQ_SEQUENCE_DETECT | ASC_REQ_REPLAY_DETECT | ASC_REQ_CONFIDENTIALITY |
  78. ASC_REQ_EXTENDED_ERROR | ASC_REQ_ALLOCATE_MEMORY | ASC_REQ_STREAM;
  79. class TLSBuffer : public ::SecBuffer {
  80. public:
  81. explicit TLSBuffer(ULONG type, ULONG size, void* data)
  82. {
  83. cbBuffer = size;
  84. BufferType = type;
  85. pvBuffer = data;
  86. }
  87. };
  88. class TLSBufferDesc : public ::SecBufferDesc {
  89. public:
  90. explicit TLSBufferDesc(SecBuffer* arr, ULONG buffers)
  91. {
  92. ulVersion = SECBUFFER_VERSION;
  93. cBuffers = buffers;
  94. pBuffers = arr;
  95. }
  96. };
  97. inline static std::string getCipherSuite(CtxtHandle* handle)
  98. {
  99. WinSecPkgContext_CipherInfo info = {SECPKGCONTEXT_CIPHERINFO_V1};
  100. if (QueryContextAttributes(handle, SECPKG_ATTR_CIPHER_INFO, &info) ==
  101. SEC_E_OK) {
  102. return wCharToUtf8(info.szCipherSuite);
  103. }
  104. return "Unknown";
  105. }
  106. inline static uint32_t getProtocolVersion(CtxtHandle* handle)
  107. {
  108. WinSecPkgContext_CipherInfo info = {SECPKGCONTEXT_CIPHERINFO_V1};
  109. if (QueryContextAttributes(handle, SECPKG_ATTR_CIPHER_INFO, &info) ==
  110. SEC_E_OK) {
  111. return info.dwProtocol;
  112. }
  113. // XXX Assume the best?!
  114. return std::numeric_limits<uint32_t>::max();
  115. }
  116. } // namespace
  117. namespace aria2 {
  118. TLSSession* TLSSession::make(TLSContext* ctx)
  119. {
  120. return new WinTLSSession(static_cast<WinTLSContext*>(ctx));
  121. }
  122. WinTLSSession::WinTLSSession(WinTLSContext* ctx)
  123. : sockfd_(0),
  124. side_(ctx->getSide()),
  125. cred_(ctx->getCredHandle()),
  126. writeBuffered_(0),
  127. state_(st_constructed),
  128. status_(SEC_E_OK)
  129. {
  130. memset(&handle_, 0, sizeof(handle_));
  131. }
  132. WinTLSSession::~WinTLSSession()
  133. {
  134. ::DeleteSecurityContext(&handle_);
  135. state_ = st_error;
  136. }
  137. int WinTLSSession::init(sock_t sockfd)
  138. {
  139. if (state_ != st_constructed) {
  140. status_ = SEC_E_INVALID_HANDLE;
  141. return TLS_ERR_ERROR;
  142. }
  143. sockfd_ = sockfd;
  144. state_ = st_initialized;
  145. return TLS_ERR_OK;
  146. }
  147. int WinTLSSession::setSNIHostname(const std::string& hostname)
  148. {
  149. if (state_ != st_initialized) {
  150. status_ = SEC_E_INVALID_HANDLE;
  151. return TLS_ERR_ERROR;
  152. }
  153. hostname_ = hostname;
  154. return TLS_ERR_OK;
  155. }
  156. int WinTLSSession::closeConnection()
  157. {
  158. if (state_ != st_connected && state_ != st_closing) {
  159. if (state_ != st_error) {
  160. status_ = SEC_E_INVALID_HANDLE;
  161. state_ = st_error;
  162. }
  163. A2_LOG_DEBUG("WinTLS: Cannot close connection");
  164. return TLS_ERR_ERROR;
  165. }
  166. if (state_ == st_connected) {
  167. A2_LOG_DEBUG("WinTLS: Closing connection");
  168. state_ = st_closing;
  169. DWORD dwShut = SCHANNEL_SHUTDOWN;
  170. TLSBuffer shut(SECBUFFER_TOKEN, sizeof(dwShut), &dwShut);
  171. TLSBufferDesc shutDesc(&shut, 1);
  172. status_ = ::ApplyControlToken(&handle_, &shutDesc);
  173. if (status_ != SEC_E_OK) {
  174. state_ = st_error;
  175. return TLS_ERR_ERROR;
  176. }
  177. TLSBuffer ctx(SECBUFFER_EMPTY, 0, nullptr);
  178. TLSBufferDesc desc(&ctx, 1);
  179. ULONG flags = 0;
  180. if (side_ == TLS_CLIENT) {
  181. SEC_CHAR* host = hostname_.empty()
  182. ? nullptr
  183. : const_cast<SEC_CHAR*>(hostname_.c_str());
  184. status_ = ::InitializeSecurityContext(cred_, &handle_, host, kReqFlags, 0,
  185. 0, nullptr, 0, &handle_, &desc,
  186. &flags, nullptr);
  187. }
  188. else {
  189. status_ = ::AcceptSecurityContext(cred_, &handle_, nullptr, kReqAFlags, 0,
  190. &handle_, &desc, &flags, nullptr);
  191. }
  192. if (status_ == SEC_E_OK || status_ == SEC_I_CONTEXT_EXPIRED) {
  193. size_t len = ctx.cbBuffer;
  194. ssize_t rv = writeData(ctx.pvBuffer, ctx.cbBuffer);
  195. ::FreeContextBuffer(ctx.pvBuffer);
  196. if (rv == TLS_ERR_WOULDBLOCK) {
  197. return rv;
  198. }
  199. // Alright data is sent or buffered
  200. if (rv - len != 0) {
  201. return TLS_ERR_WOULDBLOCK;
  202. }
  203. }
  204. }
  205. // Send remaining data.
  206. while (writeBuf_.size()) {
  207. int rv = writeData(nullptr, 0);
  208. if (rv == 0) {
  209. break;
  210. }
  211. if (rv < 0) {
  212. return rv;
  213. }
  214. }
  215. A2_LOG_DEBUG("WinTLS: Closed Connection");
  216. state_ = st_closed;
  217. return TLS_ERR_OK;
  218. }
  219. int WinTLSSession::checkDirection()
  220. {
  221. if (state_ == st_handshake_write || state_ == st_handshake_write_last) {
  222. return TLS_WANT_WRITE;
  223. }
  224. if (state_ == st_handshake_read) {
  225. return TLS_WANT_READ;
  226. }
  227. if (readBuf_.size() || decBuf_.size()) {
  228. return TLS_WANT_READ;
  229. }
  230. if (writeBuf_.size()) {
  231. return TLS_WANT_WRITE;
  232. }
  233. return TLS_WANT_READ;
  234. }
  235. ssize_t WinTLSSession::writeData(const void* data, size_t len)
  236. {
  237. if (state_ == st_handshake_write || state_ == st_handshake_write_last ||
  238. state_ == st_handshake_read) {
  239. // Renegotiating
  240. std::string hn, err;
  241. TLSVersion ver;
  242. auto connect = tlsConnect(hn, ver, err);
  243. if (connect != TLS_ERR_OK) {
  244. return connect;
  245. }
  246. // Continue.
  247. }
  248. if (state_ != st_connected && state_ != st_closing) {
  249. status_ = SEC_E_INVALID_HANDLE;
  250. return TLS_ERR_ERROR;
  251. }
  252. A2_LOG_DEBUG(fmt("WinTLS: Write request: %" PRIu64 " buffered: %" PRIu64,
  253. (uint64_t)len, (uint64_t)writeBuf_.size()));
  254. // Write remaining buffered data, if any.
  255. size_t written = 0;
  256. while (writeBuf_.size()) {
  257. written = ::send(sockfd_, writeBuf_.data(), writeBuf_.size(), 0);
  258. errno = ::WSAGetLastError();
  259. if (written < 0 && errno == WSAEINTR) {
  260. continue;
  261. }
  262. if (written < 0 && errno == WSAEWOULDBLOCK) {
  263. return TLS_ERR_WOULDBLOCK;
  264. }
  265. if (written == 0) {
  266. return written;
  267. }
  268. if (written < 0) {
  269. status_ = SEC_E_INVALID_HANDLE;
  270. state_ = st_error;
  271. return TLS_ERR_ERROR;
  272. }
  273. writeBuf_.eat(written);
  274. }
  275. if (len == 0) {
  276. return 0;
  277. }
  278. if (!streamSizes_) {
  279. streamSizes_.reset(new SecPkgContext_StreamSizes());
  280. status_ = ::QueryContextAttributes(&handle_, SECPKG_ATTR_STREAM_SIZES,
  281. streamSizes_.get());
  282. if (status_ != SEC_E_OK || !streamSizes_->cbMaximumMessage) {
  283. state_ = st_error;
  284. return TLS_ERR_ERROR;
  285. }
  286. }
  287. size_t process = len;
  288. auto bytes = reinterpret_cast<const char*>(data);
  289. if (writeBuffered_) {
  290. // There was buffered data, hence we need to "remove" that data from the
  291. // incoming buffer to avoid writing it again
  292. if (len < writeBuffered_) {
  293. // We didn't get called with the same data again, obviously.
  294. status_ = SEC_E_INVALID_HANDLE;
  295. status_ = st_error;
  296. return TLS_ERR_ERROR;
  297. }
  298. // just advance the buffer by writeBuffered_ bytes
  299. bytes += writeBuffered_;
  300. process -= writeBuffered_;
  301. writeBuffered_ = 0;
  302. }
  303. if (!process) {
  304. // The buffer contained the full remainder. At this point, the buffer has
  305. // been written, so the request is done in its entirety;
  306. return len;
  307. }
  308. // Buffered data was already written ;)
  309. // If there was no buffered data, this will be len - len = 0.
  310. len = len - process;
  311. while (process) {
  312. // Set up an outgoing message, according to streamSizes_
  313. writeBuffered_ = std::min(process, (size_t)streamSizes_->cbMaximumMessage);
  314. size_t dl =
  315. streamSizes_->cbHeader + writeBuffered_ + streamSizes_->cbTrailer;
  316. auto buf = make_unique<char[]>(dl);
  317. TLSBuffer buffers[] = {
  318. TLSBuffer(SECBUFFER_STREAM_HEADER, streamSizes_->cbHeader, buf.get()),
  319. TLSBuffer(SECBUFFER_DATA, writeBuffered_,
  320. buf.get() + streamSizes_->cbHeader),
  321. TLSBuffer(SECBUFFER_STREAM_TRAILER, streamSizes_->cbTrailer,
  322. buf.get() + streamSizes_->cbHeader + writeBuffered_),
  323. TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
  324. };
  325. TLSBufferDesc desc(buffers, 4);
  326. memcpy(buffers[1].pvBuffer, bytes, writeBuffered_);
  327. status_ = ::EncryptMessage(&handle_, 0, &desc, 0);
  328. if (status_ != SEC_E_OK) {
  329. A2_LOG_ERROR(fmt("WinTLS: Failed to encrypt a message! %s",
  330. getLastErrorString().c_str()));
  331. state_ = st_error;
  332. return TLS_ERR_ERROR;
  333. }
  334. // EncryptMessage may have truncated the buffers.
  335. // Should rarely happen, if ever, except for the trailer.
  336. dl = buffers[0].cbBuffer;
  337. if (dl < streamSizes_->cbHeader) {
  338. // Move message.
  339. memmove(buf.get() + dl, buffers[1].pvBuffer, buffers[1].cbBuffer);
  340. }
  341. dl += buffers[1].cbBuffer;
  342. if (dl < streamSizes_->cbHeader + writeBuffered_) {
  343. // Move trailer.
  344. memmove(buf.get() + dl, buffers[2].pvBuffer, buffers[2].cbBuffer);
  345. }
  346. dl += buffers[2].cbBuffer;
  347. // Write (or buffer) the message.
  348. char* p = buf.get();
  349. while (dl) {
  350. written = ::send(sockfd_, p, dl, 0);
  351. errno = ::WSAGetLastError();
  352. if (written < 0 && errno == WSAEINTR) {
  353. continue;
  354. }
  355. if (written < 0 && errno == WSAEWOULDBLOCK) {
  356. // Buffer the rest of the message...
  357. writeBuf_.write(p, dl);
  358. // and return...
  359. return len;
  360. }
  361. if (written == 0) {
  362. A2_LOG_ERROR("WinTLS: Connection closed while writing");
  363. status_ = SEC_E_INCOMPLETE_MESSAGE;
  364. state_ = st_error;
  365. return TLS_ERR_ERROR;
  366. }
  367. if (written < 0) {
  368. A2_LOG_ERROR("WinTLS: Connection error while writing");
  369. status_ = SEC_E_INCOMPLETE_MESSAGE;
  370. state_ = st_error;
  371. return TLS_ERR_ERROR;
  372. }
  373. dl -= written;
  374. p += written;
  375. }
  376. len += writeBuffered_;
  377. bytes += writeBuffered_;
  378. process -= writeBuffered_;
  379. writeBuffered_ = 0;
  380. }
  381. A2_LOG_DEBUG(fmt("WinTLS: Write result: %" PRIu64 " buffered: %" PRIu64,
  382. (uint64_t)len, (uint64_t)writeBuf_.size()));
  383. if (!len) {
  384. return TLS_ERR_WOULDBLOCK;
  385. }
  386. return len;
  387. }
  388. ssize_t WinTLSSession::readData(void* data, size_t len)
  389. {
  390. A2_LOG_DEBUG(fmt("WinTLS: Read request: %" PRIu64 " buffered: %" PRIu64,
  391. (uint64_t)len, (uint64_t)readBuf_.size()));
  392. if (len == 0) {
  393. return 0;
  394. }
  395. // Can be filled from decBuffer entirely?
  396. if (decBuf_.size() >= len) {
  397. A2_LOG_DEBUG("WinTLS: Fullfilling req from buffer");
  398. memcpy(data, decBuf_.data(), len);
  399. decBuf_.eat(len);
  400. return len;
  401. }
  402. if (state_ == st_closing || state_ == st_closed || state_ == st_error) {
  403. auto nread = decBuf_.size();
  404. if (nread) {
  405. assert(nread < len);
  406. memcpy(data, decBuf_.data(), nread);
  407. decBuf_.clear();
  408. A2_LOG_DEBUG("WinTLS: Sending out decrypted buffer after EOF");
  409. return nread;
  410. }
  411. A2_LOG_DEBUG("WinTLS: Read request aborted. Connection already closed");
  412. return state_ == st_error ? TLS_ERR_ERROR : 0;
  413. }
  414. if (state_ == st_handshake_write || state_ == st_handshake_write_last ||
  415. state_ == st_handshake_read) {
  416. // Renegotiating
  417. std::string hn, err;
  418. TLSVersion ver;
  419. auto connect = tlsConnect(hn, ver, err);
  420. if (connect != TLS_ERR_OK) {
  421. return connect;
  422. }
  423. // Continue.
  424. }
  425. if (state_ != st_connected) {
  426. status_ = SEC_E_INVALID_HANDLE;
  427. return TLS_ERR_ERROR;
  428. }
  429. // Read as many bytes as available from the connection, up to len + 4k.
  430. readBuf_.resize(len + 4_k);
  431. while (readBuf_.free()) {
  432. ssize_t read = ::recv(sockfd_, readBuf_.end(), readBuf_.free(), 0);
  433. errno = ::WSAGetLastError();
  434. if (read < 0 && errno == WSAEINTR) {
  435. continue;
  436. }
  437. if (read < 0 && errno == WSAEWOULDBLOCK) {
  438. break;
  439. }
  440. if (read < 0) {
  441. status_ = SEC_E_INCOMPLETE_MESSAGE;
  442. state_ = st_error;
  443. return TLS_ERR_ERROR;
  444. }
  445. if (read == 0) {
  446. A2_LOG_DEBUG("WinTLS: Connection abruptly closed!");
  447. // At least try to gracefully close our write end.
  448. closeConnection();
  449. break;
  450. }
  451. readBuf_.advance(read);
  452. }
  453. // Try to decrypt as many messages as possible from the readBuf_.
  454. while (readBuf_.size()) {
  455. TLSBuffer bufs[] = {
  456. TLSBuffer(SECBUFFER_DATA, readBuf_.size(), readBuf_.data()),
  457. TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
  458. TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
  459. TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
  460. };
  461. TLSBufferDesc desc(bufs, 4);
  462. status_ = ::DecryptMessage(&handle_, &desc, 0, nullptr);
  463. if (status_ == SEC_E_INCOMPLETE_MESSAGE) {
  464. // Need to stop now, and wait for more bytes to arrive on the socket.
  465. break;
  466. }
  467. if (status_ != SEC_E_OK && status_ != SEC_I_CONTEXT_EXPIRED &&
  468. status_ != SEC_I_RENEGOTIATE) {
  469. A2_LOG_ERROR(fmt("WinTLS: Failed to decrypt a message! %s",
  470. getLastErrorString().c_str()));
  471. state_ = st_error;
  472. return TLS_ERR_ERROR;
  473. }
  474. // Decrypted message successfully.
  475. bool ate = false;
  476. for (auto& buf : bufs) {
  477. if (buf.BufferType == SECBUFFER_DATA && buf.cbBuffer > 0) {
  478. decBuf_.write(buf.pvBuffer, buf.cbBuffer);
  479. }
  480. else if (buf.BufferType == SECBUFFER_EXTRA && buf.cbBuffer > 0) {
  481. readBuf_.eat(readBuf_.size() - buf.cbBuffer);
  482. ate = true;
  483. }
  484. }
  485. if (!ate) {
  486. readBuf_.clear();
  487. }
  488. if (status_ == SEC_I_RENEGOTIATE) {
  489. // Renegotiation basically means performing another handshake
  490. state_ = st_initialized;
  491. A2_LOG_INFO("WinTLS: Renegotiate");
  492. std::string hn, err;
  493. TLSVersion ver;
  494. auto connect = tlsConnect(hn, ver, err);
  495. if (connect == TLS_ERR_WOULDBLOCK) {
  496. break;
  497. }
  498. if (connect == TLS_ERR_ERROR) {
  499. return connect;
  500. }
  501. // Still good.
  502. }
  503. if (status_ == SEC_I_CONTEXT_EXPIRED) {
  504. // Connection is gone now, but the buffered bytes are still valid.
  505. A2_LOG_DEBUG("WinTLS: Connection gracefully closed!");
  506. closeConnection();
  507. break;
  508. }
  509. }
  510. len = std::min(decBuf_.size(), len);
  511. if (len == 0) {
  512. if (state_ != st_connected) {
  513. return state_ == st_error ? TLS_ERR_ERROR : 0;
  514. }
  515. return TLS_ERR_WOULDBLOCK;
  516. }
  517. memcpy(data, decBuf_.data(), len);
  518. decBuf_.eat(len);
  519. return len;
  520. }
  521. int WinTLSSession::tlsConnect(const std::string& hostname, TLSVersion& version,
  522. std::string& handshakeErr)
  523. {
  524. // Handshaking will require sending multiple read/write exchanges until the
  525. // handshake is actually done. The client will first generate the initial
  526. // handshake message, then write that to the server, read the response
  527. // message, and write and/or read additional messages until the handshake is
  528. // either complete and successful, or something went wrong.
  529. // The server works analog to that.
  530. A2_LOG_DEBUG("WinTLS: Starting/Resuming TLS Connect");
  531. ULONG flags = 0;
  532. restart:
  533. switch (state_) {
  534. default:
  535. A2_LOG_ERROR("WinTLS: Invalid state");
  536. status_ = SEC_E_INVALID_HANDLE;
  537. return TLS_ERR_ERROR;
  538. case st_initialized: {
  539. if (side_ == TLS_SERVER) {
  540. goto read;
  541. }
  542. if (!hostname.empty()) {
  543. setSNIHostname(hostname);
  544. }
  545. A2_LOG_DEBUG("WinTLS: Initializing handshake");
  546. TLSBuffer buf(SECBUFFER_EMPTY, 0, nullptr);
  547. TLSBufferDesc desc(&buf, 1);
  548. SEC_CHAR* host =
  549. hostname_.empty() ? nullptr : const_cast<SEC_CHAR*>(hostname_.c_str());
  550. status_ = ::InitializeSecurityContext(cred_, nullptr, host, kReqFlags, 0, 0,
  551. nullptr, 0, &handle_, &desc, &flags,
  552. nullptr);
  553. if (status_ != SEC_I_CONTINUE_NEEDED) {
  554. // Has to be SEC_I_CONTINUE_NEEDED, as we did not actually send data
  555. // at this point.
  556. state_ = st_error;
  557. return TLS_ERR_ERROR;
  558. }
  559. // Queue the initial message...
  560. writeBuf_.write(buf.pvBuffer, buf.cbBuffer);
  561. FreeContextBuffer(buf.pvBuffer);
  562. // ... and start sending it
  563. state_ = st_handshake_write;
  564. }
  565. // Fall through
  566. case st_handshake_write_last:
  567. case st_handshake_write: {
  568. A2_LOG_DEBUG("WinTLS: Writing handshake");
  569. // Write the currently queued handshake message until all data is sent.
  570. while (writeBuf_.size()) {
  571. ssize_t writ = ::send(sockfd_, writeBuf_.data(), writeBuf_.size(), 0);
  572. errno = ::WSAGetLastError();
  573. if (writ < 0 && errno == WSAEINTR) {
  574. continue;
  575. }
  576. if (writ < 0 && errno == WSAEWOULDBLOCK) {
  577. return TLS_ERR_WOULDBLOCK;
  578. }
  579. if (writ <= 0) {
  580. status_ = SEC_E_INCOMPLETE_MESSAGE;
  581. state_ = st_error;
  582. return TLS_ERR_ERROR;
  583. }
  584. writeBuf_.eat(writ);
  585. }
  586. if (state_ == st_handshake_write_last) {
  587. state_ = st_handshake_done;
  588. goto restart;
  589. }
  590. // Have to read one or more response messages.
  591. state_ = st_handshake_read;
  592. }
  593. // Fall through
  594. case st_handshake_read: {
  595. read:
  596. A2_LOG_DEBUG("WinTLS: Reading handshake...");
  597. // All write buffered data is invalid at this point!
  598. writeBuf_.clear();
  599. // Read as many bytes as possible, up to 4k new bytes.
  600. // We do not know how many bytes will arrive from the server at this
  601. // point.
  602. readBuf_.resize(readBuf_.size() + 4_k);
  603. while (readBuf_.free()) {
  604. ssize_t read = ::recv(sockfd_, readBuf_.end(), readBuf_.free(), 0);
  605. errno = ::WSAGetLastError();
  606. if (read < 0 && errno == WSAEINTR) {
  607. continue;
  608. }
  609. if (read < 0 && errno == WSAEWOULDBLOCK) {
  610. break;
  611. }
  612. if (read <= 0) {
  613. status_ = SEC_E_INCOMPLETE_MESSAGE;
  614. state_ = st_error;
  615. return TLS_ERR_ERROR;
  616. }
  617. if (read == 0) {
  618. A2_LOG_DEBUG("WinTLS: Connection abruptly closed during handshake!");
  619. status_ = SEC_E_INCOMPLETE_MESSAGE;
  620. state_ = st_error;
  621. return TLS_ERR_ERROR;
  622. }
  623. readBuf_.advance(read);
  624. break;
  625. }
  626. if (!readBuf_.size()) {
  627. return TLS_ERR_WOULDBLOCK;
  628. }
  629. // Need to copy the data, as Schannel is free to mess with it. But we
  630. // might later need unmodified data from the original read buffer.
  631. auto bufcopy = make_unique<char[]>(readBuf_.size());
  632. memcpy(bufcopy.get(), readBuf_.data(), readBuf_.size());
  633. // Set up buffers. inbufs will be the raw bytes the library has to decode.
  634. // outbufs will contain generated responses, if any.
  635. TLSBuffer inbufs[] = {
  636. TLSBuffer(SECBUFFER_TOKEN, readBuf_.size(), bufcopy.get()),
  637. TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
  638. };
  639. TLSBufferDesc indesc(inbufs, 2);
  640. TLSBuffer outbufs[] = {
  641. TLSBuffer(SECBUFFER_TOKEN, 0, nullptr),
  642. TLSBuffer(SECBUFFER_ALERT, 0, nullptr),
  643. };
  644. TLSBufferDesc outdesc(outbufs, 2);
  645. if (side_ == TLS_CLIENT) {
  646. SEC_CHAR* host = hostname_.empty()
  647. ? nullptr
  648. : const_cast<SEC_CHAR*>(hostname_.c_str());
  649. status_ = ::InitializeSecurityContext(cred_, &handle_, host, kReqFlags, 0,
  650. 0, &indesc, 0, nullptr, &outdesc,
  651. &flags, nullptr);
  652. }
  653. else {
  654. status_ = ::AcceptSecurityContext(
  655. cred_, state_ == st_initialized ? nullptr : &handle_, &indesc,
  656. kReqAFlags, 0, state_ == st_initialized ? &handle_ : nullptr,
  657. &outdesc, &flags, nullptr);
  658. }
  659. if (status_ == SEC_E_INCOMPLETE_MESSAGE) {
  660. // Not enough raw bytes read yet to decode a full message.
  661. return TLS_ERR_WOULDBLOCK;
  662. }
  663. if (status_ != SEC_E_OK && status_ != SEC_I_CONTINUE_NEEDED) {
  664. state_ = st_error;
  665. return TLS_ERR_ERROR;
  666. }
  667. // Raw bytes where not entirely consumed, i.e. readBuf_ still contains
  668. // unprocessed data from the next message?
  669. if (inbufs[1].BufferType == SECBUFFER_EXTRA && inbufs[1].cbBuffer > 0) {
  670. readBuf_.eat(readBuf_.size() - inbufs[1].cbBuffer);
  671. }
  672. else {
  673. readBuf_.clear();
  674. }
  675. // Check if the library produced a new outgoing message and queue it.
  676. for (auto& buf : outbufs) {
  677. if (buf.BufferType == SECBUFFER_TOKEN && buf.cbBuffer > 0) {
  678. writeBuf_.write(buf.pvBuffer, buf.cbBuffer);
  679. FreeContextBuffer(buf.pvBuffer);
  680. state_ = st_handshake_write;
  681. }
  682. }
  683. // Need to read additional messages?
  684. if (status_ == SEC_I_CONTINUE_NEEDED) {
  685. A2_LOG_DEBUG("WinTLS: Continuing with handshake");
  686. goto restart;
  687. }
  688. if (side_ == TLS_CLIENT && flags != kReqFlags) {
  689. A2_LOG_ERROR(fmt("WinTLS: Channel setup failed. Schannel provider did "
  690. "not fulfill requested flags. "
  691. "Excepted: %lu Actual: %lu",
  692. kReqFlags, flags));
  693. status_ = SEC_E_INTERNAL_ERROR;
  694. state_ = st_error;
  695. return TLS_ERR_ERROR;
  696. }
  697. if (state_ == st_handshake_write) {
  698. A2_LOG_DEBUG("WinTLS: Continuing with handshake (last write)");
  699. state_ = st_handshake_write_last;
  700. goto restart;
  701. }
  702. }
  703. // Fall through
  704. case st_handshake_done:
  705. // All ready now :D
  706. state_ = st_connected;
  707. A2_LOG_INFO(
  708. fmt("WinTLS: connected with: %s", getCipherSuite(&handle_).c_str()));
  709. switch (getProtocolVersion(&handle_)) {
  710. case 0x300:
  711. version = TLS_PROTO_SSL3;
  712. break;
  713. case 0x301:
  714. version = TLS_PROTO_TLS10;
  715. break;
  716. case 0x302:
  717. version = TLS_PROTO_TLS11;
  718. break;
  719. case 0x303:
  720. version = TLS_PROTO_TLS12;
  721. break;
  722. default:
  723. version = TLS_PROTO_NONE;
  724. break;
  725. }
  726. return TLS_ERR_OK;
  727. }
  728. A2_LOG_ERROR("WinTLS: Unreachable reached during tlsConnect! This is a bug!");
  729. state_ = st_error;
  730. return TLS_ERR_ERROR;
  731. }
  732. int WinTLSSession::tlsAccept(TLSVersion& version)
  733. {
  734. std::string host, err;
  735. return tlsConnect(host, version, err);
  736. }
  737. std::string WinTLSSession::getLastErrorString()
  738. {
  739. std::stringstream ss;
  740. wchar_t* buf = nullptr;
  741. auto rv = FormatMessageW(
  742. FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM |
  743. FORMAT_MESSAGE_IGNORE_INSERTS,
  744. nullptr, status_, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPWSTR)&buf,
  745. 1024, nullptr);
  746. if (rv && buf) {
  747. ss << "Error: " << wCharToUtf8(buf) << "(" << std::hex << status_ << ")";
  748. LocalFree(buf);
  749. }
  750. else {
  751. ss << "Error: " << std::hex << status_;
  752. }
  753. return ss.str();
  754. }
  755. size_t WinTLSSession::getRecvBufferedLength() { return decBuf_.size(); }
  756. } // namespace aria2