WinTLSSession.cc 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881
  1. /* <!-- copyright */
  2. /*
  3. * aria2 - The high speed download utility
  4. *
  5. * Copyright (C) 2013 Nils Maier
  6. *
  7. * This program is free software; you can redistribute it and/or modify
  8. * it under the terms of the GNU General Public License as published by
  9. * the Free Software Foundation; either version 2 of the License, or
  10. * (at your option) any later version.
  11. *
  12. * This program is distributed in the hope that it will be useful,
  13. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  14. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  15. * GNU General Public License for more details.
  16. *
  17. * You should have received a copy of the GNU General Public License
  18. * along with this program; if not, write to the Free Software
  19. * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  20. *
  21. * In addition, as a special exception, the copyright holders give
  22. * permission to link the code of portions of this program with the
  23. * OpenSSL library under certain conditions as described in each
  24. * individual source file, and distribute linked combinations
  25. * including the two.
  26. * You must obey the GNU General Public License in all respects
  27. * for all of the code used other than OpenSSL. If you modify
  28. * file(s) with this exception, you may extend this exception to your
  29. * version of the file(s), but you are not obligated to do so. If you
  30. * do not wish to do so, delete this exception statement from your
  31. * version. If you delete this exception statement from all source
  32. * files in the program, then also delete it here.
  33. */
  34. /* copyright --> */
  35. #include "WinTLSSession.h"
  36. #include <cassert>
  37. #include <sstream>
  38. #include "LogFactory.h"
  39. #include "a2functional.h"
  40. #include "fmt.h"
  41. #include "util.h"
  42. #ifndef SECBUFFER_ALERT
  43. #define SECBUFFER_ALERT 17
  44. #endif
  45. #ifndef SZ_ALG_MAX_SIZE
  46. #define SZ_ALG_MAX_SIZE 64
  47. #endif
  48. #ifndef SECPKGCONTEXT_CIPHERINFO_V1
  49. #define SECPKGCONTEXT_CIPHERINFO_V1 1
  50. #endif
  51. #ifndef SECPKG_ATTR_CIPHER_INFO
  52. #define SECPKG_ATTR_CIPHER_INFO 0x64
  53. #endif
  54. namespace {
  55. using namespace aria2;
  56. struct WinSecPkgContext_CipherInfo
  57. {
  58. DWORD dwVersion;
  59. DWORD dwProtocol;
  60. DWORD dwCipherSuite;
  61. DWORD dwBaseCipherSuite;
  62. WCHAR szCipherSuite[SZ_ALG_MAX_SIZE];
  63. WCHAR szCipher[SZ_ALG_MAX_SIZE];
  64. DWORD dwCipherLen;
  65. DWORD dwCipherBlockLen; // in bytes
  66. WCHAR szHash[SZ_ALG_MAX_SIZE];
  67. DWORD dwHashLen;
  68. WCHAR szExchange[SZ_ALG_MAX_SIZE];
  69. DWORD dwMinExchangeLen;
  70. DWORD dwMaxExchangeLen;
  71. WCHAR szCertificate[SZ_ALG_MAX_SIZE];
  72. DWORD dwKeyType;
  73. };
  74. static const ULONG kReqFlags =
  75. ISC_REQ_SEQUENCE_DETECT | ISC_REQ_REPLAY_DETECT | ISC_REQ_CONFIDENTIALITY |
  76. ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_USE_SUPPLIED_CREDS | ISC_REQ_STREAM;
  77. static const ULONG kReqAFlags =
  78. ASC_REQ_SEQUENCE_DETECT | ASC_REQ_REPLAY_DETECT | ASC_REQ_CONFIDENTIALITY |
  79. ASC_REQ_EXTENDED_ERROR | ASC_REQ_ALLOCATE_MEMORY | ASC_REQ_STREAM;
  80. class TLSBuffer : public ::SecBuffer
  81. {
  82. public:
  83. explicit TLSBuffer(ULONG type, ULONG size, void* data)
  84. {
  85. cbBuffer = size;
  86. BufferType = type;
  87. pvBuffer = data;
  88. }
  89. };
  90. class TLSBufferDesc : public ::SecBufferDesc
  91. {
  92. public:
  93. explicit TLSBufferDesc(SecBuffer* arr, ULONG buffers)
  94. {
  95. ulVersion = SECBUFFER_VERSION;
  96. cBuffers = buffers;
  97. pBuffers = arr;
  98. }
  99. };
  100. inline static std::string getCipherSuite(CtxtHandle* handle)
  101. {
  102. WinSecPkgContext_CipherInfo info = {SECPKGCONTEXT_CIPHERINFO_V1};
  103. if (QueryContextAttributes(handle, SECPKG_ATTR_CIPHER_INFO, &info) ==
  104. SEC_E_OK) {
  105. return wCharToUtf8(info.szCipherSuite);
  106. }
  107. return "Unknown";
  108. }
  109. inline static uint32_t getProtocolVersion(CtxtHandle* handle)
  110. {
  111. WinSecPkgContext_CipherInfo info = {SECPKGCONTEXT_CIPHERINFO_V1};
  112. if (QueryContextAttributes(handle, SECPKG_ATTR_CIPHER_INFO, &info) ==
  113. SEC_E_OK) {
  114. return info.dwProtocol;
  115. }
  116. // XXX Assume the best?!
  117. return std::numeric_limits<uint32_t>::max();
  118. }
  119. } // namespace
  120. namespace aria2 {
  121. TLSSession* TLSSession::make(TLSContext* ctx)
  122. {
  123. return new WinTLSSession(static_cast<WinTLSContext*>(ctx));
  124. }
  125. WinTLSSession::WinTLSSession(WinTLSContext* ctx)
  126. : sockfd_(0),
  127. side_(ctx->getSide()),
  128. cred_(ctx->getCredHandle()),
  129. writeBuffered_(0),
  130. state_(st_constructed),
  131. status_(SEC_E_OK)
  132. {
  133. memset(&handle_, 0, sizeof(handle_));
  134. }
  135. WinTLSSession::~WinTLSSession()
  136. {
  137. ::DeleteSecurityContext(&handle_);
  138. state_ = st_error;
  139. }
  140. int WinTLSSession::init(sock_t sockfd)
  141. {
  142. if (state_ != st_constructed) {
  143. status_ = SEC_E_INVALID_HANDLE;
  144. return TLS_ERR_ERROR;
  145. }
  146. sockfd_ = sockfd;
  147. state_ = st_initialized;
  148. return TLS_ERR_OK;
  149. }
  150. int WinTLSSession::setSNIHostname(const std::string& hostname)
  151. {
  152. if (state_ != st_initialized) {
  153. status_ = SEC_E_INVALID_HANDLE;
  154. return TLS_ERR_ERROR;
  155. }
  156. hostname_ = hostname;
  157. return TLS_ERR_OK;
  158. }
  159. int WinTLSSession::closeConnection()
  160. {
  161. if (state_ != st_connected && state_ != st_closing) {
  162. if (state_ != st_error) {
  163. status_ = SEC_E_INVALID_HANDLE;
  164. state_ = st_error;
  165. }
  166. A2_LOG_DEBUG("WinTLS: Cannot close connection");
  167. return TLS_ERR_ERROR;
  168. }
  169. if (state_ == st_connected) {
  170. A2_LOG_DEBUG("WinTLS: Closing connection");
  171. state_ = st_closing;
  172. DWORD dwShut = SCHANNEL_SHUTDOWN;
  173. TLSBuffer shut(SECBUFFER_TOKEN, sizeof(dwShut), &dwShut);
  174. TLSBufferDesc shutDesc(&shut, 1);
  175. status_ = ::ApplyControlToken(&handle_, &shutDesc);
  176. if (status_ != SEC_E_OK) {
  177. state_ = st_error;
  178. return TLS_ERR_ERROR;
  179. }
  180. TLSBuffer ctx(SECBUFFER_EMPTY, 0, nullptr);
  181. TLSBufferDesc desc(&ctx, 1);
  182. ULONG flags = 0;
  183. if (side_ == TLS_CLIENT) {
  184. SEC_CHAR* host = hostname_.empty() ?
  185. nullptr :
  186. const_cast<SEC_CHAR*>(hostname_.c_str());
  187. status_ = ::InitializeSecurityContext(cred_,
  188. &handle_,
  189. host,
  190. kReqFlags,
  191. 0,
  192. 0,
  193. nullptr,
  194. 0,
  195. &handle_,
  196. &desc,
  197. &flags,
  198. nullptr);
  199. }
  200. else {
  201. status_ = ::AcceptSecurityContext(cred_,
  202. &handle_,
  203. nullptr,
  204. kReqAFlags,
  205. 0,
  206. &handle_,
  207. &desc,
  208. &flags,
  209. nullptr);
  210. }
  211. if (status_ == SEC_E_OK || status_ == SEC_I_CONTEXT_EXPIRED) {
  212. size_t len = ctx.cbBuffer;
  213. ssize_t rv = writeData(ctx.pvBuffer, ctx.cbBuffer);
  214. ::FreeContextBuffer(ctx.pvBuffer);
  215. if (rv == TLS_ERR_WOULDBLOCK) {
  216. return rv;
  217. }
  218. // Alright data is sent or buffered
  219. if (rv - len != 0) {
  220. return TLS_ERR_WOULDBLOCK;
  221. }
  222. }
  223. }
  224. // Send remaining data.
  225. while (writeBuf_.size()) {
  226. int rv = writeData(nullptr, 0);
  227. if (rv == TLS_ERR_WOULDBLOCK) {
  228. return rv;
  229. }
  230. }
  231. A2_LOG_DEBUG("WinTLS: Closed Connection");
  232. state_ = st_closed;
  233. return TLS_ERR_OK;
  234. }
  235. int WinTLSSession::checkDirection()
  236. {
  237. if (state_ == st_handshake_write || state_ == st_handshake_write_last) {
  238. return TLS_WANT_WRITE;
  239. }
  240. if (state_ == st_handshake_read) {
  241. return TLS_WANT_READ;
  242. }
  243. if (readBuf_.size() || decBuf_.size()) {
  244. return TLS_WANT_READ;
  245. }
  246. if (writeBuf_.size()) {
  247. return TLS_WANT_WRITE;
  248. }
  249. return TLS_WANT_READ;
  250. }
  251. ssize_t WinTLSSession::writeData(const void* data, size_t len)
  252. {
  253. if (state_ == st_handshake_write || state_ == st_handshake_write_last ||
  254. state_ == st_handshake_read) {
  255. // Renegotiating
  256. std::string hn, err;
  257. auto connect = tlsConnect(hn, err);
  258. if (connect != TLS_ERR_OK) {
  259. return connect;
  260. }
  261. // Continue.
  262. }
  263. if (state_ != st_connected && state_ != st_closing) {
  264. status_ = SEC_E_INVALID_HANDLE;
  265. return TLS_ERR_ERROR;
  266. }
  267. A2_LOG_DEBUG(fmt("WinTLS: Write request: %" PRIu64 " buffered: %" PRIu64,
  268. (uint64_t)len,
  269. (uint64_t)writeBuf_.size()));
  270. // Write remaining buffered data, if any.
  271. size_t written = 0;
  272. while (writeBuf_.size()) {
  273. written = ::send(sockfd_, writeBuf_.data(), writeBuf_.size(), 0);
  274. errno = ::WSAGetLastError();
  275. if (written < 0 && errno == WSAEINTR) {
  276. continue;
  277. }
  278. if (written < 0 && errno == WSAEWOULDBLOCK) {
  279. return TLS_ERR_WOULDBLOCK;
  280. }
  281. if (written == 0) {
  282. return written;
  283. }
  284. if (written < 0) {
  285. status_ = SEC_E_INVALID_HANDLE;
  286. state_ = st_error;
  287. return TLS_ERR_ERROR;
  288. }
  289. writeBuf_.eat(written);
  290. }
  291. if (len == 0) {
  292. return 0;
  293. }
  294. if (!streamSizes_) {
  295. streamSizes_.reset(new SecPkgContext_StreamSizes());
  296. status_ = ::QueryContextAttributes(
  297. &handle_, SECPKG_ATTR_STREAM_SIZES, streamSizes_.get());
  298. if (status_ != SEC_E_OK || !streamSizes_->cbMaximumMessage) {
  299. state_ = st_error;
  300. return TLS_ERR_ERROR;
  301. }
  302. }
  303. size_t process = len;
  304. auto bytes = reinterpret_cast<const char*>(data);
  305. if (writeBuffered_) {
  306. // There was buffered data, hence we need to "remove" that data from the
  307. // incoming buffer to avoid writing it again
  308. if (len < writeBuffered_) {
  309. // We didn't get called with the same data again, obviously.
  310. status_ = SEC_E_INVALID_HANDLE;
  311. status_ = st_error;
  312. return TLS_ERR_ERROR;
  313. }
  314. // just advance the buffer by writeBuffered_ bytes
  315. bytes += writeBuffered_;
  316. process -= writeBuffered_;
  317. writeBuffered_ = 0;
  318. }
  319. if (!process) {
  320. // The buffer contained the full remainder. At this point, the buffer has
  321. // been written, so the request is done in its entirety;
  322. return len;
  323. }
  324. // Buffered data was already written ;)
  325. // If there was no buffered data, this will be len - len = 0.
  326. len = len - process;
  327. while (process) {
  328. // Set up an outgoing message, according to streamSizes_
  329. writeBuffered_ = std::min(process, (size_t)streamSizes_->cbMaximumMessage);
  330. size_t dl =
  331. streamSizes_->cbHeader + writeBuffered_ + streamSizes_->cbTrailer;
  332. auto buf = make_unique<char[]>(dl);
  333. TLSBuffer buffers[] = {
  334. TLSBuffer(SECBUFFER_STREAM_HEADER, streamSizes_->cbHeader, buf.get()),
  335. TLSBuffer(
  336. SECBUFFER_DATA, writeBuffered_, buf.get() + streamSizes_->cbHeader),
  337. TLSBuffer(SECBUFFER_STREAM_TRAILER,
  338. streamSizes_->cbTrailer,
  339. buf.get() + streamSizes_->cbHeader + writeBuffered_),
  340. TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
  341. };
  342. TLSBufferDesc desc(buffers, 4);
  343. memcpy(buffers[1].pvBuffer, bytes, writeBuffered_);
  344. status_ = ::EncryptMessage(&handle_, 0, &desc, 0);
  345. if (status_ != SEC_E_OK) {
  346. A2_LOG_ERROR(fmt("WinTLS: Failed to encrypt a message! %s",
  347. getLastErrorString().c_str()));
  348. state_ = st_error;
  349. return TLS_ERR_ERROR;
  350. }
  351. // EncryptMessage may have truncated the buffers.
  352. // Should rarely happen, if ever, except for the trailer.
  353. dl = buffers[0].cbBuffer;
  354. if (dl < streamSizes_->cbHeader) {
  355. // Move message.
  356. memmove(buf.get() + dl, buffers[1].pvBuffer, buffers[1].cbBuffer);
  357. }
  358. dl += buffers[1].cbBuffer;
  359. if (dl < streamSizes_->cbHeader + writeBuffered_) {
  360. // Move tailer.
  361. memmove(buf.get() + dl, buffers[2].pvBuffer, buffers[2].cbBuffer);
  362. }
  363. dl += buffers[2].cbBuffer;
  364. // Write (or buffer) the message.
  365. char* p = buf.get();
  366. while (dl) {
  367. written = ::send(sockfd_, p, dl, 0);
  368. errno = ::WSAGetLastError();
  369. if (written < 0 && errno == WSAEINTR) {
  370. continue;
  371. }
  372. if (written < 0 && errno == WSAEWOULDBLOCK) {
  373. // Buffer the rest of the message...
  374. writeBuf_.write(p, dl);
  375. // and return...
  376. return len;
  377. }
  378. if (written == 0) {
  379. A2_LOG_ERROR("WinTLS: Connection closed while writing");
  380. status_ = SEC_E_INCOMPLETE_MESSAGE;
  381. state_ = st_error;
  382. return TLS_ERR_ERROR;
  383. }
  384. if (written < 0) {
  385. A2_LOG_ERROR("WinTLS: Connection error while writing");
  386. status_ = SEC_E_INCOMPLETE_MESSAGE;
  387. state_ = st_error;
  388. return TLS_ERR_ERROR;
  389. }
  390. dl -= written;
  391. p += written;
  392. }
  393. len += writeBuffered_;
  394. bytes += writeBuffered_;
  395. process -= writeBuffered_;
  396. writeBuffered_ = 0;
  397. }
  398. A2_LOG_DEBUG(fmt("WinTLS: Write result: %" PRIu64 " buffered: %" PRIu64,
  399. (uint64_t)len,
  400. (uint64_t)writeBuf_.size()));
  401. if (!len) {
  402. return TLS_ERR_WOULDBLOCK;
  403. }
  404. return len;
  405. }
  406. ssize_t WinTLSSession::readData(void* data, size_t len)
  407. {
  408. A2_LOG_DEBUG(fmt("WinTLS: Read request: %" PRIu64 " buffered: %" PRIu64,
  409. (uint64_t)len,
  410. (uint64_t)readBuf_.size()));
  411. if (len == 0) {
  412. return 0;
  413. }
  414. // Can be filled from decBuffer entirely?
  415. if (decBuf_.size() >= len) {
  416. A2_LOG_DEBUG("WinTLS: Fullfilling req from buffer");
  417. memcpy(data, decBuf_.data(), len);
  418. decBuf_.eat(len);
  419. return len;
  420. }
  421. if (state_ == st_closing || state_ == st_closed || state_ == st_error) {
  422. auto nread = decBuf_.size();
  423. if (nread) {
  424. assert(nread < len);
  425. memcpy(data, decBuf_.data(), nread);
  426. decBuf_.clear();
  427. A2_LOG_DEBUG("WinTLS: Sending out decrypted buffer after EOF");
  428. return nread;
  429. }
  430. A2_LOG_DEBUG("WinTLS: Read request aborted. Connection already closed");
  431. return state_ == st_error ? TLS_ERR_ERROR : 0;
  432. }
  433. if (state_ == st_handshake_write || state_ == st_handshake_write_last ||
  434. state_ == st_handshake_read) {
  435. // Renegotiating
  436. std::string hn, err;
  437. auto connect = tlsConnect(hn, err);
  438. if (connect != TLS_ERR_OK) {
  439. return connect;
  440. }
  441. // Continue.
  442. }
  443. if (state_ != st_connected) {
  444. status_ = SEC_E_INVALID_HANDLE;
  445. return TLS_ERR_ERROR;
  446. }
  447. // Read as many bytes as available from the connection, up to len + 4k.
  448. readBuf_.resize(len + 4096);
  449. while (readBuf_.free()) {
  450. ssize_t read = ::recv(sockfd_, readBuf_.end(), readBuf_.free(), 0);
  451. errno = ::WSAGetLastError();
  452. if (read < 0 && errno == WSAEINTR) {
  453. continue;
  454. }
  455. if (read < 0 && errno == WSAEWOULDBLOCK) {
  456. break;
  457. }
  458. if (read < 0) {
  459. status_ = SEC_E_INCOMPLETE_MESSAGE;
  460. state_ = st_error;
  461. return TLS_ERR_ERROR;
  462. }
  463. if (read == 0) {
  464. A2_LOG_DEBUG("WinTLS: Connection abruptly closed!");
  465. // At least try to gracefully close our write end.
  466. closeConnection();
  467. break;
  468. }
  469. readBuf_.advance(read);
  470. }
  471. // Try to decrypt as many messages as possible from the readBuf_.
  472. while (readBuf_.size()) {
  473. TLSBuffer bufs[] = {
  474. TLSBuffer(SECBUFFER_DATA, readBuf_.size(), readBuf_.data()),
  475. TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
  476. TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
  477. TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
  478. };
  479. TLSBufferDesc desc(bufs, 4);
  480. status_ = ::DecryptMessage(&handle_, &desc, 0, nullptr);
  481. if (status_ == SEC_E_INCOMPLETE_MESSAGE) {
  482. // Need to stop now, and wait for more bytes to arrive on the socket.
  483. break;
  484. }
  485. if (status_ != SEC_E_OK && status_ != SEC_I_CONTEXT_EXPIRED &&
  486. status_ != SEC_I_RENEGOTIATE) {
  487. A2_LOG_ERROR(fmt("WinTLS: Failed to decrypt a message! %s",
  488. getLastErrorString().c_str()));
  489. state_ = st_error;
  490. return TLS_ERR_ERROR;
  491. }
  492. // Decrypted message successfully.
  493. bool ate = false;
  494. for (auto& buf : bufs) {
  495. if (buf.BufferType == SECBUFFER_DATA && buf.cbBuffer > 0) {
  496. decBuf_.write(buf.pvBuffer, buf.cbBuffer);
  497. }
  498. else if (buf.BufferType == SECBUFFER_EXTRA && buf.cbBuffer > 0) {
  499. readBuf_.eat(readBuf_.size() - buf.cbBuffer);
  500. ate = true;
  501. }
  502. }
  503. if (!ate) {
  504. readBuf_.clear();
  505. }
  506. if (status_ == SEC_I_RENEGOTIATE) {
  507. // Renegotiation basically means performing another handshake
  508. state_ = st_initialized;
  509. A2_LOG_INFO("WinTLS: Renegotiate");
  510. std::string hn, err;
  511. auto connect = tlsConnect(hn, err);
  512. if (connect == TLS_ERR_WOULDBLOCK) {
  513. break;
  514. }
  515. if (connect == TLS_ERR_ERROR) {
  516. return connect;
  517. }
  518. // Still good.
  519. }
  520. if (status_ == SEC_I_CONTEXT_EXPIRED) {
  521. // Connection is gone now, but the buffered bytes are still valid.
  522. A2_LOG_DEBUG("WinTLS: Connection gracefully closed!");
  523. closeConnection();
  524. break;
  525. }
  526. }
  527. len = std::min(decBuf_.size(), len);
  528. if (len == 0) {
  529. if (state_ != st_connected) {
  530. return state_ == st_error ? TLS_ERR_ERROR : 0;
  531. }
  532. return TLS_ERR_WOULDBLOCK;
  533. }
  534. memcpy(data, decBuf_.data(), len);
  535. decBuf_.eat(len);
  536. return len;
  537. }
  538. int WinTLSSession::tlsConnect(const std::string& hostname,
  539. std::string& handshakeErr)
  540. {
  541. // Handshaking will require sending multiple read/write exchanges until the
  542. // handshake is actually done. The client will first generate the initial
  543. // handshake message, then write that to the server, read the response
  544. // message, and write and/or read additional messages until the handshake is
  545. // either complete and successful, or something went wrong.
  546. // The server works analog to that.
  547. A2_LOG_DEBUG("WinTLS: Starting/Resuming TLS Connect");
  548. ULONG flags = 0;
  549. restart:
  550. switch (state_) {
  551. default:
  552. A2_LOG_ERROR("WinTLS: Invalid state");
  553. status_ = SEC_E_INVALID_HANDLE;
  554. return TLS_ERR_ERROR;
  555. case st_initialized: {
  556. if (side_ == TLS_SERVER) {
  557. goto read;
  558. }
  559. if (!hostname.empty()) {
  560. setSNIHostname(hostname);
  561. }
  562. A2_LOG_DEBUG("WinTLS: Initializing handshake");
  563. TLSBuffer buf(SECBUFFER_EMPTY, 0, nullptr);
  564. TLSBufferDesc desc(&buf, 1);
  565. SEC_CHAR* host =
  566. hostname_.empty() ? nullptr : const_cast<SEC_CHAR*>(hostname_.c_str());
  567. status_ = ::InitializeSecurityContext(cred_,
  568. nullptr,
  569. host,
  570. kReqFlags,
  571. 0,
  572. 0,
  573. nullptr,
  574. 0,
  575. &handle_,
  576. &desc,
  577. &flags,
  578. nullptr);
  579. if (status_ != SEC_I_CONTINUE_NEEDED) {
  580. // Has to be SEC_I_CONTINUE_NEEDED, as we did not actually send data
  581. // at this point.
  582. state_ = st_error;
  583. return TLS_ERR_ERROR;
  584. }
  585. // Queue the initial message...
  586. writeBuf_.write(buf.pvBuffer, buf.cbBuffer);
  587. FreeContextBuffer(buf.pvBuffer);
  588. // ... and start sending it
  589. state_ = st_handshake_write;
  590. }
  591. // Fall through
  592. case st_handshake_write_last:
  593. case st_handshake_write: {
  594. A2_LOG_DEBUG("WinTLS: Writing handshake");
  595. // Write the currently queued handshake message until all data is sent.
  596. while (writeBuf_.size()) {
  597. ssize_t writ = ::send(sockfd_, writeBuf_.data(), writeBuf_.size(), 0);
  598. errno = ::WSAGetLastError();
  599. if (writ < 0 && errno == WSAEINTR) {
  600. continue;
  601. }
  602. if (writ < 0 && errno == WSAEWOULDBLOCK) {
  603. return TLS_ERR_WOULDBLOCK;
  604. }
  605. if (writ <= 0) {
  606. status_ = SEC_E_INCOMPLETE_MESSAGE;
  607. state_ = st_error;
  608. return TLS_ERR_ERROR;
  609. }
  610. writeBuf_.eat(writ);
  611. }
  612. if (state_ == st_handshake_write_last) {
  613. state_ = st_handshake_done;
  614. goto restart;
  615. }
  616. // Have to read one or more response messages.
  617. state_ = st_handshake_read;
  618. }
  619. // Fall through
  620. case st_handshake_read: {
  621. read:
  622. A2_LOG_DEBUG("WinTLS: Reading handshake...");
  623. // All write buffered data is invalid at this point!
  624. writeBuf_.clear();
  625. // Read as many bytes as possible, up to 4k new bytes.
  626. // We do not know how many bytes will arrive from the server at this
  627. // point.
  628. readBuf_.resize(readBuf_.size() + 4096);
  629. while (readBuf_.free()) {
  630. ssize_t read = ::recv(sockfd_, readBuf_.end(), readBuf_.free(), 0);
  631. errno = ::WSAGetLastError();
  632. if (read < 0 && errno == WSAEINTR) {
  633. continue;
  634. }
  635. if (read < 0 && errno == WSAEWOULDBLOCK) {
  636. break;
  637. }
  638. if (read <= 0) {
  639. status_ = SEC_E_INCOMPLETE_MESSAGE;
  640. state_ = st_error;
  641. return TLS_ERR_ERROR;
  642. }
  643. if (read == 0) {
  644. A2_LOG_DEBUG("WinTLS: Connection abruptly closed during handshake!");
  645. status_ = SEC_E_INCOMPLETE_MESSAGE;
  646. state_ = st_error;
  647. return TLS_ERR_ERROR;
  648. }
  649. readBuf_.advance(read);
  650. break;
  651. }
  652. if (!readBuf_.size()) {
  653. return TLS_ERR_WOULDBLOCK;
  654. }
  655. // Need to copy the data, as Schannel is free to mess with it. But we
  656. // might later need unmodified data from the original read buffer.
  657. auto bufcopy = make_unique<char[]>(readBuf_.size());
  658. memcpy(bufcopy.get(), readBuf_.data(), readBuf_.size());
  659. // Set up buffers. inbufs will be the raw bytes the library has to decode.
  660. // outbufs will contain generated responses, if any.
  661. TLSBuffer inbufs[] = {
  662. TLSBuffer(SECBUFFER_TOKEN, readBuf_.size(), bufcopy.get()),
  663. TLSBuffer(SECBUFFER_EMPTY, 0, nullptr),
  664. };
  665. TLSBufferDesc indesc(inbufs, 2);
  666. TLSBuffer outbufs[] = {
  667. TLSBuffer(SECBUFFER_TOKEN, 0, nullptr),
  668. TLSBuffer(SECBUFFER_ALERT, 0, nullptr),
  669. };
  670. TLSBufferDesc outdesc(outbufs, 2);
  671. if (side_ == TLS_CLIENT) {
  672. SEC_CHAR* host = hostname_.empty() ?
  673. nullptr :
  674. const_cast<SEC_CHAR*>(hostname_.c_str());
  675. status_ = ::InitializeSecurityContext(cred_,
  676. &handle_,
  677. host,
  678. kReqFlags,
  679. 0,
  680. 0,
  681. &indesc,
  682. 0,
  683. nullptr,
  684. &outdesc,
  685. &flags,
  686. nullptr);
  687. }
  688. else {
  689. status_ =
  690. ::AcceptSecurityContext(cred_,
  691. state_ == st_initialized ? nullptr : &handle_,
  692. &indesc,
  693. kReqAFlags,
  694. 0,
  695. state_ == st_initialized ? &handle_ : nullptr,
  696. &outdesc,
  697. &flags,
  698. nullptr);
  699. }
  700. if (status_ == SEC_E_INCOMPLETE_MESSAGE) {
  701. // Not enough raw bytes read yet to decode a full message.
  702. return TLS_ERR_WOULDBLOCK;
  703. }
  704. if (status_ != SEC_E_OK && status_ != SEC_I_CONTINUE_NEEDED) {
  705. state_ = st_error;
  706. return TLS_ERR_ERROR;
  707. }
  708. // Raw bytes where not entirely consumed, i.e. readBuf_ still contains
  709. // unprocessed data from the next message?
  710. if (inbufs[1].BufferType == SECBUFFER_EXTRA && inbufs[1].cbBuffer > 0) {
  711. readBuf_.eat(readBuf_.size() - inbufs[1].cbBuffer);
  712. }
  713. else {
  714. readBuf_.clear();
  715. }
  716. // Check if the library produced a new outgoing message and queue it.
  717. for (auto& buf : outbufs) {
  718. if (buf.BufferType == SECBUFFER_TOKEN && buf.cbBuffer > 0) {
  719. writeBuf_.write(buf.pvBuffer, buf.cbBuffer);
  720. FreeContextBuffer(buf.pvBuffer);
  721. state_ = st_handshake_write;
  722. }
  723. }
  724. // Need to read additional messages?
  725. if (status_ == SEC_I_CONTINUE_NEEDED) {
  726. A2_LOG_DEBUG("WinTLS: Continuing with handshake");
  727. goto restart;
  728. }
  729. if (side_ == TLS_CLIENT && flags != kReqFlags) {
  730. A2_LOG_ERROR(fmt("WinTLS: Channel setup failed. Schannel provider did "
  731. "not fulfill requested flags. "
  732. "Excepted: %lu Actual: %lu",
  733. kReqFlags,
  734. flags));
  735. status_ = SEC_E_INTERNAL_ERROR;
  736. state_ = st_error;
  737. return TLS_ERR_ERROR;
  738. }
  739. if (state_ == st_handshake_write) {
  740. A2_LOG_DEBUG("WinTLS: Continuing with handshake (last write)");
  741. state_ = st_handshake_write_last;
  742. goto restart;
  743. }
  744. }
  745. // Fall through
  746. case st_handshake_done: {
  747. // All ready now :D
  748. state_ = st_connected;
  749. A2_LOG_INFO(
  750. fmt("WinTLS: connected with: %s", getCipherSuite(&handle_).c_str()));
  751. auto proto = getProtocolVersion(&handle_);
  752. if (proto < 0x301) {
  753. std::string protoAndSuite;
  754. switch (proto) {
  755. case 0x300:
  756. protoAndSuite = "SSLv3";
  757. break;
  758. default:
  759. protoAndSuite = "Unknown";
  760. break;
  761. }
  762. protoAndSuite += " " + getCipherSuite(&handle_);
  763. A2_LOG_WARN(fmt(MSG_WARN_OLD_TLS_CONNECTION, protoAndSuite.c_str()));
  764. }
  765. return TLS_ERR_OK;
  766. }
  767. }
  768. A2_LOG_ERROR("WinTLS: Unreachable reached during tlsConnect! This is a bug!");
  769. state_ = st_error;
  770. return TLS_ERR_ERROR;
  771. }
  772. int WinTLSSession::tlsAccept()
  773. {
  774. std::string host, err;
  775. return tlsConnect(host, err);
  776. }
  777. std::string WinTLSSession::getLastErrorString()
  778. {
  779. std::stringstream ss;
  780. wchar_t* buf = nullptr;
  781. auto rv = FormatMessageW(FORMAT_MESSAGE_ALLOCATE_BUFFER |
  782. FORMAT_MESSAGE_FROM_SYSTEM |
  783. FORMAT_MESSAGE_IGNORE_INSERTS,
  784. nullptr,
  785. status_,
  786. MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
  787. (LPWSTR) & buf,
  788. 1024,
  789. nullptr);
  790. if (rv && buf) {
  791. ss << "Error: " << wCharToUtf8(buf) << "(" << std::hex << status_ << ")";
  792. LocalFree(buf);
  793. }
  794. else {
  795. ss << "Error: " << std::hex << status_;
  796. }
  797. return ss.str();
  798. }
  799. } // namespace aria2