tls12.c 34 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078
  1. /*
  2. * Copyright 2014-2023 The GmSSL Project. All Rights Reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the License); you may
  5. * not use this file except in compliance with the License.
  6. *
  7. * http://www.apache.org/licenses/LICENSE-2.0
  8. */
  9. #include <time.h>
  10. #include <stdio.h>
  11. #include <stdlib.h>
  12. #include <string.h>
  13. #include <gmssl/rand.h>
  14. #include <gmssl/x509.h>
  15. #include <gmssl/error.h>
  16. #include <gmssl/sm2.h>
  17. #include <gmssl/sm3.h>
  18. #include <gmssl/sm4.h>
  19. #include <gmssl/pem.h>
  20. #include <gmssl/mem.h>
  21. #include <gmssl/tls.h>
  22. static const int tls12_ciphers[] = {
  23. TLS_cipher_ecdhe_sm4_cbc_sm3,
  24. };
  25. static const size_t tls12_ciphers_count = sizeof(tls12_ciphers)/sizeof(tls12_ciphers[0]);
  26. static const uint8_t tls12_exts[] = {
  27. /* supported_groups */ 0x00,0x0A, 0x00,0x04, 0x00,0x02, 0x00,30,//0x29, // curveSM2
  28. /* ec_point_formats */ 0x00,0x0B, 0x00,0x02, 0x01, 0x00, // uncompressed
  29. /* signature_algors */ 0x00,0x0D, 0x00,0x04, 0x00,0x02, 0x07,0x07,//0x08, // sm2sig_sm3
  30. };
  31. int tls12_record_print(FILE *fp, const uint8_t *record, size_t recordlen, int format, int indent)
  32. {
  33. // 目前只支持TLCP的ECC公钥加密套件,因此不论用哪个套件解析都是一样的
  34. // 如果未来支持ECDHE套件,可以将函数改为宏,直接传入 (conn->cipher_suite << 8)
  35. format |= tls12_ciphers[0] << 8;
  36. return tls_record_print(fp, record, recordlen, format, indent);
  37. }
  38. int tls_record_set_handshake_server_key_exchange_ecdhe(uint8_t *record, size_t *recordlen,
  39. int curve, const SM2_POINT *point, const uint8_t *sig, size_t siglen)
  40. {
  41. int type = TLS_handshake_server_key_exchange;
  42. uint8_t *server_ecdh_params = record + 9;
  43. uint8_t *p = server_ecdh_params + 69;
  44. size_t len = 69;
  45. if (!record || !recordlen || !tls_named_curve_name(curve) || !point
  46. || !sig || !siglen || siglen > TLS_MAX_SIGNATURE_SIZE) {
  47. error_print();
  48. return -1;
  49. }
  50. server_ecdh_params[0] = TLS_curve_type_named_curve;
  51. server_ecdh_params[1] = curve >> 8;
  52. server_ecdh_params[2] = curve;
  53. server_ecdh_params[3] = 65;
  54. sm2_point_to_uncompressed_octets(point, server_ecdh_params + 4);
  55. tls_uint16_to_bytes(TLS_sig_sm2sig_sm3, &p, &len);
  56. tls_uint16array_to_bytes(sig, siglen, &p, &len);
  57. tls_record_set_handshake(record, recordlen, type, NULL, len);
  58. return 1;
  59. }
  60. // 这里返回的应该是一个SM2_POINT吗?
  61. int tls_record_get_handshake_server_key_exchange_ecdhe(const uint8_t *record,
  62. int *curve, SM2_POINT *point, const uint8_t **sig, size_t *siglen)
  63. {
  64. int type;
  65. const uint8_t *p;
  66. size_t len;
  67. uint8_t curve_type;
  68. uint16_t named_curve;
  69. const uint8_t *octets;
  70. size_t octetslen;
  71. uint16_t sig_alg;
  72. if (!record || !curve || !point || !sig || !siglen) {
  73. error_print();
  74. return -1;
  75. }
  76. if (tls_record_get_handshake(record, &type, &p, &len) != 1
  77. || type != TLS_handshake_server_key_exchange) {
  78. error_print();
  79. return -1;
  80. }
  81. if (tls_uint8_from_bytes(&curve_type, &p, &len) != 1
  82. || tls_uint16_from_bytes(&named_curve, &p, &len) != 1
  83. || tls_uint8array_from_bytes(&octets, &octetslen, &p, &len) != 1
  84. || tls_uint16_from_bytes(&sig_alg, &p, &len) != 1
  85. || tls_uint16array_from_bytes(sig, siglen, &p, &len) != 1
  86. || tls_length_is_zero(len) != 1) {
  87. error_print();
  88. return -1;
  89. }
  90. if (curve_type != TLS_curve_type_named_curve) {
  91. error_print();
  92. return -1;
  93. }
  94. if (named_curve != TLS_curve_sm2p256v1) {
  95. error_print();
  96. return -1;
  97. }
  98. *curve = named_curve;
  99. if (octetslen != 65
  100. || sm2_point_from_octets(point, octets, octetslen) != 1) {
  101. error_print();
  102. return -1;
  103. }
  104. if (sig_alg != TLS_sig_sm2sig_sm3) {
  105. error_print();
  106. return -1;
  107. }
  108. return 1;
  109. }
  110. int tls_record_set_handshake_client_key_exchange_ecdhe(uint8_t *record, size_t *recordlen,
  111. const SM2_POINT *point)
  112. {
  113. int type = TLS_handshake_client_key_exchange;
  114. record[9] = 65;
  115. sm2_point_to_uncompressed_octets(point, record + 9 + 1);
  116. tls_record_set_handshake(record, recordlen, type, NULL, 1 + 65);
  117. return 1;
  118. }
  119. int tls_record_get_handshake_client_key_exchange_ecdhe(const uint8_t *record, SM2_POINT *point)
  120. {
  121. int type;
  122. const uint8_t *p;
  123. size_t len;
  124. const uint8_t *octets;
  125. size_t octetslen;
  126. if (tls_record_get_handshake(record, &type, &p, &len) != 1
  127. || type != TLS_handshake_client_key_exchange) {
  128. error_print();
  129. return -1;
  130. }
  131. if (tls_uint8array_from_bytes(&octets, &octetslen, &p, &len) != 1
  132. || len > 0) {
  133. error_print();
  134. return -1;
  135. }
  136. if (octetslen != 65
  137. || sm2_point_from_octets(point, octets, octetslen) != 1) {
  138. error_print();
  139. return -1;
  140. }
  141. return 1;
  142. }
  143. /*
  144. Client Server
  145. ClientHello -------->
  146. ServerHello
  147. Certificate
  148. ServerKeyExchange
  149. CertificateRequest*
  150. <-------- ServerHelloDone
  151. Certificate*
  152. ClientKeyExchange
  153. CertificateVerify*
  154. [ChangeCipherSpec]
  155. Finished -------->
  156. [ChangeCipherSpec]
  157. <-------- Finished
  158. Application Data <-------> Application Data
  159. */
  160. int tls12_do_connect(TLS_CONNECT *conn)
  161. {
  162. int ret = -1;
  163. uint8_t *record = conn->record;
  164. uint8_t finished_record[TLS_FINISHED_RECORD_BUF_SIZE];
  165. size_t recordlen, finished_record_len;
  166. uint8_t client_random[32];
  167. uint8_t server_random[32];
  168. int protocol;
  169. int cipher_suite;
  170. const uint8_t *random;
  171. const uint8_t *session_id;
  172. size_t session_id_len;
  173. uint8_t client_exts[TLS_MAX_EXTENSIONS_SIZE];
  174. size_t client_exts_len = 0;
  175. const uint8_t *server_exts;
  176. size_t server_exts_len;
  177. // 扩展的协商结果,-1 表示服务器不支持该扩展(未给出响应)
  178. int ec_point_format = -1;
  179. int supported_group = -1;
  180. int signature_algor = -1;
  181. SM2_KEY server_sign_key;
  182. SM2_SIGN_CTX sign_ctx;
  183. const uint8_t *sig;
  184. size_t siglen;
  185. uint8_t pre_master_secret[48];
  186. SM3_CTX sm3_ctx;
  187. SM3_CTX tmp_sm3_ctx;
  188. uint8_t sm3_hash[32];
  189. const uint8_t *verify_data;
  190. size_t verify_data_len;
  191. uint8_t local_verify_data[12];
  192. int handshake_type;
  193. const uint8_t *cp;
  194. uint8_t *p;
  195. size_t len;
  196. int depth = 5;
  197. int alert = 0;
  198. int verify_result;
  199. // 初始化记录缓冲
  200. tls_record_set_protocol(record, TLS_protocol_tls1); // ClientHello的记录层协议版本是TLSv1.0
  201. tls_record_set_protocol(finished_record, conn->protocol);
  202. // 准备Finished Context(和ClientVerify)
  203. sm3_init(&sm3_ctx);
  204. if (conn->client_certs_len)
  205. sm2_sign_init(&sign_ctx, &conn->sign_key, SM2_DEFAULT_ID, SM2_DEFAULT_ID_LENGTH);
  206. // send ClientHello
  207. tls_random_generate(client_random);
  208. int ec_point_formats[] = { TLS_point_uncompressed };
  209. size_t ec_point_formats_cnt = 1;
  210. int supported_groups[] = { TLS_curve_sm2p256v1 };
  211. size_t supported_groups_cnt = 1;
  212. int signature_algors[] = { TLS_sig_sm2sig_sm3 };
  213. size_t signature_algors_cnt = 1;
  214. p = client_exts;
  215. client_exts_len = 0;
  216. tls_ec_point_formats_ext_to_bytes(ec_point_formats, ec_point_formats_cnt, &p, &client_exts_len);
  217. tls_supported_groups_ext_to_bytes(supported_groups, supported_groups_cnt, &p, &client_exts_len);
  218. tls_signature_algorithms_ext_to_bytes(signature_algors, signature_algors_cnt, &p, &client_exts_len);
  219. if (tls_record_set_handshake_client_hello(record, &recordlen,
  220. conn->protocol, client_random, NULL, 0,
  221. tls12_ciphers, tls12_ciphers_count,
  222. client_exts, client_exts_len) != 1) {
  223. error_print();
  224. goto end;
  225. }
  226. tls_trace("send ClientHello\n");
  227. tls12_record_trace(stderr, record, recordlen, 0, 0);
  228. if (tls_record_send(record, recordlen, conn->sock) != 1) {
  229. error_print();
  230. goto end;
  231. }
  232. sm3_update(&sm3_ctx, record + 5, recordlen - 5);
  233. if (conn->client_certs_len)
  234. sm2_sign_update(&sign_ctx, record + 5, recordlen - 5);
  235. // recv ServerHello
  236. tls_trace("recv ServerHello\n");
  237. if (tls_record_recv(record, &recordlen, conn->sock) != 1) {
  238. error_print();
  239. tls_send_alert(conn, TLS_alert_unexpected_message);
  240. goto end;
  241. }
  242. tls12_record_trace(stderr, record, recordlen, 0, 0);
  243. if (tls_record_protocol(record) != conn->protocol) {
  244. error_print();
  245. tls_send_alert(conn, TLS_alert_protocol_version);
  246. goto end;
  247. }
  248. if (tls_record_get_handshake_server_hello(record,
  249. &protocol, &random, &session_id, &session_id_len, &cipher_suite,
  250. &server_exts, &server_exts_len) != 1) {
  251. error_print();
  252. tls_send_alert(conn, TLS_alert_unexpected_message);
  253. goto end;
  254. }
  255. if (protocol != conn->protocol) {
  256. error_print();
  257. tls_send_alert(conn, TLS_alert_protocol_version);
  258. goto end;
  259. }
  260. // tls12_ciphers 应该改为conn的内部变量
  261. if (tls_cipher_suite_in_list(cipher_suite, tls12_ciphers, tls12_ciphers_count) != 1) {
  262. error_print();
  263. tls_send_alert(conn, TLS_alert_handshake_failure);
  264. goto end;
  265. }
  266. if (!server_exts) {
  267. error_print();
  268. tls_send_alert(conn, TLS_alert_unexpected_message);
  269. goto end;
  270. }
  271. if (tls_process_server_hello_exts(server_exts, server_exts_len, &ec_point_format, &supported_group, &signature_algor) != 1
  272. || ec_point_format < 0
  273. || supported_group < 0
  274. || signature_algor < 0) {
  275. error_print();
  276. tls_send_alert(conn, TLS_alert_unexpected_message);
  277. goto end;
  278. }
  279. memcpy(server_random, random, 32);
  280. memcpy(conn->session_id, session_id, session_id_len);
  281. conn->cipher_suite = cipher_suite;
  282. sm3_update(&sm3_ctx, record + 5, recordlen - 5);
  283. if (conn->client_certs_len)
  284. sm2_sign_update(&sign_ctx, record + 5, recordlen - 5);
  285. // recv ServerCertificate
  286. tls_trace("recv ServerCertificate\n");
  287. if (tls_record_recv(record, &recordlen, conn->sock) != 1
  288. || tls_record_protocol(record) != conn->protocol) {
  289. error_print();
  290. tls_send_alert(conn, TLS_alert_unexpected_message);
  291. goto end;
  292. }
  293. tls12_record_trace(stderr, record, recordlen, 0, 0);
  294. if (tls_record_get_handshake_certificate(record,
  295. conn->server_certs, &conn->server_certs_len) != 1) {
  296. error_print();
  297. tls_send_alert(conn, TLS_alert_unexpected_message);
  298. goto end;
  299. }
  300. sm3_update(&sm3_ctx, record + 5, recordlen - 5);
  301. if (conn->client_certs_len)
  302. sm2_sign_update(&sign_ctx, record + 5, recordlen - 5);
  303. // verify ServerCertificate
  304. if (x509_certs_verify(conn->server_certs, conn->server_certs_len, X509_cert_chain_server,
  305. conn->ca_certs, conn->ca_certs_len, depth, &verify_result) != 1) {
  306. error_print();
  307. tls_send_alert(conn, TLS_alert_bad_certificate);
  308. goto end;
  309. }
  310. // recv ServerKeyExchange
  311. tls_trace("recv ServerKeyExchange\n");
  312. if (tls_record_recv(record, &recordlen, conn->sock) != 1
  313. || tls_record_protocol(record) != conn->protocol) {
  314. error_print();
  315. tls_send_alert(conn, TLS_alert_unexpected_message);
  316. goto end;
  317. }
  318. tls12_record_trace(stderr, record, recordlen, 0, 0);
  319. int curve;
  320. SM2_POINT server_ecdhe_public;
  321. if (tls_record_get_handshake_server_key_exchange_ecdhe(record, &curve, &server_ecdhe_public, &sig, &siglen) != 1) {
  322. error_print();
  323. tls_send_alert(conn, TLS_alert_unexpected_message);
  324. goto end;
  325. }
  326. if (curve != TLS_curve_sm2p256v1) {
  327. error_print();
  328. tls_send_alert(conn, TLS_alert_unexpected_message);
  329. goto end;
  330. }
  331. sm3_update(&sm3_ctx, record + 5, recordlen - 5);
  332. if (conn->client_certs_len)
  333. sm2_sign_update(&sign_ctx, record + 5, recordlen - 5);
  334. // verify ServerKeyExchange
  335. if (x509_certs_get_cert_by_index(conn->server_certs, conn->server_certs_len, 0, &cp, &len) != 1
  336. || x509_cert_get_subject_public_key(cp, len, &server_sign_key) != 1) {
  337. error_print();
  338. tls_send_alert(conn, TLS_alert_bad_certificate);
  339. goto end;
  340. }
  341. if (tls_verify_server_ecdh_params(&server_sign_key, // 这应该是签名公钥
  342. client_random, server_random, curve, &server_ecdhe_public, sig, siglen) != 1) {
  343. error_print();
  344. tls_send_alert(conn, TLS_alert_internal_error);
  345. goto end;
  346. }
  347. // recv CertificateRequest or ServerHelloDone
  348. if (tls_record_recv(record, &recordlen, conn->sock) != 1
  349. || tls_record_protocol(record) != conn->protocol
  350. || tls_record_get_handshake(record, &handshake_type, &cp, &len) != 1) {
  351. error_print();
  352. tls_send_alert(conn, TLS_alert_unexpected_message);
  353. goto end;
  354. }
  355. if (handshake_type == TLS_handshake_certificate_request) {
  356. const uint8_t *cert_types;
  357. size_t cert_types_len;
  358. const uint8_t *ca_names;
  359. size_t ca_names_len;
  360. // recv CertificateRequest
  361. tls_trace("recv CertificateRequest\n");
  362. tls12_record_trace(stderr, record, recordlen, 0, 0);
  363. if (tls_record_get_handshake_certificate_request(record,
  364. &cert_types, &cert_types_len, &ca_names, &ca_names_len) != 1) {
  365. error_print();
  366. tls_send_alert(conn, TLS_alert_unexpected_message);
  367. goto end;
  368. }
  369. if(!conn->client_certs_len) {
  370. error_print();
  371. tls_send_alert(conn, TLS_alert_internal_error);
  372. goto end;
  373. }
  374. if (tls_cert_types_accepted(cert_types, cert_types_len, conn->client_certs, conn->client_certs_len) != 1
  375. || tls_authorities_issued_certificate(ca_names, ca_names_len, conn->client_certs, conn->client_certs_len) != 1) {
  376. error_print();
  377. tls_send_alert(conn, TLS_alert_unsupported_certificate);
  378. goto end;
  379. }
  380. sm3_update(&sm3_ctx, record + 5, recordlen - 5);
  381. sm2_sign_update(&sign_ctx, record + 5, recordlen - 5);
  382. // recv ServerHelloDone
  383. if (tls_record_recv(record, &recordlen, conn->sock) != 1
  384. || tls_record_protocol(record) != conn->protocol) {
  385. error_print();
  386. tls_send_alert(conn, TLS_alert_unexpected_message);
  387. goto end;
  388. }
  389. } else {
  390. // 这个得处理一下
  391. conn->client_certs_len = 0;
  392. gmssl_secure_clear(&conn->sign_key, sizeof(SM2_KEY));
  393. }
  394. tls_trace("recv ServerHelloDone\n");
  395. tls12_record_trace(stderr, record, recordlen, 0, 0);
  396. if (tls_record_get_handshake_server_hello_done(record) != 1) {
  397. error_print();
  398. tls_send_alert(conn, TLS_alert_unexpected_message);
  399. goto end;
  400. }
  401. sm3_update(&sm3_ctx, record + 5, recordlen - 5);
  402. if (conn->client_certs_len)
  403. sm2_sign_update(&sign_ctx, record + 5, recordlen - 5);
  404. // send ClientCertificate
  405. if (conn->client_certs_len) {
  406. tls_trace("send ClientCertificate\n");
  407. if (tls_record_set_handshake_certificate(record, &recordlen, conn->client_certs, conn->client_certs_len) != 1) {
  408. error_print();
  409. tls_send_alert(conn, TLS_alert_internal_error);
  410. goto end;
  411. }
  412. tls12_record_trace(stderr, record, recordlen, 0, 0);
  413. if (tls_record_send(record, recordlen, conn->sock) != 1) {
  414. error_print();
  415. goto end;
  416. }
  417. sm3_update(&sm3_ctx, record + 5, recordlen - 5);
  418. sm2_sign_update(&sign_ctx, record + 5, recordlen - 5);
  419. }
  420. // generate MASTER_SECRET
  421. tls_trace("generate secrets\n");
  422. SM2_KEY client_ecdh;
  423. sm2_key_generate(&client_ecdh);
  424. sm2_do_ecdh(&client_ecdh, &server_ecdhe_public, &server_ecdhe_public);
  425. memcpy(pre_master_secret, &server_ecdhe_public, 32); // 这个做法很不优雅
  426. // ECDHE和ECC的PMS结构是不一样的吗?
  427. if (tls_prf(pre_master_secret, 32, "master secret",
  428. client_random, 32, server_random, 32,
  429. 48, conn->master_secret) != 1
  430. || tls_prf(conn->master_secret, 48, "key expansion",
  431. server_random, 32, client_random, 32,
  432. 96, conn->key_block) != 1) {
  433. error_print();
  434. tls_send_alert(conn, TLS_alert_internal_error);
  435. goto end;
  436. }
  437. sm3_hmac_init(&conn->client_write_mac_ctx, conn->key_block, 32);
  438. sm3_hmac_init(&conn->server_write_mac_ctx, conn->key_block + 32, 32);
  439. sm4_set_encrypt_key(&conn->client_write_enc_key, conn->key_block + 64);
  440. sm4_set_decrypt_key(&conn->server_write_enc_key, conn->key_block + 80);
  441. /*
  442. tls_secrets_print(stderr,
  443. pre_master_secret, 48,
  444. client_random, server_random,
  445. conn->master_secret,
  446. conn->key_block, 96,
  447. 0, 4);
  448. */
  449. // send ClientKeyExchange
  450. tls_trace("send ClientKeyExchange\n");
  451. if (tls_record_set_handshake_client_key_exchange_ecdhe(record, &recordlen, &client_ecdh.public_key) != 1) {
  452. error_print();
  453. tls_send_alert(conn, TLS_alert_internal_error);
  454. goto end;
  455. }
  456. tls12_record_trace(stderr, record, recordlen, 0, 0);
  457. if (tls_record_send(record, recordlen, conn->sock) != 1) {
  458. error_print();
  459. goto end;
  460. }
  461. sm3_update(&sm3_ctx, record + 5, recordlen - 5);
  462. if (conn->client_certs_len)
  463. sm2_sign_update(&sign_ctx, record + 5, recordlen - 5);
  464. // send CertificateVerify
  465. if (conn->client_certs_len) {
  466. tls_trace("send CertificateVerify\n");
  467. uint8_t sigbuf[SM2_MAX_SIGNATURE_SIZE];
  468. if (sm2_sign_finish(&sign_ctx, sigbuf, &siglen) != 1
  469. || tls_record_set_handshake_certificate_verify(record, &recordlen, sigbuf, siglen) != 1) {
  470. error_print();
  471. tls_send_alert(conn, TLS_alert_internal_error);
  472. goto end;
  473. }
  474. tls12_record_trace(stderr, record, recordlen, 0, 0);
  475. if (tls_record_send(record, recordlen, conn->sock) != 1) {
  476. error_print();
  477. goto end;
  478. }
  479. sm3_update(&sm3_ctx, record + 5, recordlen - 5);
  480. }
  481. // send [ChangeCipherSpec]
  482. tls_trace("send [ChangeCipherSpec]\n");
  483. if (tls_record_set_change_cipher_spec(record, &recordlen) !=1) {
  484. error_print();
  485. tls_send_alert(conn, TLS_alert_internal_error);
  486. goto end;
  487. }
  488. tls12_record_trace(stderr, record, recordlen, 0, 0);
  489. if (tls_record_send(record, recordlen, conn->sock) != 1) {
  490. error_print();
  491. goto end;
  492. }
  493. // send Client Finished
  494. tls_trace("send Finished\n");
  495. memcpy(&tmp_sm3_ctx, &sm3_ctx, sizeof(sm3_ctx));
  496. sm3_finish(&tmp_sm3_ctx, sm3_hash);
  497. if (tls_prf(conn->master_secret, 48, "client finished",
  498. sm3_hash, 32, NULL, 0, sizeof(local_verify_data), local_verify_data) != 1
  499. || tls_record_set_handshake_finished(finished_record, &finished_record_len,
  500. local_verify_data, sizeof(local_verify_data)) != 1) {
  501. error_print();
  502. tls_send_alert(conn, TLS_alert_internal_error);
  503. goto end;
  504. }
  505. tls12_record_trace(stderr, finished_record, finished_record_len, 0, 0);
  506. sm3_update(&sm3_ctx, finished_record + 5, finished_record_len - 5);
  507. // encrypt Client Finished
  508. tls_trace("encrypt Finished\n");
  509. if (tls_record_encrypt(&conn->client_write_mac_ctx, &conn->client_write_enc_key,
  510. conn->client_seq_num, finished_record, finished_record_len, record, &recordlen) != 1) {
  511. error_print();
  512. tls_send_alert(conn, TLS_alert_internal_error);
  513. goto end;
  514. }
  515. tls12_record_trace(stderr, record, recordlen, (1<<24), 0); // 强制打印密文原数据
  516. tls_seq_num_incr(conn->client_seq_num);
  517. if (tls_record_send(record, recordlen, conn->sock) != 1) {
  518. error_print();
  519. goto end;
  520. }
  521. // [ChangeCipherSpec]
  522. tls_trace("recv [ChangeCipherSpec]\n");
  523. if (tls_record_recv(record, &recordlen, conn->sock) != 1
  524. || tls_record_protocol(record) != conn->protocol) {
  525. error_print();
  526. tls_send_alert(conn, TLS_alert_unexpected_message);
  527. goto end;
  528. }
  529. tls12_record_trace(stderr, record, recordlen, 0, 0);
  530. if (tls_record_get_change_cipher_spec(record) != 1) {
  531. error_print();
  532. tls_send_alert(conn, TLS_alert_unexpected_message);
  533. goto end;
  534. }
  535. // Finished
  536. tls_trace("recv Finished\n");
  537. if (tls_record_recv(record, &recordlen, conn->sock) != 1
  538. || tls_record_protocol(record) != conn->protocol) {
  539. error_print();
  540. tls_send_alert(conn, TLS_alert_unexpected_message);
  541. goto end;
  542. }
  543. if (recordlen > sizeof(finished_record)) {
  544. error_print(); // 解密可能导致 finished_record 溢出
  545. tls_send_alert(conn, TLS_alert_bad_record_mac);
  546. goto end;
  547. }
  548. tls12_record_trace(stderr, record, recordlen, (1<<24), 0); // 强制打印密文原数据
  549. tls_trace("decrypt Finished\n");
  550. if (tls_record_decrypt(&conn->server_write_mac_ctx, &conn->server_write_enc_key,
  551. conn->server_seq_num, record, recordlen, finished_record, &finished_record_len) != 1) {
  552. error_print();
  553. tls_send_alert(conn, TLS_alert_bad_record_mac);
  554. goto end;
  555. }
  556. tls12_record_trace(stderr, finished_record, finished_record_len, 0, 0);
  557. tls_seq_num_incr(conn->server_seq_num);
  558. if (tls_record_get_handshake_finished(finished_record, &verify_data, &verify_data_len) != 1) {
  559. error_print();
  560. tls_send_alert(conn, TLS_alert_unexpected_message);
  561. goto end;
  562. }
  563. if (verify_data_len != sizeof(local_verify_data)) {
  564. error_print();
  565. tls_send_alert(conn, TLS_alert_unexpected_message);
  566. goto end;
  567. }
  568. sm3_finish(&sm3_ctx, sm3_hash);
  569. if (tls_prf(conn->master_secret, 48, "server finished",
  570. sm3_hash, 32, NULL, 0, sizeof(local_verify_data), local_verify_data) != 1) {
  571. error_print();
  572. tls_send_alert(conn, TLS_alert_internal_error);
  573. goto end;
  574. }
  575. if (memcmp(verify_data, local_verify_data, sizeof(local_verify_data)) != 0) {
  576. error_print();
  577. tls_send_alert(conn, TLS_alert_decrypt_error);
  578. goto end;
  579. }
  580. fprintf(stderr, "Connection established!\n");
  581. conn->protocol = conn->protocol;
  582. conn->cipher_suite = cipher_suite;
  583. ret = 1;
  584. end:
  585. gmssl_secure_clear(&sign_ctx, sizeof(sign_ctx));
  586. gmssl_secure_clear(pre_master_secret, sizeof(pre_master_secret));
  587. return ret;
  588. }
  589. int tls12_do_accept(TLS_CONNECT *conn)
  590. {
  591. int ret = -1;
  592. int client_verify = 0;
  593. uint8_t *record = conn->record;
  594. uint8_t finished_record[TLS_FINISHED_RECORD_BUF_SIZE]; // 解密可能导致前面的record被覆盖
  595. size_t recordlen, finished_record_len;
  596. // 这个ciphers不是应该在CTX中设置的吗
  597. const int server_ciphers[] = { TLS_cipher_ecdhe_sm4_cbc_sm3 }; // 未来应该支持GCM/CBC两个套件
  598. // ClientHello, ServerHello
  599. uint8_t client_random[32];
  600. uint8_t server_random[32];
  601. int protocol;
  602. const uint8_t *random;
  603. const uint8_t *session_id; // TLCP服务器忽略客户端SessionID,也不主动设置SessionID
  604. size_t session_id_len;
  605. const uint8_t *client_ciphers;
  606. size_t client_ciphers_len;
  607. const uint8_t *client_exts;
  608. size_t client_exts_len;
  609. uint8_t server_exts[TLS_MAX_EXTENSIONS_SIZE];
  610. size_t server_exts_len;
  611. int curve = TLS_curve_sm2p256v1; // 这个是否应该在conn中设置?
  612. // ServerKeyExchange
  613. SM2_KEY server_ecdhe_key;
  614. SM2_SIGN_CTX sign_ctx;
  615. uint8_t sigbuf[SM2_MAX_SIGNATURE_SIZE];
  616. size_t siglen;
  617. // ClientCertificate, CertificateVerify
  618. TLS_CLIENT_VERIFY_CTX client_verify_ctx;
  619. SM2_KEY client_sign_key;
  620. const uint8_t *sig;
  621. const int verify_depth = 5;
  622. int verify_result;
  623. // ClientKeyExchange
  624. SM2_POINT client_ecdhe_point;
  625. uint8_t pre_master_secret[SM2_MAX_PLAINTEXT_SIZE]; // sm2_decrypt 保证输出不会溢出
  626. // Finished
  627. SM3_CTX sm3_ctx;
  628. SM3_CTX tmp_sm3_ctx;
  629. uint8_t sm3_hash[32];
  630. uint8_t local_verify_data[12];
  631. const uint8_t *verify_data;
  632. size_t verify_data_len;
  633. const uint8_t *cp;
  634. size_t len;
  635. // 服务器端如果设置了CA
  636. if (conn->ca_certs_len)
  637. client_verify = 1;
  638. // 初始化Finished和客户端验证环境
  639. sm3_init(&sm3_ctx);
  640. if (client_verify)
  641. tls_client_verify_init(&client_verify_ctx);
  642. // recv ClientHello
  643. tls_trace("recv ClientHello\n");
  644. if (tls_record_recv(record, &recordlen, conn->sock) != 1) {
  645. error_print();
  646. tls_send_alert(conn, TLS_alert_unexpected_message);
  647. goto end;
  648. }
  649. tls12_record_trace(stderr, record, recordlen, 0, 0);
  650. if (tls_record_protocol(record) != conn->protocol
  651. && tls_record_protocol(record) != TLS_protocol_tls1) {
  652. error_print();
  653. tls_send_alert(conn, TLS_alert_protocol_version);
  654. goto end;
  655. }
  656. if (tls_record_get_handshake_client_hello(record,
  657. &protocol, &random, &session_id, &session_id_len,
  658. &client_ciphers, &client_ciphers_len,
  659. &client_exts, &client_exts_len) != 1) {
  660. error_print();
  661. tls_send_alert(conn, TLS_alert_unexpected_message);
  662. goto end;
  663. }
  664. if (protocol != conn->protocol) {
  665. error_print();
  666. tls_send_alert(conn, TLS_alert_protocol_version);
  667. goto end;
  668. }
  669. memcpy(client_random, random, 32);
  670. if (tls_cipher_suites_select(client_ciphers, client_ciphers_len,
  671. server_ciphers, sizeof(server_ciphers)/sizeof(server_ciphers[0]),
  672. &conn->cipher_suite) != 1) {
  673. error_print();
  674. tls_send_alert(conn, TLS_alert_insufficient_security);
  675. goto end;
  676. }
  677. if (client_exts) {
  678. server_exts_len = 0;
  679. curve = TLS_curve_sm2p256v1;
  680. tls_process_client_hello_exts(client_exts, client_exts_len, server_exts, &server_exts_len, sizeof(server_exts));
  681. }
  682. sm3_update(&sm3_ctx, record + 5, recordlen - 5);
  683. if (client_verify)
  684. tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5);
  685. // send ServerHello
  686. tls_trace("send ServerHello\n");
  687. tls_random_generate(server_random);
  688. tls_record_set_protocol(record, conn->protocol);
  689. if (tls_record_set_handshake_server_hello(record, &recordlen,
  690. conn->protocol, server_random, NULL, 0,
  691. conn->cipher_suite, server_exts, server_exts_len) != 1) {
  692. error_print();
  693. tls_send_alert(conn, TLS_alert_internal_error);
  694. goto end;
  695. }
  696. tls12_record_trace(stderr, record, recordlen, 0, 0);
  697. if (tls_record_send(record, recordlen, conn->sock) != 1) {
  698. error_print();
  699. goto end;
  700. }
  701. sm3_update(&sm3_ctx, record + 5, recordlen - 5);
  702. if (client_verify)
  703. tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5);
  704. // send ServerCertificate
  705. tls_trace("send ServerCertificate\n");
  706. if (tls_record_set_handshake_certificate(record, &recordlen,
  707. conn->server_certs, conn->server_certs_len) != 1) {
  708. error_print();
  709. tls_send_alert(conn, TLS_alert_internal_error);
  710. goto end;
  711. }
  712. tls12_record_trace(stderr, record, recordlen, 0, 0);
  713. if (tls_record_send(record, recordlen, conn->sock) != 1) {
  714. error_print();
  715. goto end;
  716. }
  717. sm3_update(&sm3_ctx, record + 5, recordlen - 5);
  718. if (client_verify)
  719. tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5);
  720. // send ServerKeyExchange
  721. tls_trace("send ServerKeyExchange\n");
  722. sm2_key_generate(&server_ecdhe_key);
  723. if (tls_sign_server_ecdh_params(&conn->sign_key,
  724. client_random, server_random, TLS_curve_sm2p256v1, &server_ecdhe_key.public_key,
  725. sigbuf, &siglen) != 1) {
  726. error_print();
  727. tls_send_alert(conn, TLS_alert_internal_error);
  728. return -1;
  729. }
  730. if (tls_record_set_handshake_server_key_exchange_ecdhe(record, &recordlen,
  731. curve, &server_ecdhe_key.public_key, sigbuf, siglen) != 1) {
  732. error_print();
  733. tls_send_alert(conn, TLS_alert_internal_error);
  734. goto end;
  735. }
  736. tls12_record_trace(stderr, record, recordlen, 0, 0);
  737. if (tls_record_send(record, recordlen, conn->sock) != 1) {
  738. error_print();
  739. goto end;
  740. }
  741. sm3_update(&sm3_ctx, record + 5, recordlen - 5);
  742. if (client_verify)
  743. tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5);
  744. // send CertificateRequest
  745. if (client_verify) {
  746. const uint8_t cert_types[] = { TLS_cert_type_ecdsa_sign };
  747. uint8_t ca_names[TLS_MAX_CA_NAMES_SIZE] = {0}; // TODO: 根据客户端验证CA证书列计算缓冲大小,或直接输出到record缓冲
  748. size_t ca_names_len = 0;
  749. tls_trace("send CertificateRequest\n");
  750. if (tls_authorities_from_certs(ca_names, &ca_names_len, sizeof(ca_names),
  751. conn->ca_certs, conn->ca_certs_len) != 1) {
  752. error_print();
  753. goto end;
  754. }
  755. if (tls_record_set_handshake_certificate_request(record, &recordlen,
  756. cert_types, sizeof(cert_types),
  757. ca_names, ca_names_len) != 1) {
  758. error_print();
  759. goto end;
  760. }
  761. tls12_record_trace(stderr, record, recordlen, 0, 0);
  762. if (tls_record_send(record, recordlen, conn->sock) != 1) {
  763. error_print();
  764. goto end;
  765. }
  766. sm3_update(&sm3_ctx, record + 5, recordlen - 5);
  767. tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5);
  768. }
  769. // send ServerHelloDone
  770. tls_trace("send ServerHelloDone\n");
  771. tls_record_set_handshake_server_hello_done(record, &recordlen);
  772. tls12_record_trace(stderr, record, recordlen, 0, 0);
  773. if (tls_record_send(record, recordlen, conn->sock) != 1) {
  774. error_print();
  775. goto end;
  776. }
  777. sm3_update(&sm3_ctx, record + 5, recordlen - 5);
  778. if (client_verify)
  779. tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5);
  780. // recv ClientCertificate
  781. if (conn->ca_certs_len) {
  782. tls_trace("recv ClientCertificate\n");
  783. if (tls_record_recv(record, &recordlen, conn->sock) != 1
  784. || tls_record_protocol(record) != conn->protocol) { // protocol检查应该在trace之后
  785. error_print();
  786. tls_send_alert(conn, TLS_alert_unexpected_message);
  787. goto end;
  788. }
  789. tls12_record_trace(stderr, record, recordlen, 0, 0);
  790. if (tls_record_get_handshake_certificate(record, conn->client_certs, &conn->client_certs_len) != 1) {
  791. error_print();
  792. tls_send_alert(conn, TLS_alert_unexpected_message);
  793. goto end;
  794. }
  795. if (x509_certs_verify(conn->client_certs, conn->client_certs_len, X509_cert_chain_client,
  796. conn->ca_certs, conn->ca_certs_len, verify_depth, &verify_result) != 1) {
  797. error_print();
  798. tls_send_alert(conn, TLS_alert_bad_certificate);
  799. goto end;
  800. }
  801. sm3_update(&sm3_ctx, record + 5, recordlen - 5);
  802. tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5);
  803. }
  804. // recv ClientKeyExchange
  805. tls_trace("recv ClientKeyExchange\n");
  806. if (tls_record_recv(record, &recordlen, conn->sock) != 1
  807. || tls_record_protocol(record) != conn->protocol) {
  808. error_print();
  809. tls_send_alert(conn, TLS_alert_unexpected_message);
  810. goto end;
  811. }
  812. tls12_record_trace(stderr, record, recordlen, 0, 0); // 应该给tls12一个独立的trace
  813. if (tls_record_get_handshake_client_key_exchange_ecdhe(record, &client_ecdhe_point) != 1) {
  814. error_print();
  815. tls_send_alert(conn, TLS_alert_unexpected_message);
  816. goto end;
  817. }
  818. sm3_update(&sm3_ctx, record + 5, recordlen - 5);
  819. if (client_verify)
  820. tls_client_verify_update(&client_verify_ctx, record + 5, recordlen - 5);
  821. // recv CertificateVerify
  822. if (client_verify) {
  823. tls_trace("recv CertificateVerify\n");
  824. if (tls_record_recv(record, &recordlen, conn->sock) != 1
  825. || tls_record_protocol(record) != conn->protocol) {
  826. tls_send_alert(conn, TLS_alert_unexpected_message);
  827. error_print();
  828. goto end;
  829. }
  830. tls12_record_trace(stderr, record, recordlen, 0, 0);
  831. if (tls_record_get_handshake_certificate_verify(record, &sig, &siglen) != 1) {
  832. tls_send_alert(conn, TLS_alert_unexpected_message);
  833. error_print();
  834. goto end;
  835. }
  836. if (x509_certs_get_cert_by_index(conn->client_certs, conn->client_certs_len, 0, &cp, &len) != 1
  837. || x509_cert_get_subject_public_key(cp, len, &client_sign_key) != 1) {
  838. error_print();
  839. tls_send_alert(conn, TLS_alert_bad_certificate);
  840. goto end;
  841. }
  842. if (tls_client_verify_finish(&client_verify_ctx, sig, siglen, &client_sign_key) != 1) {
  843. error_print();
  844. tls_send_alert(conn, TLS_alert_decrypt_error);
  845. goto end;
  846. }
  847. sm3_update(&sm3_ctx, record + 5, recordlen - 5);
  848. }
  849. // generate secrets
  850. tls_trace("generate secrets\n");
  851. sm2_do_ecdh(&server_ecdhe_key, &client_ecdhe_point, &client_ecdhe_point);
  852. memcpy(pre_master_secret, (uint8_t *)&client_ecdhe_point, 32); // 这里应该修改一下表示方式,比如get_xy()
  853. tls_prf(pre_master_secret, 32, "master secret",
  854. client_random, 32, server_random, 32,
  855. 48, conn->master_secret);
  856. tls_prf(conn->master_secret, 48, "key expansion",
  857. server_random, 32, client_random, 32,
  858. 96, conn->key_block);
  859. sm3_hmac_init(&conn->client_write_mac_ctx, conn->key_block, 32);
  860. sm3_hmac_init(&conn->server_write_mac_ctx, conn->key_block + 32, 32);
  861. sm4_set_decrypt_key(&conn->client_write_enc_key, conn->key_block + 64);
  862. sm4_set_encrypt_key(&conn->server_write_enc_key, conn->key_block + 80);
  863. /*
  864. tls_secrets_print(stderr, pre_master_secret, 32, client_random, server_random,
  865. conn->master_secret, conn->key_block, 96, 0, 4);
  866. */
  867. // recv [ChangeCipherSpec]
  868. tls_trace("recv [ChangeCipherSpec]\n");
  869. if (tls_record_recv(record, &recordlen, conn->sock) != 1
  870. || tls_record_protocol(record) != conn->protocol) {
  871. error_print();
  872. tls_send_alert(conn, TLS_alert_unexpected_message);
  873. goto end;
  874. }
  875. tls12_record_trace(stderr, record, recordlen, 0, 0);
  876. if (tls_record_get_change_cipher_spec(record) != 1) {
  877. error_print();
  878. tls_send_alert(conn, TLS_alert_unexpected_message);
  879. goto end;
  880. }
  881. // recv ClientFinished
  882. tls_trace("recv Finished\n");
  883. if (tls_record_recv(record, &recordlen, conn->sock) != 1
  884. || tls_record_protocol(record) != conn->protocol) {
  885. error_print();
  886. tls_send_alert(conn, TLS_alert_unexpected_message);
  887. goto end;
  888. }
  889. if (recordlen > sizeof(finished_record)) {
  890. error_print();
  891. tls_send_alert(conn, TLS_alert_unexpected_message);
  892. goto end;
  893. }
  894. tls12_record_trace(stderr, record, recordlen, (1<<24), 0); // 强制打印密文原数据
  895. // decrypt ClientFinished
  896. tls_trace("decrypt Finished\n");
  897. if (tls_record_decrypt(&conn->client_write_mac_ctx, &conn->client_write_enc_key,
  898. conn->client_seq_num, record, recordlen, finished_record, &finished_record_len) != 1) {
  899. error_print();
  900. tls_send_alert(conn, TLS_alert_bad_record_mac);
  901. goto end;
  902. }
  903. tls12_record_trace(stderr, finished_record, finished_record_len, 0, 0);
  904. tls_seq_num_incr(conn->client_seq_num);
  905. if (tls_record_get_handshake_finished(finished_record, &verify_data, &verify_data_len) != 1) {
  906. error_print();
  907. tls_send_alert(conn, TLS_alert_bad_record_mac);
  908. goto end;
  909. }
  910. if (verify_data_len != sizeof(local_verify_data)) {
  911. error_print();
  912. tls_send_alert(conn, TLS_alert_bad_record_mac);
  913. goto end;
  914. }
  915. // verify ClientFinished
  916. memcpy(&tmp_sm3_ctx, &sm3_ctx, sizeof(SM3_CTX));
  917. sm3_update(&sm3_ctx, finished_record + 5, finished_record_len - 5);
  918. sm3_finish(&tmp_sm3_ctx, sm3_hash);
  919. if (tls_prf(conn->master_secret, 48, "client finished", sm3_hash, 32, NULL, 0,
  920. sizeof(local_verify_data), local_verify_data) != 1) {
  921. error_print();
  922. tls_send_alert(conn, TLS_alert_internal_error);
  923. goto end;
  924. }
  925. if (memcmp(verify_data, local_verify_data, sizeof(local_verify_data)) != 0) {
  926. error_puts("client_finished.verify_data verification failure");
  927. tls_send_alert(conn, TLS_alert_decrypt_error);
  928. goto end;
  929. }
  930. // send [ChangeCipherSpec]
  931. tls_trace("send [ChangeCipherSpec]\n");
  932. if (tls_record_set_change_cipher_spec(record, &recordlen) != 1) {
  933. error_print();
  934. tls_send_alert(conn, TLS_alert_internal_error);
  935. goto end;
  936. }
  937. tls12_record_trace(stderr, record, recordlen, 0, 0);
  938. if (tls_record_send(record, recordlen, conn->sock) != 1) {
  939. error_print();
  940. goto end;
  941. }
  942. // send ServerFinished
  943. tls_trace("send Finished\n");
  944. sm3_finish(&sm3_ctx, sm3_hash);
  945. if (tls_prf(conn->master_secret, 48, "server finished", sm3_hash, 32, NULL, 0,
  946. sizeof(local_verify_data), local_verify_data) != 1
  947. || tls_record_set_handshake_finished(finished_record, &finished_record_len,
  948. local_verify_data, sizeof(local_verify_data)) != 1) {
  949. error_print();
  950. tls_send_alert(conn, TLS_alert_internal_error);
  951. goto end;
  952. }
  953. tls12_record_trace(stderr, finished_record, finished_record_len, 0, 0);
  954. if (tls_record_encrypt(&conn->server_write_mac_ctx, &conn->server_write_enc_key,
  955. conn->server_seq_num, finished_record, finished_record_len, record, &recordlen) != 1) {
  956. error_print();
  957. tls_send_alert(conn, TLS_alert_internal_error);
  958. goto end;
  959. }
  960. tls_trace("encrypt Finished\n");
  961. tls12_record_trace(stderr, record, recordlen, (1<<24), 0); // 强制打印密文原数据
  962. tls_seq_num_incr(conn->server_seq_num);
  963. if (tls_record_send(record, recordlen, conn->sock) != 1) {
  964. error_print();
  965. goto end;
  966. }
  967. conn->protocol = conn->protocol;
  968. fprintf(stderr, "Connection Established!\n\n");
  969. ret = 1;
  970. end:
  971. gmssl_secure_clear(&sign_ctx, sizeof(sign_ctx));
  972. gmssl_secure_clear(pre_master_secret, sizeof(pre_master_secret));
  973. if (client_verify) tls_client_verify_cleanup(&client_verify_ctx);
  974. return ret;
  975. }