diff options
Diffstat (limited to 'src/lib/libcrypto/mlkem/mlkem_internal.c')
-rw-r--r-- | src/lib/libcrypto/mlkem/mlkem_internal.c | 1286 |
1 files changed, 1286 insertions, 0 deletions
diff --git a/src/lib/libcrypto/mlkem/mlkem_internal.c b/src/lib/libcrypto/mlkem/mlkem_internal.c new file mode 100644 index 0000000000..653b2f332d --- /dev/null +++ b/src/lib/libcrypto/mlkem/mlkem_internal.c | |||
@@ -0,0 +1,1286 @@ | |||
1 | /* $OpenBSD: mlkem_internal.c,v 1.1 2025/09/05 23:30:12 beck Exp $ */ | ||
2 | /* | ||
3 | * Copyright (c) 2024, Google Inc. | ||
4 | * Copyright (c) 2024, 2025 Bob Beck <beck@obtuse.com> | ||
5 | * | ||
6 | * Permission to use, copy, modify, and/or distribute this software for any | ||
7 | * purpose with or without fee is hereby granted, provided that the above | ||
8 | * copyright notice and this permission notice appear in all copies. | ||
9 | * | ||
10 | * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES | ||
11 | * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF | ||
12 | * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY | ||
13 | * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES | ||
14 | * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION | ||
15 | * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN | ||
16 | * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. | ||
17 | */ | ||
18 | |||
19 | #include <assert.h> | ||
20 | #include <stdlib.h> | ||
21 | #include <string.h> | ||
22 | #include <stdio.h> | ||
23 | |||
24 | #include <openssl/mlkem.h> | ||
25 | |||
26 | #include "bytestring.h" | ||
27 | #include "sha3_internal.h" | ||
28 | #include "mlkem_internal.h" | ||
29 | #include "constant_time.h" | ||
30 | #include "crypto_internal.h" | ||
31 | |||
32 | /* | ||
33 | * See | ||
34 | * https://csrc.nist.gov/pubs/fips/203/final | ||
35 | */ | ||
36 | |||
37 | static void | ||
38 | prf(uint8_t *out, size_t out_len, const uint8_t in[33]) | ||
39 | { | ||
40 | sha3_ctx ctx; | ||
41 | shake256_init(&ctx); | ||
42 | shake_update(&ctx, in, 33); | ||
43 | shake_xof(&ctx); | ||
44 | shake_out(&ctx, out, out_len); | ||
45 | } | ||
46 | |||
47 | /* Section 4.1 */ | ||
48 | static void | ||
49 | hash_h(uint8_t out[32], const uint8_t *in, size_t len) | ||
50 | { | ||
51 | sha3_ctx ctx; | ||
52 | sha3_init(&ctx, 32); | ||
53 | sha3_update(&ctx, in, len); | ||
54 | sha3_final(out, &ctx); | ||
55 | } | ||
56 | |||
57 | static void | ||
58 | hash_g(uint8_t out[64], const uint8_t *in, size_t len) | ||
59 | { | ||
60 | sha3_ctx ctx; | ||
61 | sha3_init(&ctx, 64); | ||
62 | sha3_update(&ctx, in, len); | ||
63 | sha3_final(out, &ctx); | ||
64 | } | ||
65 | |||
66 | /* this is called 'J' in the spec */ | ||
67 | static void | ||
68 | kdf(uint8_t out[MLKEM_SHARED_SECRET_LENGTH], const uint8_t failure_secret[32], | ||
69 | const uint8_t *in, size_t len) | ||
70 | { | ||
71 | sha3_ctx ctx; | ||
72 | shake256_init(&ctx); | ||
73 | shake_update(&ctx, failure_secret, 32); | ||
74 | shake_update(&ctx, in, len); | ||
75 | shake_xof(&ctx); | ||
76 | shake_out(&ctx, out, MLKEM_SHARED_SECRET_LENGTH); | ||
77 | } | ||
78 | |||
79 | #define DEGREE 256 | ||
80 | |||
81 | static const size_t kBarrettMultiplier = 5039; | ||
82 | static const unsigned kBarrettShift = 24; | ||
83 | static const uint16_t kPrime = 3329; | ||
84 | static const int kLog2Prime = 12; | ||
85 | static const uint16_t kHalfPrime = (/*kPrime=*/3329 - 1) / 2; | ||
86 | static const int kDU768 = 10; | ||
87 | static const int kDV768 = 4; | ||
88 | static const int kDU1024 = 11; | ||
89 | static const int kDV1024 = 5; | ||
90 | |||
91 | /* | ||
92 | * kInverseDegree is 128^-1 mod 3329; 128 because kPrime does not have a 512th | ||
93 | * root of unity. | ||
94 | */ | ||
95 | static const uint16_t kInverseDegree = 3303; | ||
96 | |||
97 | static inline size_t | ||
98 | encoded_vector_size(uint16_t rank) | ||
99 | { | ||
100 | return (kLog2Prime * DEGREE / 8) * rank; | ||
101 | } | ||
102 | |||
103 | static inline size_t | ||
104 | compressed_vector_size(uint16_t rank) | ||
105 | { | ||
106 | return ((rank == RANK768) ? kDU768 : kDU1024) * rank * DEGREE / 8; | ||
107 | } | ||
108 | |||
109 | typedef struct scalar { | ||
110 | /* On every function entry and exit, 0 <= c < kPrime. */ | ||
111 | uint16_t c[DEGREE]; | ||
112 | } scalar; | ||
113 | |||
114 | /* | ||
115 | * Retrieve a const scalar from const matrix of |rank| at position [row][col] | ||
116 | */ | ||
117 | static inline const scalar * | ||
118 | const_m2s(const scalar *v, size_t row, size_t col, uint16_t rank) | ||
119 | { | ||
120 | return ((scalar *)v) + row * rank + col; | ||
121 | } | ||
122 | |||
123 | /* | ||
124 | * Retrieve a scalar from matrix of |rank| at position [row][col] | ||
125 | */ | ||
126 | static inline scalar * | ||
127 | m2s(scalar *v, size_t row, size_t col, uint16_t rank) | ||
128 | { | ||
129 | return ((scalar *)v) + row * rank + col; | ||
130 | } | ||
131 | |||
132 | /* | ||
133 | * This bit of Python will be referenced in some of the following comments: | ||
134 | * | ||
135 | * p = 3329 | ||
136 | * | ||
137 | * def bitreverse(i): | ||
138 | * ret = 0 | ||
139 | * for n in range(7): | ||
140 | * bit = i & 1 | ||
141 | * ret <<= 1 | ||
142 | * ret |= bit | ||
143 | * i >>= 1 | ||
144 | * return ret | ||
145 | */ | ||
146 | |||
147 | /* kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)] */ | ||
148 | static const uint16_t kNTTRoots[128] = { | ||
149 | 1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797, | ||
150 | 2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333, | ||
151 | 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756, | ||
152 | 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915, | ||
153 | 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648, | ||
154 | 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100, | ||
155 | 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789, | ||
156 | 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641, | ||
157 | 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757, | ||
158 | 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886, | ||
159 | 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154, | ||
160 | }; | ||
161 | |||
162 | /* kInverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)] */ | ||
163 | static const uint16_t kInverseNTTRoots[128] = { | ||
164 | 1, 1600, 40, 749, 2481, 1432, 2699, 687, 1583, 2760, 69, 543, | ||
165 | 2532, 3136, 1410, 2267, 2508, 1355, 450, 936, 447, 2794, 1235, 1903, | ||
166 | 1996, 1089, 3273, 283, 1853, 1990, 882, 3033, 2419, 2102, 219, 855, | ||
167 | 2681, 1848, 712, 682, 927, 1795, 461, 1891, 2877, 2522, 1894, 1010, | ||
168 | 1414, 2009, 3296, 464, 2697, 816, 1352, 2679, 1274, 1052, 1025, 2132, | ||
169 | 1573, 76, 2998, 3040, 1175, 2444, 394, 1219, 2300, 1455, 2117, 1607, | ||
170 | 2443, 554, 1179, 2186, 2303, 2926, 2237, 525, 735, 863, 2768, 1230, | ||
171 | 2572, 556, 3010, 2266, 1684, 1239, 780, 2954, 109, 1292, 1031, 1745, | ||
172 | 2688, 3061, 992, 2596, 941, 892, 1021, 2390, 642, 1868, 2377, 1482, | ||
173 | 1540, 540, 1678, 1626, 279, 314, 1173, 2573, 3096, 48, 667, 1920, | ||
174 | 2229, 1041, 2606, 1692, 680, 2746, 568, 3312, | ||
175 | }; | ||
176 | |||
177 | /* kModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)] */ | ||
178 | static const uint16_t kModRoots[128] = { | ||
179 | 17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606, | ||
180 | 2288, 1041, 1100, 2229, 1409, 1920, 2662, 667, 3281, 48, 233, 3096, | ||
181 | 756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678, | ||
182 | 2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642, | ||
183 | 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992, | ||
184 | 268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109, | ||
185 | 375, 2954, 2549, 780, 2090, 1239, 1645, 1684, 1063, 2266, 319, 3010, | ||
186 | 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735, | ||
187 | 2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179, | ||
188 | 2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300, | ||
189 | 2110, 1219, 2935, 394, 885, 2444, 2154, 1175, | ||
190 | }; | ||
191 | |||
192 | /* reduce_once reduces 0 <= x < 2*kPrime, mod kPrime. */ | ||
193 | static uint16_t | ||
194 | reduce_once(uint16_t x) | ||
195 | { | ||
196 | assert(x < 2 * kPrime); | ||
197 | const uint16_t subtracted = x - kPrime; | ||
198 | uint16_t mask = 0u - (subtracted >> 15); | ||
199 | |||
200 | /* | ||
201 | * Although this is a constant-time select, we omit a value barrier here. | ||
202 | * Value barriers impede auto-vectorization (likely because it forces the | ||
203 | * value to transit through a general-purpose register). On AArch64, this | ||
204 | * is a difference of 2x. | ||
205 | * | ||
206 | * We usually add value barriers to selects because Clang turns | ||
207 | * consecutive selects with the same condition into a branch instead of | ||
208 | * CMOV/CSEL. This condition does not occur in ML-KEM, so omitting it | ||
209 | * seems to be safe so far but see | ||
210 | * |scalar_centered_binomial_distribution_eta_2_with_prf|. | ||
211 | */ | ||
212 | return (mask & x) | (~mask & subtracted); | ||
213 | } | ||
214 | |||
215 | /* | ||
216 | * constant time reduce x mod kPrime using Barrett reduction. x must be less | ||
217 | * than kPrime + 2×kPrime². | ||
218 | */ | ||
219 | static uint16_t | ||
220 | reduce(uint32_t x) | ||
221 | { | ||
222 | uint64_t product = (uint64_t)x * kBarrettMultiplier; | ||
223 | uint32_t quotient = (uint32_t)(product >> kBarrettShift); | ||
224 | uint32_t remainder = x - quotient * kPrime; | ||
225 | |||
226 | assert(x < kPrime + 2u * kPrime * kPrime); | ||
227 | return reduce_once(remainder); | ||
228 | } | ||
229 | |||
230 | static void | ||
231 | scalar_zero(scalar *out) | ||
232 | { | ||
233 | memset(out, 0, sizeof(*out)); | ||
234 | } | ||
235 | |||
236 | static void | ||
237 | vector_zero(scalar *out, size_t rank) | ||
238 | { | ||
239 | memset(out, 0, sizeof(*out) * rank); | ||
240 | } | ||
241 | |||
242 | /* | ||
243 | * In place number theoretic transform of a given scalar. | ||
244 | * Note that MLKEM's kPrime 3329 does not have a 512th root of unity, so this | ||
245 | * transform leaves off the last iteration of the usual FFT code, with the 128 | ||
246 | * relevant roots of unity being stored in |kNTTRoots|. This means the output | ||
247 | * should be seen as 128 elements in GF(3329^2), with the coefficients of the | ||
248 | * elements being consecutive entries in |s->c|. | ||
249 | */ | ||
250 | static void | ||
251 | scalar_ntt(scalar *s) | ||
252 | { | ||
253 | int offset = DEGREE; | ||
254 | int step; | ||
255 | /* | ||
256 | * `int` is used here because using `size_t` throughout caused a ~5% slowdown | ||
257 | * with Clang 14 on Aarch64. | ||
258 | */ | ||
259 | for (step = 1; step < DEGREE / 2; step <<= 1) { | ||
260 | int i, j, k = 0; | ||
261 | |||
262 | offset >>= 1; | ||
263 | for (i = 0; i < step; i++) { | ||
264 | const uint32_t step_root = kNTTRoots[i + step]; | ||
265 | |||
266 | for (j = k; j < k + offset; j++) { | ||
267 | uint16_t odd, even; | ||
268 | |||
269 | odd = reduce(step_root * s->c[j + offset]); | ||
270 | even = s->c[j]; | ||
271 | s->c[j] = reduce_once(odd + even); | ||
272 | s->c[j + offset] = reduce_once(even - odd + | ||
273 | kPrime); | ||
274 | } | ||
275 | k += 2 * offset; | ||
276 | } | ||
277 | } | ||
278 | } | ||
279 | |||
280 | static void | ||
281 | vector_ntt(scalar *v, size_t rank) | ||
282 | { | ||
283 | size_t i; | ||
284 | |||
285 | for (i = 0; i < rank; i++) { | ||
286 | scalar_ntt(&v[i]); | ||
287 | } | ||
288 | } | ||
289 | |||
290 | /* | ||
291 | * In place inverse number theoretic transform of a given scalar, with pairs of | ||
292 | * entries of s->v being interpreted as elements of GF(3329^2). Just as with the | ||
293 | * number theoretic transform, this leaves off the first step of the normal iFFT | ||
294 | * to account for the fact that 3329 does not have a 512th root of unity, using | ||
295 | * the precomputed 128 roots of unity stored in |kInverseNTTRoots|. | ||
296 | */ | ||
297 | static void | ||
298 | scalar_inverse_ntt(scalar *s) | ||
299 | { | ||
300 | int i, j, k, offset, step = DEGREE / 2; | ||
301 | |||
302 | /* | ||
303 | * `int` is used here because using `size_t` throughout caused a ~5% slowdown | ||
304 | * with Clang 14 on Aarch64. | ||
305 | */ | ||
306 | for (offset = 2; offset < DEGREE; offset <<= 1) { | ||
307 | step >>= 1; | ||
308 | k = 0; | ||
309 | for (i = 0; i < step; i++) { | ||
310 | uint32_t step_root = kInverseNTTRoots[i + step]; | ||
311 | for (j = k; j < k + offset; j++) { | ||
312 | uint16_t odd, even; | ||
313 | odd = s->c[j + offset]; | ||
314 | even = s->c[j]; | ||
315 | s->c[j] = reduce_once(odd + even); | ||
316 | s->c[j + offset] = reduce(step_root * | ||
317 | (even - odd + kPrime)); | ||
318 | } | ||
319 | k += 2 * offset; | ||
320 | } | ||
321 | } | ||
322 | for (i = 0; i < DEGREE; i++) { | ||
323 | s->c[i] = reduce(s->c[i] * kInverseDegree); | ||
324 | } | ||
325 | } | ||
326 | |||
327 | static void | ||
328 | vector_inverse_ntt(scalar *v, size_t rank) | ||
329 | { | ||
330 | size_t i; | ||
331 | |||
332 | for (i = 0; i < rank; i++) { | ||
333 | scalar_inverse_ntt(&v[i]); | ||
334 | } | ||
335 | } | ||
336 | |||
337 | static void | ||
338 | scalar_add(scalar *lhs, const scalar *rhs) | ||
339 | { | ||
340 | int i; | ||
341 | |||
342 | for (i = 0; i < DEGREE; i++) { | ||
343 | lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]); | ||
344 | } | ||
345 | } | ||
346 | |||
347 | static void | ||
348 | scalar_sub(scalar *lhs, const scalar *rhs) | ||
349 | { | ||
350 | int i; | ||
351 | |||
352 | for (i = 0; i < DEGREE; i++) { | ||
353 | lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime); | ||
354 | } | ||
355 | } | ||
356 | |||
357 | /* | ||
358 | * Multiplying two scalars in the number theoretically transformed state. | ||
359 | * Since 3329 does not have a 512th root of unity, this means we have to | ||
360 | * interpret the 2*ith and (2*i+1)th entries of the scalar as elements of | ||
361 | * GF(3329)[X]/(X^2 - 17^(2*bitreverse(i)+1)). | ||
362 | * The value of 17^(2*bitreverse(i)+1) mod 3329 is stored in the precomputed | ||
363 | * |kModRoots| table. Our Barrett transform only allows us to multiply two | ||
364 | * reduced numbers together, so we need some intermediate reduction steps, | ||
365 | * even if an uint64_t could hold 3 multiplied numbers. | ||
366 | */ | ||
367 | static void | ||
368 | scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs) | ||
369 | { | ||
370 | int i; | ||
371 | |||
372 | for (i = 0; i < DEGREE / 2; i++) { | ||
373 | uint32_t real_real = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i]; | ||
374 | uint32_t img_img = (uint32_t)lhs->c[2 * i + 1] * | ||
375 | rhs->c[2 * i + 1]; | ||
376 | uint32_t real_img = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i + 1]; | ||
377 | uint32_t img_real = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i]; | ||
378 | |||
379 | out->c[2 * i] = | ||
380 | reduce(real_real + | ||
381 | (uint32_t)reduce(img_img) * kModRoots[i]); | ||
382 | out->c[2 * i + 1] = reduce(img_real + real_img); | ||
383 | } | ||
384 | } | ||
385 | |||
386 | static void | ||
387 | vector_add(scalar *lhs, const scalar *rhs, size_t rank) | ||
388 | { | ||
389 | size_t i; | ||
390 | |||
391 | for (i = 0; i < rank; i++) { | ||
392 | scalar_add(&lhs[i], &rhs[i]); | ||
393 | } | ||
394 | } | ||
395 | |||
396 | static void | ||
397 | matrix_mult(scalar *out, const void *m, const scalar *a, size_t rank) | ||
398 | { | ||
399 | size_t i, j; | ||
400 | |||
401 | vector_zero(&out[0], rank); | ||
402 | for (i = 0; i < rank; i++) { | ||
403 | for (j = 0; j < rank; j++) { | ||
404 | scalar product; | ||
405 | |||
406 | scalar_mult(&product, const_m2s(m, i, j, rank), &a[j]); | ||
407 | scalar_add(&out[i], &product); | ||
408 | } | ||
409 | } | ||
410 | } | ||
411 | |||
412 | static void | ||
413 | matrix_mult_transpose(scalar *out, const void *m, const scalar *a, size_t rank) | ||
414 | { | ||
415 | int i, j; | ||
416 | |||
417 | vector_zero(&out[0], rank); | ||
418 | for (i = 0; i < rank; i++) { | ||
419 | for (j = 0; j < rank; j++) { | ||
420 | scalar product; | ||
421 | |||
422 | scalar_mult(&product, const_m2s(m, j, i, rank), &a[j]); | ||
423 | scalar_add(&out[i], &product); | ||
424 | } | ||
425 | } | ||
426 | } | ||
427 | |||
428 | static void | ||
429 | scalar_inner_product(scalar *out, const scalar *lhs, | ||
430 | const scalar *rhs, size_t rank) | ||
431 | { | ||
432 | size_t i; | ||
433 | |||
434 | scalar_zero(out); | ||
435 | for (i = 0; i < rank; i++) { | ||
436 | scalar product; | ||
437 | |||
438 | scalar_mult(&product, &lhs[i], &rhs[i]); | ||
439 | scalar_add(out, &product); | ||
440 | } | ||
441 | } | ||
442 | |||
443 | /* | ||
444 | * Algorithm 6 of spec. Rejection samples a Keccak stream to get uniformly | ||
445 | * distributed elements. This is used for matrix expansion and only operates on | ||
446 | * public inputs. | ||
447 | */ | ||
448 | static void | ||
449 | scalar_from_keccak_vartime(scalar *out, sha3_ctx *keccak_ctx) | ||
450 | { | ||
451 | int i, done = 0; | ||
452 | |||
453 | while (done < DEGREE) { | ||
454 | uint8_t block[168]; | ||
455 | |||
456 | shake_out(keccak_ctx, block, sizeof(block)); | ||
457 | for (i = 0; i < sizeof(block) && done < DEGREE; i += 3) { | ||
458 | uint16_t d1 = block[i] + 256 * (block[i + 1] % 16); | ||
459 | uint16_t d2 = block[i + 1] / 16 + 16 * block[i + 2]; | ||
460 | |||
461 | if (d1 < kPrime) { | ||
462 | out->c[done++] = d1; | ||
463 | } | ||
464 | if (d2 < kPrime && done < DEGREE) { | ||
465 | out->c[done++] = d2; | ||
466 | } | ||
467 | } | ||
468 | } | ||
469 | } | ||
470 | |||
471 | /* | ||
472 | * Algorithm 7 of the spec, with eta fixed to two and the PRF call | ||
473 | * included. Creates binominally distributed elements by sampling 2*|eta| bits, | ||
474 | * and setting the coefficient to the count of the first bits minus the count of | ||
475 | * the second bits, resulting in a centered binomial distribution. Since eta is | ||
476 | * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4, | ||
477 | * and 0 with probability 3/8. | ||
478 | */ | ||
479 | static void | ||
480 | scalar_centered_binomial_distribution_eta_2_with_prf(scalar *out, | ||
481 | const uint8_t input[33]) | ||
482 | { | ||
483 | uint8_t entropy[128]; | ||
484 | int i; | ||
485 | |||
486 | CTASSERT(sizeof(entropy) == 2 * /*kEta=*/ 2 * DEGREE / 8); | ||
487 | prf(entropy, sizeof(entropy), input); | ||
488 | |||
489 | for (i = 0; i < DEGREE; i += 2) { | ||
490 | uint8_t byte = entropy[i / 2]; | ||
491 | uint16_t mask; | ||
492 | uint16_t value = (byte & 1) + ((byte >> 1) & 1); | ||
493 | |||
494 | value -= ((byte >> 2) & 1) + ((byte >> 3) & 1); | ||
495 | |||
496 | /* | ||
497 | * Add |kPrime| if |value| underflowed. See |reduce_once| for a | ||
498 | * discussion on why the value barrier is omitted. While this | ||
499 | * could have been written reduce_once(value + kPrime), this is | ||
500 | * one extra addition and small range of |value| tempts some | ||
501 | * versions of Clang to emit a branch. | ||
502 | */ | ||
503 | mask = 0u - (value >> 15); | ||
504 | out->c[i] = ((value + kPrime) & mask) | (value & ~mask); | ||
505 | |||
506 | byte >>= 4; | ||
507 | value = (byte & 1) + ((byte >> 1) & 1); | ||
508 | value -= ((byte >> 2) & 1) + ((byte >> 3) & 1); | ||
509 | /* See above. */ | ||
510 | mask = 0u - (value >> 15); | ||
511 | out->c[i + 1] = ((value + kPrime) & mask) | (value & ~mask); | ||
512 | } | ||
513 | } | ||
514 | |||
515 | /* | ||
516 | * Generates a secret vector by using | ||
517 | * |scalar_centered_binomial_distribution_eta_2_with_prf|, using the given seed | ||
518 | * appending and incrementing |counter| for entry of the vector. | ||
519 | */ | ||
520 | static void | ||
521 | vector_generate_secret_eta_2(scalar *out, uint8_t *counter, | ||
522 | const uint8_t seed[32], size_t rank) | ||
523 | { | ||
524 | uint8_t input[33]; | ||
525 | size_t i; | ||
526 | |||
527 | memcpy(input, seed, 32); | ||
528 | for (i = 0; i < rank; i++) { | ||
529 | input[32] = (*counter)++; | ||
530 | scalar_centered_binomial_distribution_eta_2_with_prf(&out[i], | ||
531 | input); | ||
532 | } | ||
533 | } | ||
534 | |||
535 | /* Expands the matrix of a seed for key generation and for encaps-CPA. */ | ||
536 | static void | ||
537 | matrix_expand(void *out, const uint8_t rho[32], size_t rank) | ||
538 | { | ||
539 | uint8_t input[34]; | ||
540 | size_t i, j; | ||
541 | |||
542 | memcpy(input, rho, 32); | ||
543 | for (i = 0; i < rank; i++) { | ||
544 | for (j = 0; j < rank; j++) { | ||
545 | sha3_ctx keccak_ctx; | ||
546 | |||
547 | input[32] = i; | ||
548 | input[33] = j; | ||
549 | shake128_init(&keccak_ctx); | ||
550 | shake_update(&keccak_ctx, input, sizeof(input)); | ||
551 | shake_xof(&keccak_ctx); | ||
552 | scalar_from_keccak_vartime(m2s(out, i, j, rank), | ||
553 | &keccak_ctx); | ||
554 | } | ||
555 | } | ||
556 | } | ||
557 | |||
558 | static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f, | ||
559 | 0x1f, 0x3f, 0x7f, 0xff}; | ||
560 | |||
561 | static void | ||
562 | scalar_encode(uint8_t *out, const scalar *s, int bits) | ||
563 | { | ||
564 | uint8_t out_byte = 0; | ||
565 | int i, out_byte_bits = 0; | ||
566 | |||
567 | assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1); | ||
568 | for (i = 0; i < DEGREE; i++) { | ||
569 | uint16_t element = s->c[i]; | ||
570 | int element_bits_done = 0; | ||
571 | |||
572 | while (element_bits_done < bits) { | ||
573 | int chunk_bits = bits - element_bits_done; | ||
574 | int out_bits_remaining = 8 - out_byte_bits; | ||
575 | |||
576 | if (chunk_bits >= out_bits_remaining) { | ||
577 | chunk_bits = out_bits_remaining; | ||
578 | out_byte |= (element & | ||
579 | kMasks[chunk_bits - 1]) << out_byte_bits; | ||
580 | *out = out_byte; | ||
581 | out++; | ||
582 | out_byte_bits = 0; | ||
583 | out_byte = 0; | ||
584 | } else { | ||
585 | out_byte |= (element & | ||
586 | kMasks[chunk_bits - 1]) << out_byte_bits; | ||
587 | out_byte_bits += chunk_bits; | ||
588 | } | ||
589 | |||
590 | element_bits_done += chunk_bits; | ||
591 | element >>= chunk_bits; | ||
592 | } | ||
593 | } | ||
594 | |||
595 | if (out_byte_bits > 0) { | ||
596 | *out = out_byte; | ||
597 | } | ||
598 | } | ||
599 | |||
600 | /* scalar_encode_1 is |scalar_encode| specialised for |bits| == 1. */ | ||
601 | static void | ||
602 | scalar_encode_1(uint8_t out[32], const scalar *s) | ||
603 | { | ||
604 | int i, j; | ||
605 | |||
606 | for (i = 0; i < DEGREE; i += 8) { | ||
607 | uint8_t out_byte = 0; | ||
608 | |||
609 | for (j = 0; j < 8; j++) { | ||
610 | out_byte |= (s->c[i + j] & 1) << j; | ||
611 | } | ||
612 | *out = out_byte; | ||
613 | out++; | ||
614 | } | ||
615 | } | ||
616 | |||
617 | /* | ||
618 | * Encodes an entire vector into 32*|RANK768|*|bits| bytes. Note that since 256 | ||
619 | * (DEGREE) is divisible by 8, the individual vector entries will always fill a | ||
620 | * whole number of bytes, so we do not need to worry about bit packing here. | ||
621 | */ | ||
622 | static void | ||
623 | vector_encode(uint8_t *out, const scalar *a, int bits, size_t rank) | ||
624 | { | ||
625 | int i; | ||
626 | |||
627 | for (i = 0; i < rank; i++) { | ||
628 | scalar_encode(out + i * bits * DEGREE / 8, &a[i], bits); | ||
629 | } | ||
630 | } | ||
631 | |||
632 | /* Encodes an entire vector as above, but adding it to a CBB */ | ||
633 | static int | ||
634 | vector_encode_cbb(CBB *cbb, const scalar *a, int bits, size_t rank) | ||
635 | { | ||
636 | uint8_t *encoded_vector; | ||
637 | |||
638 | if (!CBB_add_space(cbb, &encoded_vector, encoded_vector_size(rank))) | ||
639 | return 0; | ||
640 | vector_encode(encoded_vector, a, bits, rank); | ||
641 | |||
642 | return 1; | ||
643 | } | ||
644 | |||
645 | /* | ||
646 | * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in | ||
647 | * |out|. It returns one on success and zero if any parsed value is >= | ||
648 | * |kPrime|. | ||
649 | */ | ||
650 | static int | ||
651 | scalar_decode(scalar *out, const uint8_t *in, int bits) | ||
652 | { | ||
653 | uint8_t in_byte = 0; | ||
654 | int i, in_byte_bits_left = 0; | ||
655 | |||
656 | assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1); | ||
657 | |||
658 | for (i = 0; i < DEGREE; i++) { | ||
659 | uint16_t element = 0; | ||
660 | int element_bits_done = 0; | ||
661 | |||
662 | while (element_bits_done < bits) { | ||
663 | int chunk_bits = bits - element_bits_done; | ||
664 | |||
665 | if (in_byte_bits_left == 0) { | ||
666 | in_byte = *in; | ||
667 | in++; | ||
668 | in_byte_bits_left = 8; | ||
669 | } | ||
670 | |||
671 | if (chunk_bits > in_byte_bits_left) { | ||
672 | chunk_bits = in_byte_bits_left; | ||
673 | } | ||
674 | |||
675 | element |= (in_byte & kMasks[chunk_bits - 1]) << | ||
676 | element_bits_done; | ||
677 | in_byte_bits_left -= chunk_bits; | ||
678 | in_byte >>= chunk_bits; | ||
679 | |||
680 | element_bits_done += chunk_bits; | ||
681 | } | ||
682 | |||
683 | if (element >= kPrime) { | ||
684 | return 0; | ||
685 | } | ||
686 | out->c[i] = element; | ||
687 | } | ||
688 | |||
689 | return 1; | ||
690 | } | ||
691 | |||
692 | /* scalar_decode_1 is |scalar_decode| specialised for |bits| == 1. */ | ||
693 | static void | ||
694 | scalar_decode_1(scalar *out, const uint8_t in[32]) | ||
695 | { | ||
696 | int i, j; | ||
697 | |||
698 | for (i = 0; i < DEGREE; i += 8) { | ||
699 | uint8_t in_byte = *in; | ||
700 | |||
701 | in++; | ||
702 | for (j = 0; j < 8; j++) { | ||
703 | out->c[i + j] = in_byte & 1; | ||
704 | in_byte >>= 1; | ||
705 | } | ||
706 | } | ||
707 | } | ||
708 | |||
709 | /* | ||
710 | * Decodes 32*|RANK768|*|bits| bytes from |in| into |out|. It returns one on | ||
711 | * success or zero if any parsed value is >= |kPrime|. | ||
712 | */ | ||
713 | static int | ||
714 | vector_decode(scalar *out, const uint8_t *in, int bits, size_t rank) | ||
715 | { | ||
716 | size_t i; | ||
717 | |||
718 | for (i = 0; i < rank; i++) { | ||
719 | if (!scalar_decode(&out[i], in + i * bits * DEGREE / 8, | ||
720 | bits)) { | ||
721 | return 0; | ||
722 | } | ||
723 | } | ||
724 | return 1; | ||
725 | } | ||
726 | |||
727 | /* | ||
728 | * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping | ||
729 | * numbers close to each other together. The formula used is | ||
730 | * round(2^|bits|/kPrime*x) mod 2^|bits|. | ||
731 | * Uses Barrett reduction to achieve constant time. Since we need both the | ||
732 | * remainder (for rounding) and the quotient (as the result), we cannot use | ||
733 | * |reduce| here, but need to do the Barrett reduction directly. | ||
734 | */ | ||
735 | static uint16_t | ||
736 | compress(uint16_t x, int bits) | ||
737 | { | ||
738 | uint32_t shifted = (uint32_t)x << bits; | ||
739 | uint64_t product = (uint64_t)shifted * kBarrettMultiplier; | ||
740 | uint32_t quotient = (uint32_t)(product >> kBarrettShift); | ||
741 | uint32_t remainder = shifted - quotient * kPrime; | ||
742 | |||
743 | /* | ||
744 | * Adjust the quotient to round correctly: | ||
745 | * 0 <= remainder <= kHalfPrime round to 0 | ||
746 | * kHalfPrime < remainder <= kPrime + kHalfPrime round to 1 | ||
747 | * kPrime + kHalfPrime < remainder < 2 * kPrime round to 2 | ||
748 | */ | ||
749 | assert(remainder < 2u * kPrime); | ||
750 | quotient += 1 & constant_time_lt(kHalfPrime, remainder); | ||
751 | quotient += 1 & constant_time_lt(kPrime + kHalfPrime, remainder); | ||
752 | return quotient & ((1 << bits) - 1); | ||
753 | } | ||
754 | |||
755 | /* | ||
756 | * Decompresses |x| by using an equi-distant representative. The formula is | ||
757 | * round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us to | ||
758 | * implement this logic using only bit operations. | ||
759 | */ | ||
760 | static uint16_t | ||
761 | decompress(uint16_t x, int bits) | ||
762 | { | ||
763 | uint32_t product = (uint32_t)x * kPrime; | ||
764 | uint32_t power = 1 << bits; | ||
765 | /* This is |product| % power, since |power| is a power of 2. */ | ||
766 | uint32_t remainder = product & (power - 1); | ||
767 | /* This is |product| / power, since |power| is a power of 2. */ | ||
768 | uint32_t lower = product >> bits; | ||
769 | |||
770 | /* | ||
771 | * The rounding logic works since the first half of numbers mod |power| have a | ||
772 | * 0 as first bit, and the second half has a 1 as first bit, since |power| is | ||
773 | * a power of 2. As a 12 bit number, |remainder| is always positive, so we | ||
774 | * will shift in 0s for a right shift. | ||
775 | */ | ||
776 | return lower + (remainder >> (bits - 1)); | ||
777 | } | ||
778 | |||
779 | static void | ||
780 | scalar_compress(scalar *s, int bits) | ||
781 | { | ||
782 | int i; | ||
783 | |||
784 | for (i = 0; i < DEGREE; i++) { | ||
785 | s->c[i] = compress(s->c[i], bits); | ||
786 | } | ||
787 | } | ||
788 | |||
789 | static void | ||
790 | scalar_decompress(scalar *s, int bits) | ||
791 | { | ||
792 | int i; | ||
793 | |||
794 | for (i = 0; i < DEGREE; i++) { | ||
795 | s->c[i] = decompress(s->c[i], bits); | ||
796 | } | ||
797 | } | ||
798 | |||
799 | static void | ||
800 | vector_compress(scalar *v, int bits, size_t rank) | ||
801 | { | ||
802 | size_t i; | ||
803 | |||
804 | for (i = 0; i < rank; i++) { | ||
805 | scalar_compress(&v[i], bits); | ||
806 | } | ||
807 | } | ||
808 | |||
809 | static void | ||
810 | vector_decompress(scalar *v, int bits, size_t rank) | ||
811 | { | ||
812 | int i; | ||
813 | |||
814 | for (i = 0; i < rank; i++) { | ||
815 | scalar_decompress(&v[i], bits); | ||
816 | } | ||
817 | } | ||
818 | |||
819 | struct public_key { | ||
820 | scalar *t; | ||
821 | uint8_t *rho; | ||
822 | uint8_t *public_key_hash; | ||
823 | scalar *m; | ||
824 | }; | ||
825 | |||
826 | static void | ||
827 | public_key_from_external(const MLKEM_public_key *external, | ||
828 | struct public_key *pub) | ||
829 | { | ||
830 | size_t vector_size = external->rank * sizeof(scalar); | ||
831 | uint8_t *bytes = external->key_768->bytes; | ||
832 | size_t offset = 0; | ||
833 | |||
834 | if (external->rank == RANK1024) | ||
835 | bytes = external->key_1024->bytes; | ||
836 | |||
837 | pub->t = (struct scalar *)bytes + offset; | ||
838 | offset += vector_size; | ||
839 | pub->rho = bytes + offset; | ||
840 | offset += 32; | ||
841 | pub->public_key_hash = bytes + offset; | ||
842 | offset += 32; | ||
843 | pub->m = (void *)(bytes + offset); | ||
844 | offset += vector_size * external->rank; | ||
845 | } | ||
846 | |||
847 | struct private_key { | ||
848 | struct public_key pub; | ||
849 | scalar *s; | ||
850 | uint8_t *fo_failure_secret; | ||
851 | }; | ||
852 | |||
853 | static void | ||
854 | private_key_from_external(const MLKEM_private_key *external, | ||
855 | struct private_key *priv) | ||
856 | { | ||
857 | size_t vector_size = external->rank * sizeof(scalar); | ||
858 | size_t offset = 0; | ||
859 | uint8_t *bytes = external->key_768->bytes; | ||
860 | |||
861 | if (external->rank == RANK1024) | ||
862 | bytes = external->key_1024->bytes; | ||
863 | |||
864 | priv->pub.t = (struct scalar *)(bytes + offset); | ||
865 | offset += vector_size; | ||
866 | priv->pub.rho = bytes + offset; | ||
867 | offset += 32; | ||
868 | priv->pub.public_key_hash = bytes + offset; | ||
869 | offset += 32; | ||
870 | priv->pub.m = (void *)(bytes + offset); | ||
871 | offset += vector_size * external->rank; | ||
872 | priv->s = (void *)(bytes + offset); | ||
873 | offset += vector_size; | ||
874 | priv->fo_failure_secret = bytes + offset; | ||
875 | offset += 32; | ||
876 | } | ||
877 | |||
878 | /* | ||
879 | * Calls |mlkem_generate_key_external_entropy| with random bytes from | ||
880 | * |RAND_bytes|. | ||
881 | */ | ||
882 | int | ||
883 | mlkem_generate_key(uint8_t *out_encoded_public_key, | ||
884 | uint8_t optional_out_seed[MLKEM_SEED_LENGTH], | ||
885 | MLKEM_private_key *out_private_key) | ||
886 | { | ||
887 | uint8_t entropy_buf[MLKEM_SEED_LENGTH]; | ||
888 | uint8_t *entropy = optional_out_seed != NULL ? optional_out_seed : | ||
889 | entropy_buf; | ||
890 | |||
891 | arc4random_buf(entropy, MLKEM_SEED_LENGTH); | ||
892 | return mlkem_generate_key_external_entropy(out_encoded_public_key, | ||
893 | out_private_key, entropy); | ||
894 | } | ||
895 | |||
896 | int | ||
897 | mlkem_private_key_from_seed(const uint8_t *seed, size_t seed_len, | ||
898 | MLKEM_private_key *out_private_key) | ||
899 | { | ||
900 | uint8_t *public_key_buf = NULL; | ||
901 | size_t public_key_buf_len = out_private_key->rank == RANK768 ? | ||
902 | MLKEM768_PUBLIC_KEY_BYTES : MLKEM1024_PUBLIC_KEY_BYTES; | ||
903 | int ret = 0; | ||
904 | |||
905 | if (seed_len != MLKEM_SEED_LENGTH) { | ||
906 | goto err; | ||
907 | } | ||
908 | |||
909 | if ((public_key_buf = calloc(1, public_key_buf_len)) == NULL) | ||
910 | goto err; | ||
911 | |||
912 | ret = mlkem_generate_key_external_entropy(public_key_buf, | ||
913 | out_private_key, seed); | ||
914 | |||
915 | err: | ||
916 | freezero(public_key_buf, public_key_buf_len); | ||
917 | |||
918 | return ret; | ||
919 | } | ||
920 | |||
921 | static int | ||
922 | mlkem_marshal_public_key_internal(CBB *out, const struct public_key *pub, | ||
923 | size_t rank) | ||
924 | { | ||
925 | if (!vector_encode_cbb(out, &pub->t[0], kLog2Prime, rank)) | ||
926 | return 0; | ||
927 | return CBB_add_bytes(out, pub->rho, 32); | ||
928 | } | ||
929 | |||
930 | int | ||
931 | mlkem_generate_key_external_entropy(uint8_t *out_encoded_public_key, | ||
932 | MLKEM_private_key *out_private_key, | ||
933 | const uint8_t entropy[MLKEM_SEED_LENGTH]) | ||
934 | { | ||
935 | struct private_key priv; | ||
936 | uint8_t augmented_seed[33]; | ||
937 | uint8_t *rho, *sigma; | ||
938 | uint8_t counter = 0; | ||
939 | uint8_t hashed[64]; | ||
940 | scalar error[RANK1024]; | ||
941 | CBB cbb; | ||
942 | int ret = 0; | ||
943 | |||
944 | private_key_from_external(out_private_key, &priv); | ||
945 | memset(&cbb, 0, sizeof(CBB)); | ||
946 | memcpy(augmented_seed, entropy, 32); | ||
947 | augmented_seed[32] = out_private_key->rank; | ||
948 | hash_g(hashed, augmented_seed, 33); | ||
949 | rho = hashed; | ||
950 | sigma = hashed + 32; | ||
951 | memcpy(priv.pub.rho, hashed, 32); | ||
952 | matrix_expand(priv.pub.m, rho, out_private_key->rank); | ||
953 | vector_generate_secret_eta_2(priv.s, &counter, sigma, | ||
954 | out_private_key->rank); | ||
955 | vector_ntt(priv.s, out_private_key->rank); | ||
956 | vector_generate_secret_eta_2(&error[0], &counter, sigma, | ||
957 | out_private_key->rank); | ||
958 | vector_ntt(&error[0], out_private_key->rank); | ||
959 | matrix_mult_transpose(priv.pub.t, priv.pub.m, priv.s, | ||
960 | out_private_key->rank); | ||
961 | vector_add(priv.pub.t, &error[0], out_private_key->rank); | ||
962 | |||
963 | if (!CBB_init_fixed(&cbb, out_encoded_public_key, | ||
964 | out_private_key->rank == RANK768 ? MLKEM768_PUBLIC_KEY_BYTES : | ||
965 | MLKEM1024_PUBLIC_KEY_BYTES)) | ||
966 | goto err; | ||
967 | |||
968 | if (!mlkem_marshal_public_key_internal(&cbb, &priv.pub, | ||
969 | out_private_key->rank)) | ||
970 | goto err; | ||
971 | |||
972 | hash_h(priv.pub.public_key_hash, out_encoded_public_key, | ||
973 | out_private_key->rank == RANK768 ? MLKEM768_PUBLIC_KEY_BYTES : | ||
974 | MLKEM1024_PUBLIC_KEY_BYTES); | ||
975 | memcpy(priv.fo_failure_secret, entropy + 32, 32); | ||
976 | |||
977 | ret = 1; | ||
978 | |||
979 | err: | ||
980 | CBB_cleanup(&cbb); | ||
981 | |||
982 | return ret; | ||
983 | } | ||
984 | |||
985 | void | ||
986 | mlkem_public_from_private(const MLKEM_private_key *private_key, | ||
987 | MLKEM_public_key *out_public_key) | ||
988 | { | ||
989 | switch (private_key->rank) { | ||
990 | case RANK768: | ||
991 | memcpy(out_public_key->key_768->bytes, | ||
992 | private_key->key_768->bytes, | ||
993 | sizeof(struct MLKEM768_public_key)); | ||
994 | break; | ||
995 | case RANK1024: | ||
996 | memcpy(out_public_key->key_1024->bytes, | ||
997 | private_key->key_1024->bytes, | ||
998 | sizeof(struct MLKEM1024_public_key)); | ||
999 | break; | ||
1000 | } | ||
1001 | } | ||
1002 | |||
1003 | /* | ||
1004 | * Encrypts a message with given randomness to the ciphertext in |out|. Without | ||
1005 | * applying the Fujisaki-Okamoto transform this would not result in a CCA secure | ||
1006 | * scheme, since lattice schemes are vulnerable to decryption failure oracles. | ||
1007 | */ | ||
1008 | static void | ||
1009 | encrypt_cpa(uint8_t *out, const struct public_key *pub, | ||
1010 | const uint8_t message[32], const uint8_t randomness[32], | ||
1011 | size_t rank) | ||
1012 | { | ||
1013 | scalar secret[RANK1024], error[RANK1024], u[RANK1024]; | ||
1014 | scalar expanded_message, scalar_error; | ||
1015 | uint8_t counter = 0; | ||
1016 | uint8_t input[33]; | ||
1017 | scalar v; | ||
1018 | int u_bits = kDU768; | ||
1019 | int v_bits = kDV768; | ||
1020 | |||
1021 | if (rank == RANK1024) { | ||
1022 | u_bits = kDU1024; | ||
1023 | v_bits = kDV1024; | ||
1024 | } | ||
1025 | vector_generate_secret_eta_2(&secret[0], &counter, randomness, rank); | ||
1026 | vector_ntt(&secret[0], rank); | ||
1027 | vector_generate_secret_eta_2(&error[0], &counter, randomness, rank); | ||
1028 | memcpy(input, randomness, 32); | ||
1029 | input[32] = counter; | ||
1030 | scalar_centered_binomial_distribution_eta_2_with_prf(&scalar_error, | ||
1031 | input); | ||
1032 | matrix_mult(&u[0], pub->m, &secret[0], rank); | ||
1033 | vector_inverse_ntt(&u[0], rank); | ||
1034 | vector_add(&u[0], &error[0], rank); | ||
1035 | scalar_inner_product(&v, &pub->t[0], &secret[0], rank); | ||
1036 | scalar_inverse_ntt(&v); | ||
1037 | scalar_add(&v, &scalar_error); | ||
1038 | scalar_decode_1(&expanded_message, message); | ||
1039 | scalar_decompress(&expanded_message, 1); | ||
1040 | scalar_add(&v, &expanded_message); | ||
1041 | vector_compress(&u[0], u_bits, rank); | ||
1042 | vector_encode(out, &u[0], u_bits, rank); | ||
1043 | scalar_compress(&v, v_bits); | ||
1044 | scalar_encode(out + compressed_vector_size(rank), &v, v_bits); | ||
1045 | } | ||
1046 | |||
1047 | /* Calls mlkem_encap_external_entropy| with random bytes */ | ||
1048 | void | ||
1049 | mlkem_encap(const MLKEM_public_key *public_key, | ||
1050 | uint8_t *out_ciphertext, | ||
1051 | uint8_t out_shared_secret[MLKEM_SHARED_SECRET_LENGTH]) | ||
1052 | { | ||
1053 | uint8_t entropy[MLKEM_ENCAP_ENTROPY]; | ||
1054 | |||
1055 | arc4random_buf(entropy, MLKEM_ENCAP_ENTROPY); | ||
1056 | mlkem_encap_external_entropy(out_ciphertext, | ||
1057 | out_shared_secret, public_key, entropy); | ||
1058 | } | ||
1059 | |||
1060 | /* See section 6.2 of the spec. */ | ||
1061 | void | ||
1062 | mlkem_encap_external_entropy(uint8_t *out_ciphertext, | ||
1063 | uint8_t out_shared_secret[MLKEM_SHARED_SECRET_LENGTH], | ||
1064 | const MLKEM_public_key *public_key, | ||
1065 | const uint8_t entropy[MLKEM_ENCAP_ENTROPY]) | ||
1066 | { | ||
1067 | struct public_key pub; | ||
1068 | uint8_t key_and_randomness[64]; | ||
1069 | uint8_t input[64]; | ||
1070 | |||
1071 | public_key_from_external(public_key, &pub); | ||
1072 | memcpy(input, entropy, MLKEM_ENCAP_ENTROPY); | ||
1073 | memcpy(input + MLKEM_ENCAP_ENTROPY, pub.public_key_hash, | ||
1074 | sizeof(input) - MLKEM_ENCAP_ENTROPY); | ||
1075 | hash_g(key_and_randomness, input, sizeof(input)); | ||
1076 | encrypt_cpa(out_ciphertext, &pub, entropy, key_and_randomness + 32, | ||
1077 | public_key->rank); | ||
1078 | memcpy(out_shared_secret, key_and_randomness, 32); | ||
1079 | } | ||
1080 | |||
1081 | static void | ||
1082 | decrypt_cpa(uint8_t out[32], const struct private_key *priv, | ||
1083 | const uint8_t *ciphertext, size_t rank) | ||
1084 | { | ||
1085 | scalar u[RANK1024]; | ||
1086 | scalar mask, v; | ||
1087 | int u_bits = kDU768; | ||
1088 | int v_bits = kDV768; | ||
1089 | |||
1090 | if (rank == RANK1024) { | ||
1091 | u_bits = kDU1024; | ||
1092 | v_bits = kDV1024; | ||
1093 | } | ||
1094 | vector_decode(&u[0], ciphertext, u_bits, rank); | ||
1095 | vector_decompress(&u[0], u_bits, rank); | ||
1096 | vector_ntt(&u[0], rank); | ||
1097 | scalar_decode(&v, ciphertext + compressed_vector_size(rank), v_bits); | ||
1098 | scalar_decompress(&v, v_bits); | ||
1099 | scalar_inner_product(&mask, &priv->s[0], &u[0], rank); | ||
1100 | scalar_inverse_ntt(&mask); | ||
1101 | scalar_sub(&v, &mask); | ||
1102 | scalar_compress(&v, 1); | ||
1103 | scalar_encode_1(out, &v); | ||
1104 | } | ||
1105 | |||
1106 | /* See section 6.3 */ | ||
1107 | int | ||
1108 | mlkem_decap(const MLKEM_private_key *private_key, const uint8_t *ciphertext, | ||
1109 | size_t ciphertext_len, uint8_t out_shared_secret[MLKEM_SHARED_SECRET_LENGTH]) | ||
1110 | { | ||
1111 | struct private_key priv; | ||
1112 | size_t expected_ciphertext_length = private_key->rank == RANK768 ? | ||
1113 | MLKEM768_CIPHERTEXT_BYTES : MLKEM1024_CIPHERTEXT_BYTES; | ||
1114 | uint8_t *expected_ciphertext = NULL; | ||
1115 | uint8_t key_and_randomness[64]; | ||
1116 | uint8_t failure_key[32]; | ||
1117 | uint8_t decrypted[64]; | ||
1118 | uint8_t mask; | ||
1119 | int i; | ||
1120 | int ret = 0; | ||
1121 | |||
1122 | if (ciphertext_len != expected_ciphertext_length) { | ||
1123 | arc4random_buf(out_shared_secret, MLKEM_SHARED_SECRET_LENGTH); | ||
1124 | goto err; | ||
1125 | } | ||
1126 | |||
1127 | if ((expected_ciphertext = calloc(1, expected_ciphertext_length)) == | ||
1128 | NULL) { | ||
1129 | arc4random_buf(out_shared_secret, MLKEM_SHARED_SECRET_LENGTH); | ||
1130 | goto err; | ||
1131 | } | ||
1132 | |||
1133 | private_key_from_external(private_key, &priv); | ||
1134 | decrypt_cpa(decrypted, &priv, ciphertext, private_key->rank); | ||
1135 | memcpy(decrypted + 32, priv.pub.public_key_hash, | ||
1136 | sizeof(decrypted) - 32); | ||
1137 | hash_g(key_and_randomness, decrypted, sizeof(decrypted)); | ||
1138 | encrypt_cpa(expected_ciphertext, &priv.pub, decrypted, | ||
1139 | key_and_randomness + 32, private_key->rank); | ||
1140 | kdf(failure_key, priv.fo_failure_secret, ciphertext, ciphertext_len); | ||
1141 | mask = constant_time_eq_int_8(memcmp(ciphertext, expected_ciphertext, | ||
1142 | expected_ciphertext_length), 0); | ||
1143 | for (i = 0; i < MLKEM_SHARED_SECRET_LENGTH; i++) { | ||
1144 | out_shared_secret[i] = constant_time_select_8(mask, | ||
1145 | key_and_randomness[i], failure_key[i]); | ||
1146 | } | ||
1147 | |||
1148 | ret = 1; | ||
1149 | |||
1150 | err: | ||
1151 | freezero(expected_ciphertext, expected_ciphertext_length); | ||
1152 | |||
1153 | return ret; | ||
1154 | } | ||
1155 | |||
1156 | int | ||
1157 | mlkem_marshal_public_key(const MLKEM_public_key *public_key, | ||
1158 | uint8_t **output, size_t *output_len) | ||
1159 | { | ||
1160 | struct public_key pub; | ||
1161 | int ret = 0; | ||
1162 | CBB cbb; | ||
1163 | |||
1164 | if (!CBB_init(&cbb, public_key->rank == RANK768 ? | ||
1165 | MLKEM768_PUBLIC_KEY_BYTES : MLKEM1024_PUBLIC_KEY_BYTES)) | ||
1166 | goto err; | ||
1167 | public_key_from_external(public_key, &pub); | ||
1168 | if (!mlkem_marshal_public_key_internal(&cbb, &pub, public_key->rank)) | ||
1169 | goto err; | ||
1170 | if (!CBB_finish(&cbb, output, output_len)) | ||
1171 | goto err; | ||
1172 | |||
1173 | ret = 1; | ||
1174 | |||
1175 | err: | ||
1176 | CBB_cleanup(&cbb); | ||
1177 | |||
1178 | return ret; | ||
1179 | } | ||
1180 | |||
1181 | /* | ||
1182 | * mlkem_parse_public_key_no_hash parses |in| into |pub| but doesn't calculate | ||
1183 | * the value of |pub->public_key_hash|. | ||
1184 | */ | ||
1185 | static int | ||
1186 | mlkem_parse_public_key_no_hash(struct public_key *pub, CBS *in, size_t rank) | ||
1187 | { | ||
1188 | CBS t_bytes; | ||
1189 | |||
1190 | if (!CBS_get_bytes(in, &t_bytes, encoded_vector_size(rank))) | ||
1191 | return 0; | ||
1192 | if (!vector_decode(&pub->t[0], CBS_data(&t_bytes), kLog2Prime, rank)) | ||
1193 | return 0; | ||
1194 | |||
1195 | memcpy(pub->rho, CBS_data(in), 32); | ||
1196 | if (!CBS_skip(in, 32)) | ||
1197 | return 0; | ||
1198 | matrix_expand(pub->m, pub->rho, rank); | ||
1199 | return 1; | ||
1200 | } | ||
1201 | |||
1202 | int | ||
1203 | mlkem_parse_public_key(const uint8_t *input, size_t input_len, | ||
1204 | MLKEM_public_key *public_key) | ||
1205 | { | ||
1206 | struct public_key pub; | ||
1207 | CBS cbs; | ||
1208 | |||
1209 | public_key_from_external(public_key, &pub); | ||
1210 | CBS_init(&cbs, input, input_len); | ||
1211 | if (!mlkem_parse_public_key_no_hash(&pub, &cbs, public_key->rank)) | ||
1212 | return 0; | ||
1213 | if (CBS_len(&cbs) != 0) | ||
1214 | return 0; | ||
1215 | |||
1216 | hash_h(pub.public_key_hash, input, input_len); | ||
1217 | |||
1218 | return 1; | ||
1219 | } | ||
1220 | |||
1221 | int | ||
1222 | mlkem_marshal_private_key(const MLKEM_private_key *private_key, | ||
1223 | uint8_t **out_private_key, size_t *out_private_key_len) | ||
1224 | { | ||
1225 | struct private_key priv; | ||
1226 | size_t key_length = private_key->rank == RANK768 ? | ||
1227 | MLKEM768_PRIVATE_KEY_BYTES : MLKEM1024_PRIVATE_KEY_BYTES; | ||
1228 | CBB cbb; | ||
1229 | int ret = 0; | ||
1230 | |||
1231 | private_key_from_external(private_key, &priv); | ||
1232 | if (!CBB_init(&cbb, key_length)) | ||
1233 | goto err; | ||
1234 | |||
1235 | if (!vector_encode_cbb(&cbb, priv.s, kLog2Prime, private_key->rank)) | ||
1236 | goto err; | ||
1237 | if (!mlkem_marshal_public_key_internal(&cbb, &priv.pub, | ||
1238 | private_key->rank)) | ||
1239 | goto err; | ||
1240 | if (!CBB_add_bytes(&cbb, priv.pub.public_key_hash, 32)) | ||
1241 | goto err; | ||
1242 | if (!CBB_add_bytes(&cbb, priv.fo_failure_secret, 32)) | ||
1243 | goto err; | ||
1244 | |||
1245 | if (!CBB_finish(&cbb, out_private_key, out_private_key_len)) | ||
1246 | goto err; | ||
1247 | |||
1248 | ret = 1; | ||
1249 | |||
1250 | err: | ||
1251 | CBB_cleanup(&cbb); | ||
1252 | |||
1253 | return ret; | ||
1254 | } | ||
1255 | |||
1256 | int | ||
1257 | mlkem_parse_private_key(const uint8_t *input, size_t input_len, | ||
1258 | MLKEM_private_key *out_private_key) | ||
1259 | { | ||
1260 | struct private_key priv; | ||
1261 | CBS cbs, s_bytes; | ||
1262 | |||
1263 | private_key_from_external(out_private_key, &priv); | ||
1264 | CBS_init(&cbs, input, input_len); | ||
1265 | |||
1266 | if (!CBS_get_bytes(&cbs, &s_bytes, | ||
1267 | encoded_vector_size(out_private_key->rank))) | ||
1268 | return 0; | ||
1269 | if (!vector_decode(priv.s, CBS_data(&s_bytes), kLog2Prime, | ||
1270 | out_private_key->rank)) | ||
1271 | return 0; | ||
1272 | if (!mlkem_parse_public_key_no_hash(&priv.pub, &cbs, | ||
1273 | out_private_key->rank)) | ||
1274 | return 0; | ||
1275 | |||
1276 | memcpy(priv.pub.public_key_hash, CBS_data(&cbs), 32); | ||
1277 | if (!CBS_skip(&cbs, 32)) | ||
1278 | return 0; | ||
1279 | memcpy(priv.fo_failure_secret, CBS_data(&cbs), 32); | ||
1280 | if (!CBS_skip(&cbs, 32)) | ||
1281 | return 0; | ||
1282 | if (CBS_len(&cbs) != 0) | ||
1283 | return 0; | ||
1284 | |||
1285 | return 1; | ||
1286 | } | ||