aeadtest.c 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. /*
  2. * Copyright 2014-2023 The GmSSL Project. All Rights Reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the License); you may
  5. * not use this file except in compliance with the License.
  6. *
  7. * http://www.apache.org/licenses/LICENSE-2.0
  8. */
  9. #include <stdio.h>
  10. #include <string.h>
  11. #include <stdlib.h>
  12. #include <assert.h>
  13. #include <gmssl/hex.h>
  14. #include <gmssl/rand.h>
  15. #include <gmssl/aead.h>
  16. #include <gmssl/error.h>
  17. static int test_aead_sm4_cbc_sm3_hmac(void)
  18. {
  19. SM4_CBC_SM3_HMAC_CTX aead_ctx;
  20. uint8_t key[16 + 32];
  21. uint8_t iv[16];
  22. uint8_t aad[29];
  23. uint8_t plain[71];
  24. size_t plainlen = sizeof(plain);
  25. uint8_t cipher[256];
  26. size_t cipherlen = 0;
  27. uint8_t buf[256];
  28. size_t buflen = 0;
  29. size_t lens[] = { 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37 };
  30. uint8_t *in = plain;
  31. uint8_t *out = cipher;
  32. size_t inlen, outlen;
  33. size_t i;
  34. rand_bytes(key, sizeof(key));
  35. rand_bytes(iv, sizeof(iv));
  36. rand_bytes(aad, sizeof(aad));
  37. rand_bytes(plain, plainlen);
  38. if (sm4_cbc_sm3_hmac_encrypt_init(&aead_ctx, key, sizeof(key), iv, sizeof(iv), aad, sizeof(aad)) != 1) {
  39. error_print();
  40. return -1;
  41. }
  42. for (i = 0; plainlen; i++) {
  43. assert(i < sizeof(lens)/sizeof(lens[0]));
  44. inlen = plainlen < lens[i] ? plainlen : lens[i];
  45. if (sm4_cbc_sm3_hmac_encrypt_update(&aead_ctx, in, inlen, out, &outlen) != 1) {
  46. error_print();
  47. return -1;
  48. }
  49. in += inlen;
  50. plainlen -= inlen;
  51. out += outlen;
  52. cipherlen += outlen;
  53. }
  54. if (sm4_cbc_sm3_hmac_encrypt_finish(&aead_ctx, out, &outlen) != 1) {
  55. error_print();
  56. return -1;
  57. }
  58. out += outlen;
  59. cipherlen += outlen;
  60. format_bytes(stdout, 0, 4, "plaintext ", plain, sizeof(plain));
  61. format_bytes(stdout, 0, 4, "ciphertext", cipher, cipherlen);
  62. {
  63. SM4_KEY sm4_key;
  64. SM3_HMAC_CTX sm3_hmac_ctx;
  65. uint8_t tmp[256];
  66. size_t tmplen;
  67. sm4_set_encrypt_key(&sm4_key, key);
  68. if (sm4_cbc_padding_encrypt(&sm4_key, iv, plain, sizeof(plain), tmp, &tmplen) != 1) {
  69. error_print();
  70. return -1;
  71. }
  72. sm3_hmac_init(&sm3_hmac_ctx, key + 16, 32);
  73. sm3_hmac_update(&sm3_hmac_ctx, aad, sizeof(aad));
  74. sm3_hmac_update(&sm3_hmac_ctx, tmp, tmplen);
  75. sm3_hmac_finish(&sm3_hmac_ctx, tmp + tmplen);
  76. tmplen += 32;
  77. format_bytes(stdout, 0, 4, "ciphertext", tmp, tmplen);
  78. if (cipherlen != tmplen
  79. || memcmp(cipher, tmp, tmplen) != 0) {
  80. error_print();
  81. return -1;
  82. }
  83. }
  84. in = cipher;
  85. out = buf;
  86. if (sm4_cbc_sm3_hmac_decrypt_init(&aead_ctx, key, sizeof(key), iv, sizeof(iv), aad, sizeof(aad)) != 1) {
  87. error_print();
  88. return -1;
  89. }
  90. for (i = sizeof(lens)/sizeof(lens[0]) - 1; cipherlen; i--) {
  91. inlen = cipherlen < lens[i] ? cipherlen : lens[i];
  92. if (sm4_cbc_sm3_hmac_decrypt_update(&aead_ctx, in, inlen, out, &outlen) != 1) {
  93. error_print();
  94. return -1;
  95. }
  96. in += inlen;
  97. cipherlen -= inlen;
  98. out += outlen;
  99. buflen += outlen;
  100. }
  101. if (sm4_cbc_sm3_hmac_decrypt_finish(&aead_ctx, out, &outlen) != 1) {
  102. error_print();
  103. return -1;
  104. }
  105. out += outlen;
  106. buflen += outlen;
  107. format_bytes(stdout, 0, 4, "plaintext ", buf, buflen);
  108. if (buflen != sizeof(plain)) {
  109. error_print();
  110. return -1;
  111. }
  112. if (memcmp(buf, plain, sizeof(plain)) != 0) {
  113. error_print();
  114. return -1;
  115. }
  116. printf("%s() ok\n", __FUNCTION__);
  117. return 1;
  118. }
  119. static int test_aead_sm4_ctr_sm3_hmac(void)
  120. {
  121. SM4_CTR_SM3_HMAC_CTX aead_ctx;
  122. uint8_t key[16 + 32];
  123. uint8_t iv[16];
  124. uint8_t aad[29];
  125. uint8_t plain[71];
  126. size_t plainlen = sizeof(plain);
  127. uint8_t cipher[256];
  128. size_t cipherlen = 0;
  129. uint8_t buf[256];
  130. size_t buflen = 0;
  131. size_t lens[] = { 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37 };
  132. uint8_t *in = plain;
  133. uint8_t *out = cipher;
  134. size_t inlen, outlen;
  135. size_t i;
  136. rand_bytes(key, sizeof(key));
  137. rand_bytes(iv, sizeof(iv));
  138. rand_bytes(aad, sizeof(aad));
  139. rand_bytes(plain, plainlen);
  140. if (sm4_ctr_sm3_hmac_encrypt_init(&aead_ctx, key, sizeof(key), iv, sizeof(iv), aad, sizeof(aad)) != 1) {
  141. error_print();
  142. return -1;
  143. }
  144. for (i = 0; plainlen; i++) {
  145. assert(i < sizeof(lens)/sizeof(lens[0]));
  146. inlen = plainlen < lens[i] ? plainlen : lens[i];
  147. if (sm4_ctr_sm3_hmac_encrypt_update(&aead_ctx, in, inlen, out, &outlen) != 1) {
  148. error_print();
  149. return -1;
  150. }
  151. in += inlen;
  152. plainlen -= inlen;
  153. out += outlen;
  154. cipherlen += outlen;
  155. }
  156. if (sm4_ctr_sm3_hmac_encrypt_finish(&aead_ctx, out, &outlen) != 1) {
  157. error_print();
  158. return -1;
  159. }
  160. out += outlen;
  161. cipherlen += outlen;
  162. format_bytes(stdout, 0, 4, "plaintext ", plain, sizeof(plain));
  163. format_bytes(stdout, 0, 4, "ciphertext", cipher, cipherlen);
  164. {
  165. SM4_KEY sm4_key;
  166. uint8_t ctr[16];
  167. SM3_HMAC_CTX sm3_hmac_ctx;
  168. uint8_t tmp[256];
  169. size_t tmplen;
  170. sm4_set_encrypt_key(&sm4_key, key);
  171. memcpy(ctr, iv, 16);
  172. sm4_ctr_encrypt(&sm4_key, ctr, plain, sizeof(plain), tmp);
  173. tmplen = sizeof(plain);
  174. sm3_hmac_init(&sm3_hmac_ctx, key + 16, 32);
  175. sm3_hmac_update(&sm3_hmac_ctx, aad, sizeof(aad));
  176. sm3_hmac_update(&sm3_hmac_ctx, tmp, tmplen);
  177. sm3_hmac_finish(&sm3_hmac_ctx, tmp + tmplen);
  178. tmplen += 32;
  179. format_bytes(stdout, 0, 4, "ciphertext", tmp, tmplen);
  180. if (cipherlen != tmplen
  181. || memcmp(cipher, tmp, tmplen) != 0) {
  182. error_print();
  183. return -1;
  184. }
  185. }
  186. in = cipher;
  187. out = buf;
  188. if (sm4_ctr_sm3_hmac_decrypt_init(&aead_ctx, key, sizeof(key), iv, sizeof(iv), aad, sizeof(aad)) != 1) {
  189. error_print();
  190. return -1;
  191. }
  192. for (i = sizeof(lens)/sizeof(lens[0]) - 1; cipherlen; i--) {
  193. inlen = cipherlen < lens[i] ? cipherlen : lens[i];
  194. if (sm4_ctr_sm3_hmac_decrypt_update(&aead_ctx, in, inlen, out, &outlen) != 1) {
  195. error_print();
  196. return -1;
  197. }
  198. in += inlen;
  199. cipherlen -= inlen;
  200. out += outlen;
  201. buflen += outlen;
  202. }
  203. if (sm4_ctr_sm3_hmac_decrypt_finish(&aead_ctx, out, &outlen) != 1) {
  204. error_print();
  205. return -1;
  206. }
  207. out += outlen;
  208. buflen += outlen;
  209. format_bytes(stdout, 0, 4, "plaintext ", buf, buflen);
  210. if (buflen != sizeof(plain)) {
  211. error_print();
  212. return -1;
  213. }
  214. if (memcmp(buf, plain, sizeof(plain)) != 0) {
  215. error_print();
  216. return -1;
  217. }
  218. printf("%s() ok\n", __FUNCTION__);
  219. return 1;
  220. }
  221. static int test_aead_sm4_gcm(void)
  222. {
  223. SM4_GCM_CTX aead_ctx;
  224. uint8_t key[16];
  225. uint8_t iv[16];
  226. uint8_t aad[29];
  227. uint8_t plain[71];
  228. size_t plainlen = sizeof(plain);
  229. uint8_t cipher[256];
  230. size_t cipherlen = 0;
  231. uint8_t buf[256];
  232. size_t buflen = 0;
  233. size_t lens[] = { 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37 };
  234. uint8_t *in = plain;
  235. uint8_t *out = cipher;
  236. size_t inlen, outlen;
  237. size_t i;
  238. rand_bytes(key, sizeof(key));
  239. rand_bytes(iv, sizeof(iv));
  240. rand_bytes(aad, sizeof(aad));
  241. rand_bytes(plain, plainlen);
  242. if (sm4_gcm_encrypt_init(&aead_ctx, key, sizeof(key), iv, sizeof(iv), aad, sizeof(aad), GHASH_SIZE) != 1) {
  243. error_print();
  244. return -1;
  245. }
  246. for (i = 0; plainlen; i++) {
  247. assert(i < sizeof(lens)/sizeof(lens[0]));
  248. inlen = plainlen < lens[i] ? plainlen : lens[i];
  249. if (sm4_gcm_encrypt_update(&aead_ctx, in, inlen, out, &outlen) != 1) {
  250. error_print();
  251. return -1;
  252. }
  253. in += inlen;
  254. plainlen -= inlen;
  255. out += outlen;
  256. cipherlen += outlen;
  257. }
  258. if (sm4_gcm_encrypt_finish(&aead_ctx, out, &outlen) != 1) {
  259. error_print();
  260. return -1;
  261. }
  262. out += outlen;
  263. cipherlen += outlen;
  264. format_bytes(stdout, 0, 4, "plaintext ", plain, sizeof(plain));
  265. format_bytes(stdout, 0, 4, "ciphertext", cipher, cipherlen);
  266. {
  267. SM4_KEY sm4_key;
  268. uint8_t tmp[256];
  269. size_t tmplen;
  270. sm4_set_encrypt_key(&sm4_key, key);
  271. if (sm4_gcm_encrypt(&sm4_key, iv, sizeof(iv), aad, sizeof(aad), plain, sizeof(plain),
  272. tmp, GHASH_SIZE, tmp + sizeof(plain)) != 1) {
  273. error_print();
  274. return -1;
  275. }
  276. tmplen = sizeof(plain) + GHASH_SIZE;
  277. format_bytes(stdout, 0, 4, "ciphertext", tmp, tmplen);
  278. if (cipherlen != tmplen
  279. || memcmp(cipher, tmp, tmplen) != 0) {
  280. error_print();
  281. return -1;
  282. }
  283. }
  284. in = cipher;
  285. out = buf;
  286. if (sm4_gcm_decrypt_init(&aead_ctx, key, sizeof(key), iv, sizeof(iv), aad, sizeof(aad), GHASH_SIZE) != 1) {
  287. error_print();
  288. return -1;
  289. }
  290. for (i = sizeof(lens)/sizeof(lens[0]) - 1; cipherlen; i--) {
  291. inlen = cipherlen < lens[i] ? cipherlen : lens[i];
  292. if (sm4_gcm_decrypt_update(&aead_ctx, in, inlen, out, &outlen) != 1) {
  293. error_print();
  294. return -1;
  295. }
  296. in += inlen;
  297. cipherlen -= inlen;
  298. out += outlen;
  299. buflen += outlen;
  300. }
  301. if (sm4_gcm_decrypt_finish(&aead_ctx, out, &outlen) != 1) {
  302. error_print();
  303. return -1;
  304. }
  305. out += outlen;
  306. buflen += outlen;
  307. format_bytes(stdout, 0, 4, "plaintext ", buf, buflen);
  308. if (buflen != sizeof(plain)) {
  309. error_print();
  310. return -1;
  311. }
  312. if (memcmp(buf, plain, sizeof(plain)) != 0) {
  313. error_print();
  314. return -1;
  315. }
  316. printf("%s() ok\n", __FUNCTION__);
  317. return 1;
  318. }
  319. int main(void)
  320. {
  321. if (test_aead_sm4_cbc_sm3_hmac() != 1) { error_print(); return -1; }
  322. if (test_aead_sm4_ctr_sm3_hmac() != 1) { error_print(); return -1; }
  323. if (test_aead_sm4_gcm() != 1) { error_print(); return -1; }
  324. printf("%s all tests passed!\n", __FILE__);
  325. return 0;
  326. }