diff options
author | beck <> | 2025-09-05 23:30:12 +0000 |
---|---|---|
committer | beck <> | 2025-09-05 23:30:12 +0000 |
commit | 21ce890cad6ae67e0d52f7bfdc44579df2bfc032 (patch) | |
tree | 696ffc96af9e6fa399dc93af7325749db45458c5 /src/lib/libcrypto/mlkem/mlkem_internal.c | |
parent | dc98dc450acd1ba6a9a274662e55679f5c93e5fa (diff) | |
download | openbsd-21ce890cad6ae67e0d52f7bfdc44579df2bfc032.tar.gz openbsd-21ce890cad6ae67e0d52f7bfdc44579df2bfc032.tar.bz2 openbsd-21ce890cad6ae67e0d52f7bfdc44579df2bfc032.zip |
Deduplicate the mlkem 768 and mlkem 1024 code.
This moves everything not public to mlkem_internal.c
removing the old files and doing some further cleanup
on the way.
With this landed mlkem is out of my stack and can be
changed without breaking my subsequent changes
ok tb@
Diffstat (limited to 'src/lib/libcrypto/mlkem/mlkem_internal.c')
-rw-r--r-- | src/lib/libcrypto/mlkem/mlkem_internal.c | 1286 |
1 files changed, 1286 insertions, 0 deletions
diff --git a/src/lib/libcrypto/mlkem/mlkem_internal.c b/src/lib/libcrypto/mlkem/mlkem_internal.c new file mode 100644 index 0000000000..653b2f332d --- /dev/null +++ b/src/lib/libcrypto/mlkem/mlkem_internal.c | |||
@@ -0,0 +1,1286 @@ | |||
1 | /* $OpenBSD: mlkem_internal.c,v 1.1 2025/09/05 23:30:12 beck Exp $ */ | ||
2 | /* | ||
3 | * Copyright (c) 2024, Google Inc. | ||
4 | * Copyright (c) 2024, 2025 Bob Beck <beck@obtuse.com> | ||
5 | * | ||
6 | * Permission to use, copy, modify, and/or distribute this software for any | ||
7 | * purpose with or without fee is hereby granted, provided that the above | ||
8 | * copyright notice and this permission notice appear in all copies. | ||
9 | * | ||
10 | * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES | ||
11 | * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF | ||
12 | * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY | ||
13 | * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES | ||
14 | * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION | ||
15 | * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN | ||
16 | * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. | ||
17 | */ | ||
18 | |||
19 | #include <assert.h> | ||
20 | #include <stdlib.h> | ||
21 | #include <string.h> | ||
22 | #include <stdio.h> | ||
23 | |||
24 | #include <openssl/mlkem.h> | ||
25 | |||
26 | #include "bytestring.h" | ||
27 | #include "sha3_internal.h" | ||
28 | #include "mlkem_internal.h" | ||
29 | #include "constant_time.h" | ||
30 | #include "crypto_internal.h" | ||
31 | |||
32 | /* | ||
33 | * See | ||
34 | * https://csrc.nist.gov/pubs/fips/203/final | ||
35 | */ | ||
36 | |||
37 | static void | ||
38 | prf(uint8_t *out, size_t out_len, const uint8_t in[33]) | ||
39 | { | ||
40 | sha3_ctx ctx; | ||
41 | shake256_init(&ctx); | ||
42 | shake_update(&ctx, in, 33); | ||
43 | shake_xof(&ctx); | ||
44 | shake_out(&ctx, out, out_len); | ||
45 | } | ||
46 | |||
47 | /* Section 4.1 */ | ||
48 | static void | ||
49 | hash_h(uint8_t out[32], const uint8_t *in, size_t len) | ||
50 | { | ||
51 | sha3_ctx ctx; | ||
52 | sha3_init(&ctx, 32); | ||
53 | sha3_update(&ctx, in, len); | ||
54 | sha3_final(out, &ctx); | ||
55 | } | ||
56 | |||
57 | static void | ||
58 | hash_g(uint8_t out[64], const uint8_t *in, size_t len) | ||
59 | { | ||
60 | sha3_ctx ctx; | ||
61 | sha3_init(&ctx, 64); | ||
62 | sha3_update(&ctx, in, len); | ||
63 | sha3_final(out, &ctx); | ||
64 | } | ||
65 | |||
66 | /* this is called 'J' in the spec */ | ||
67 | static void | ||
68 | kdf(uint8_t out[MLKEM_SHARED_SECRET_LENGTH], const uint8_t failure_secret[32], | ||
69 | const uint8_t *in, size_t len) | ||
70 | { | ||
71 | sha3_ctx ctx; | ||
72 | shake256_init(&ctx); | ||
73 | shake_update(&ctx, failure_secret, 32); | ||
74 | shake_update(&ctx, in, len); | ||
75 | shake_xof(&ctx); | ||
76 | shake_out(&ctx, out, MLKEM_SHARED_SECRET_LENGTH); | ||
77 | } | ||
78 | |||
79 | #define DEGREE 256 | ||
80 | |||
81 | static const size_t kBarrettMultiplier = 5039; | ||
82 | static const unsigned kBarrettShift = 24; | ||
83 | static const uint16_t kPrime = 3329; | ||
84 | static const int kLog2Prime = 12; | ||
85 | static const uint16_t kHalfPrime = (/*kPrime=*/3329 - 1) / 2; | ||
86 | static const int kDU768 = 10; | ||
87 | static const int kDV768 = 4; | ||
88 | static const int kDU1024 = 11; | ||
89 | static const int kDV1024 = 5; | ||
90 | |||
91 | /* | ||
92 | * kInverseDegree is 128^-1 mod 3329; 128 because kPrime does not have a 512th | ||
93 | * root of unity. | ||
94 | */ | ||
95 | static const uint16_t kInverseDegree = 3303; | ||
96 | |||
97 | static inline size_t | ||
98 | encoded_vector_size(uint16_t rank) | ||
99 | { | ||
100 | return (kLog2Prime * DEGREE / 8) * rank; | ||
101 | } | ||
102 | |||
103 | static inline size_t | ||
104 | compressed_vector_size(uint16_t rank) | ||
105 | { | ||
106 | return ((rank == RANK768) ? kDU768 : kDU1024) * rank * DEGREE / 8; | ||
107 | } | ||
108 | |||
109 | typedef struct scalar { | ||
110 | /* On every function entry and exit, 0 <= c < kPrime. */ | ||
111 | uint16_t c[DEGREE]; | ||
112 | } scalar; | ||
113 | |||
114 | /* | ||
115 | * Retrieve a const scalar from const matrix of |rank| at position [row][col] | ||
116 | */ | ||
117 | static inline const scalar * | ||
118 | const_m2s(const scalar *v, size_t row, size_t col, uint16_t rank) | ||
119 | { | ||
120 | return ((scalar *)v) + row * rank + col; | ||
121 | } | ||
122 | |||
123 | /* | ||
124 | * Retrieve a scalar from matrix of |rank| at position [row][col] | ||
125 | */ | ||
126 | static inline scalar * | ||
127 | m2s(scalar *v, size_t row, size_t col, uint16_t rank) | ||
128 | { | ||
129 | return ((scalar *)v) + row * rank + col; | ||
130 | } | ||
131 | |||
132 | /* | ||
133 | * This bit of Python will be referenced in some of the following comments: | ||
134 | * | ||
135 | * p = 3329 | ||
136 | * | ||
137 | * def bitreverse(i): | ||
138 | * ret = 0 | ||
139 | * for n in range(7): | ||
140 | * bit = i & 1 | ||
141 | * ret <<= 1 | ||
142 | * ret |= bit | ||
143 | * i >>= 1 | ||
144 | * return ret | ||
145 | */ | ||
146 | |||
147 | /* kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)] */ | ||
148 | static const uint16_t kNTTRoots[128] = { | ||
149 | 1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797, | ||
150 | 2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333, | ||
151 | 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756, | ||
152 | 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915, | ||
153 | 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648, | ||
154 | 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100, | ||
155 | 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789, | ||
156 | 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641, | ||
157 | 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757, | ||
158 | 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886, | ||
159 | 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154, | ||
160 | }; | ||
161 | |||
162 | /* kInverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)] */ | ||
163 | static const uint16_t kInverseNTTRoots[128] = { | ||
164 | 1, 1600, 40, 749, 2481, 1432, 2699, 687, 1583, 2760, 69, 543, | ||
165 | 2532, 3136, 1410, 2267, 2508, 1355, 450, 936, 447, 2794, 1235, 1903, | ||
166 | 1996, 1089, 3273, 283, 1853, 1990, 882, 3033, 2419, 2102, 219, 855, | ||
167 | 2681, 1848, 712, 682, 927, 1795, 461, 1891, 2877, 2522, 1894, 1010, | ||
168 | 1414, 2009, 3296, 464, 2697, 816, 1352, 2679, 1274, 1052, 1025, 2132, | ||
169 | 1573, 76, 2998, 3040, 1175, 2444, 394, 1219, 2300, 1455, 2117, 1607, | ||
170 | 2443, 554, 1179, 2186, 2303, 2926, 2237, 525, 735, 863, 2768, 1230, | ||
171 | 2572, 556, 3010, 2266, 1684, 1239, 780, 2954, 109, 1292, 1031, 1745, | ||
172 | 2688, 3061, 992, 2596, 941, 892, 1021, 2390, 642, 1868, 2377, 1482, | ||
173 | 1540, 540, 1678, 1626, 279, 314, 1173, 2573, 3096, 48, 667, 1920, | ||
174 | 2229, 1041, 2606, 1692, 680, 2746, 568, 3312, | ||
175 | }; | ||
176 | |||
177 | /* kModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)] */ | ||
178 | static const uint16_t kModRoots[128] = { | ||
179 | 17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606, | ||
180 | 2288, 1041, 1100, 2229, 1409, 1920, 2662, 667, 3281, 48, 233, 3096, | ||
181 | 756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678, | ||
182 | 2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642, | ||
183 | 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992, | ||
184 | 268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109, | ||
185 | 375, 2954, 2549, 780, 2090, 1239, 1645, 1684, 1063, 2266, 319, 3010, | ||
186 | 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735, | ||
187 | 2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179, | ||
188 | 2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300, | ||
189 | 2110, 1219, 2935, 394, 885, 2444, 2154, 1175, | ||
190 | }; | ||
191 | |||
192 | /* reduce_once reduces 0 <= x < 2*kPrime, mod kPrime. */ | ||
193 | static uint16_t | ||
194 | reduce_once(uint16_t x) | ||
195 | { | ||
196 | assert(x < 2 * kPrime); | ||
197 | const uint16_t subtracted = x - kPrime; | ||
198 | uint16_t mask = 0u - (subtracted >> 15); | ||
199 | |||
200 | /* | ||
201 | * Although this is a constant-time select, we omit a value barrier here. | ||
202 | * Value barriers impede auto-vectorization (likely because it forces the | ||
203 | * value to transit through a general-purpose register). On AArch64, this | ||
204 | * is a difference of 2x. | ||
205 | * | ||
206 | * We usually add value barriers to selects because Clang turns | ||
207 | * consecutive selects with the same condition into a branch instead of | ||
208 | * CMOV/CSEL. This condition does not occur in ML-KEM, so omitting it | ||
209 | * seems to be safe so far but see | ||
210 | * |scalar_centered_binomial_distribution_eta_2_with_prf|. | ||
211 | */ | ||
212 | return (mask & x) | (~mask & subtracted); | ||
213 | } | ||
214 | |||
215 | /* | ||
216 | * constant time reduce x mod kPrime using Barrett reduction. x must be less | ||
217 | * than kPrime + 2×kPrime². | ||
218 | */ | ||
219 | static uint16_t | ||
220 | reduce(uint32_t x) | ||
221 | { | ||
222 | uint64_t product = (uint64_t)x * kBarrettMultiplier; | ||
223 | uint32_t quotient = (uint32_t)(product >> kBarrettShift); | ||
224 | uint32_t remainder = x - quotient * kPrime; | ||
225 | |||
226 | assert(x < kPrime + 2u * kPrime * kPrime); | ||
227 | return reduce_once(remainder); | ||
228 | } | ||
229 | |||
230 | static void | ||
231 | scalar_zero(scalar *out) | ||
232 | { | ||
233 | memset(out, 0, sizeof(*out)); | ||
234 | } | ||
235 | |||
236 | static void | ||
237 | vector_zero(scalar *out, size_t rank) | ||
238 | { | ||
239 | memset(out, 0, sizeof(*out) * rank); | ||
240 | } | ||
241 | |||
242 | /* | ||
243 | * In place number theoretic transform of a given scalar. | ||
244 | * Note that MLKEM's kPrime 3329 does not have a 512th root of unity, so this | ||
245 | * transform leaves off the last iteration of the usual FFT code, with the 128 | ||
246 | * relevant roots of unity being stored in |kNTTRoots|. This means the output | ||
247 | * should be seen as 128 elements in GF(3329^2), with the coefficients of the | ||
248 | * elements being consecutive entries in |s->c|. | ||
249 | */ | ||
250 | static void | ||
251 | scalar_ntt(scalar *s) | ||
252 | { | ||
253 | int offset = DEGREE; | ||
254 | int step; | ||
255 | /* | ||
256 | * `int` is used here because using `size_t` throughout caused a ~5% slowdown | ||
257 | * with Clang 14 on Aarch64. | ||
258 | */ | ||
259 | for (step = 1; step < DEGREE / 2; step <<= 1) { | ||
260 | int i, j, k = 0; | ||
261 | |||
262 | offset >>= 1; | ||
263 | for (i = 0; i < step; i++) { | ||
264 | const uint32_t step_root = kNTTRoots[i + step]; | ||
265 | |||
266 | for (j = k; j < k + offset; j++) { | ||
267 | uint16_t odd, even; | ||
268 | |||
269 | odd = reduce(step_root * s->c[j + offset]); | ||
270 | even = s->c[j]; | ||
271 | s->c[j] = reduce_once(odd + even); | ||
272 | s->c[j + offset] = reduce_once(even - odd + | ||
273 | kPrime); | ||
274 | } | ||
275 | k += 2 * offset; | ||
276 | } | ||
277 | } | ||
278 | } | ||
279 | |||
280 | static void | ||
281 | vector_ntt(scalar *v, size_t rank) | ||
282 | { | ||
283 | size_t i; | ||
284 | |||
285 | for (i = 0; i < rank; i++) { | ||
286 | scalar_ntt(&v[i]); | ||
287 | } | ||
288 | } | ||
289 | |||
290 | /* | ||
291 | * In place inverse number theoretic transform of a given scalar, with pairs of | ||
292 | * entries of s->v being interpreted as elements of GF(3329^2). Just as with the | ||
293 | * number theoretic transform, this leaves off the first step of the normal iFFT | ||
294 | * to account for the fact that 3329 does not have a 512th root of unity, using | ||
295 | * the precomputed 128 roots of unity stored in |kInverseNTTRoots|. | ||
296 | */ | ||
297 | static void | ||
298 | scalar_inverse_ntt(scalar *s) | ||
299 | { | ||
300 | int i, j, k, offset, step = DEGREE / 2; | ||
301 | |||
302 | /* | ||
303 | * `int` is used here because using `size_t` throughout caused a ~5% slowdown | ||
304 | * with Clang 14 on Aarch64. | ||
305 | */ | ||
306 | for (offset = 2; offset < DEGREE; offset <<= 1) { | ||
307 | step >>= 1; | ||
308 | k = 0; | ||
309 | for (i = 0; i < step; i++) { | ||
310 | uint32_t step_root = kInverseNTTRoots[i + step]; | ||
311 | for (j = k; j < k + offset; j++) { | ||
312 | uint16_t odd, even; | ||
313 | odd = s->c[j + offset]; | ||
314 | even = s->c[j]; | ||
315 | s->c[j] = reduce_once(odd + even); | ||
316 | s->c[j + offset] = reduce(step_root * | ||
317 | (even - odd + kPrime)); | ||
318 | } | ||
319 | k += 2 * offset; | ||
320 | } | ||
321 | } | ||
322 | for (i = 0; i < DEGREE; i++) { | ||
323 | s->c[i] = reduce(s->c[i] * kInverseDegree); | ||
324 | } | ||
325 | } | ||
326 | |||
327 | static void | ||
328 | vector_inverse_ntt(scalar *v, size_t rank) | ||
329 | { | ||
330 | size_t i; | ||
331 | |||
332 | for (i = 0; i < rank; i++) { | ||
333 | scalar_inverse_ntt(&v[i]); | ||
334 | } | ||
335 | } | ||
336 | |||
337 | static void | ||
338 | scalar_add(scalar *lhs, const scalar *rhs) | ||
339 | { | ||
340 | int i; | ||
341 | |||
342 | for (i = 0; i < DEGREE; i++) { | ||
343 | lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]); | ||
344 | } | ||
345 | } | ||
346 | |||
347 | static void | ||
348 | scalar_sub(scalar *lhs, const scalar *rhs) | ||
349 | { | ||
350 | int i; | ||
351 | |||
352 | for (i = 0; i < DEGREE; i++) { | ||
353 | lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime); | ||
354 | } | ||
355 | } | ||
356 | |||
357 | /* | ||
358 | * Multiplying two scalars in the number theoretically transformed state. | ||
359 | * Since 3329 does not have a 512th root of unity, this means we have to | ||
360 | * interpret the 2*ith and (2*i+1)th entries of the scalar as elements of | ||
361 | * GF(3329)[X]/(X^2 - 17^(2*bitreverse(i)+1)). | ||
362 | * The value of 17^(2*bitreverse(i)+1) mod 3329 is stored in the precomputed | ||
363 | * |kModRoots| table. Our Barrett transform only allows us to multiply two | ||
364 | * reduced numbers together, so we need some intermediate reduction steps, | ||
365 | * even if an uint64_t could hold 3 multiplied numbers. | ||
366 | */ | ||
367 | static void | ||
368 | scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs) | ||
369 | { | ||
370 | int i; | ||
371 | |||
372 | for (i = 0; i < DEGREE / 2; i++) { | ||
373 | uint32_t real_real = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i]; | ||
374 | uint32_t img_img = (uint32_t)lhs->c[2 * i + 1] * | ||
375 | rhs->c[2 * i + 1]; | ||
376 | uint32_t real_img = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i + 1]; | ||
377 | uint32_t img_real = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i]; | ||
378 | |||
379 | out->c[2 * i] = | ||
380 | reduce(real_real + | ||
381 | (uint32_t)reduce(img_img) * kModRoots[i]); | ||
382 | out->c[2 * i + 1] = reduce(img_real + real_img); | ||
383 | } | ||
384 | } | ||
385 | |||
386 | static void | ||
387 | vector_add(scalar *lhs, const scalar *rhs, size_t rank) | ||
388 | { | ||
389 | size_t i; | ||
390 | |||
391 | for (i = 0; i < rank; i++) { | ||
392 | scalar_add(&lhs[i], &rhs[i]); | ||
393 | } | ||
394 | } | ||
395 | |||
396 | static void | ||
397 | matrix_mult(scalar *out, const void *m, const scalar *a, size_t rank) | ||
398 | { | ||
399 | size_t i, j; | ||
400 | |||
401 | vector_zero(&out[0], rank); | ||
402 | for (i = 0; i < rank; i++) { | ||
403 | for (j = 0; j < rank; j++) { | ||
404 | scalar product; | ||
405 | |||
406 | scalar_mult(&product, const_m2s(m, i, j, rank), &a[j]); | ||
407 | scalar_add(&out[i], &product); | ||
408 | } | ||
409 | } | ||
410 | } | ||
411 | |||
412 | static void | ||
413 | matrix_mult_transpose(scalar *out, const void *m, const scalar *a, size_t rank) | ||
414 | { | ||
415 | int i, j; | ||
416 | |||
417 | vector_zero(&out[0], rank); | ||
418 | for (i = 0; i < rank; i++) { | ||
419 | for (j = 0; j < rank; j++) { | ||
420 | scalar product; | ||
421 | |||
422 | scalar_mult(&product, const_m2s(m, j, i, rank), &a[j]); | ||
423 | scalar_add(&out[i], &product); | ||
424 | } | ||
425 | } | ||
426 | } | ||
427 | |||
428 | static void | ||
429 | scalar_inner_product(scalar *out, const scalar *lhs, | ||
430 | const scalar *rhs, size_t rank) | ||
431 | { | ||
432 | size_t i; | ||
433 | |||
434 | scalar_zero(out); | ||
435 | for (i = 0; i < rank; i++) { | ||
436 | scalar product; | ||
437 | |||
438 | scalar_mult(&product, &lhs[i], &rhs[i]); | ||
439 | scalar_add(out, &product); | ||
440 | } | ||
441 | } | ||
442 | |||
443 | /* | ||
444 | * Algorithm 6 of spec. Rejection samples a Keccak stream to get uniformly | ||
445 | * distributed elements. This is used for matrix expansion and only operates on | ||
446 | * public inputs. | ||
447 | */ | ||
448 | static void | ||
449 | scalar_from_keccak_vartime(scalar *out, sha3_ctx *keccak_ctx) | ||
450 | { | ||
451 | int i, done = 0; | ||
452 | |||
453 | while (done < DEGREE) { | ||
454 | uint8_t block[168]; | ||
455 | |||
456 | shake_out(keccak_ctx, block, sizeof(block)); | ||
457 | for (i = 0; i < sizeof(block) && done < DEGREE; i += 3) { | ||
458 | uint16_t d1 = block[i] + 256 * (block[i + 1] % 16); | ||
459 | uint16_t d2 = block[i + 1] / 16 + 16 * block[i + 2]; | ||
460 | |||
461 | if (d1 < kPrime) { | ||
462 | out->c[done++] = d1; | ||
463 | } | ||
464 | if (d2 < kPrime && done < DEGREE) { | ||
465 | out->c[done++] = d2; | ||
466 | } | ||
467 | } | ||
468 | } | ||
469 | } | ||
470 | |||
471 | /* | ||
472 | * Algorithm 7 of the spec, with eta fixed to two and the PRF call | ||
473 | * included. Creates binominally distributed elements by sampling 2*|eta| bits, | ||
474 | * and setting the coefficient to the count of the first bits minus the count of | ||
475 | * the second bits, resulting in a centered binomial distribution. Since eta is | ||
476 | * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4, | ||
477 | * and 0 with probability 3/8. | ||
478 | */ | ||
479 | static void | ||
480 | scalar_centered_binomial_distribution_eta_2_with_prf(scalar *out, | ||
481 | const uint8_t input[33]) | ||
482 | { | ||
483 | uint8_t entropy[128]; | ||
484 | int i; | ||
485 | |||
486 | CTASSERT(sizeof(entropy) == 2 * /*kEta=*/ 2 * DEGREE / 8); | ||
487 | prf(entropy, sizeof(entropy), input); | ||
488 | |||
489 | for (i = 0; i < DEGREE; i += 2) { | ||
490 | uint8_t byte = entropy[i / 2]; | ||
491 | uint16_t mask; | ||
492 | uint16_t value = (byte & 1) + ((byte >> 1) & 1); | ||
493 | |||
494 | value -= ((byte >> 2) & 1) + ((byte >> 3) & 1); | ||
495 | |||
496 | /* | ||
497 | * Add |kPrime| if |value| underflowed. See |reduce_once| for a | ||
498 | * discussion on why the value barrier is omitted. While this | ||
499 | * could have been written reduce_once(value + kPrime), this is | ||
500 | * one extra addition and small range of |value| tempts some | ||
501 | * versions of Clang to emit a branch. | ||
502 | */ | ||
503 | mask = 0u - (value >> 15); | ||
504 | out->c[i] = ((value + kPrime) & mask) | (value & ~mask); | ||
505 | |||
506 | byte >>= 4; | ||
507 | value = (byte & 1) + ((byte >> 1) & 1); | ||
508 | value -= ((byte >> 2) & 1) + ((byte >> 3) & 1); | ||
509 | /* See above. */ | ||
510 | mask = 0u - (value >> 15); | ||
511 | out->c[i + 1] = ((value + kPrime) & mask) | (value & ~mask); | ||
512 | } | ||
513 | } | ||
514 | |||
515 | /* | ||
516 | * Generates a secret vector by using | ||
517 | * |scalar_centered_binomial_distribution_eta_2_with_prf|, using the given seed | ||
518 | * appending and incrementing |counter| for entry of the vector. | ||
519 | */ | ||
520 | static void | ||
521 | vector_generate_secret_eta_2(scalar *out, uint8_t *counter, | ||
522 | const uint8_t seed[32], size_t rank) | ||
523 | { | ||
524 | uint8_t input[33]; | ||
525 | size_t i; | ||
526 | |||
527 | memcpy(input, seed, 32); | ||
528 | for (i = 0; i < rank; i++) { | ||
529 | input[32] = (*counter)++; | ||
530 | scalar_centered_binomial_distribution_eta_2_with_prf(&out[i], | ||
531 | input); | ||
532 | } | ||
533 | } | ||
534 | |||
535 | /* Expands the matrix of a seed for key generation and for encaps-CPA. */ | ||
536 | static void | ||
537 | matrix_expand(void *out, const uint8_t rho[32], size_t rank) | ||
538 | { | ||
539 | uint8_t input[34]; | ||
540 | size_t i, j; | ||
541 | |||
542 | memcpy(input, rho, 32); | ||
543 | for (i = 0; i < rank; i++) { | ||
544 | for (j = 0; j < rank; j++) { | ||
545 | sha3_ctx keccak_ctx; | ||
546 | |||
547 | input[32] = i; | ||
548 | input[33] = j; | ||
549 | shake128_init(&keccak_ctx); | ||
550 | shake_update(&keccak_ctx, input, sizeof(input)); | ||
551 | shake_xof(&keccak_ctx); | ||
552 | scalar_from_keccak_vartime(m2s(out, i, j, rank), | ||
553 | &keccak_ctx); | ||
554 | } | ||
555 | } | ||
556 | } | ||
557 | |||
558 | static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f, | ||
559 | 0x1f, 0x3f, 0x7f, 0xff}; | ||
560 | |||
561 | static void | ||
562 | scalar_encode(uint8_t *out, const scalar *s, int bits) | ||
563 | { | ||
564 | uint8_t out_byte = 0; | ||
565 | int i, out_byte_bits = 0; | ||
566 | |||
567 | assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1); | ||
568 | for (i = 0; i < DEGREE; i++) { | ||
569 | uint16_t element = s->c[i]; | ||
570 | int element_bits_done = 0; | ||
571 | |||
572 | while (element_bits_done < bits) { | ||
573 | int chunk_bits = bits - element_bits_done; | ||
574 | int out_bits_remaining = 8 - out_byte_bits; | ||
575 | |||
576 | if (chunk_bits >= out_bits_remaining) { | ||
577 | chunk_bits = out_bits_remaining; | ||
578 | out_byte |= (element & | ||
579 | kMasks[chunk_bits - 1]) << out_byte_bits; | ||
580 | *out = out_byte; | ||
581 | out++; | ||
582 | out_byte_bits = 0; | ||
583 | out_byte = 0; | ||
584 | } else { | ||
585 | out_byte |= (element & | ||
586 | kMasks[chunk_bits - 1]) << out_byte_bits; | ||
587 | out_byte_bits += chunk_bits; | ||
588 | } | ||
589 | |||
590 | element_bits_done += chunk_bits; | ||
591 | element >>= chunk_bits; | ||
592 | } | ||
593 | } | ||
594 | |||
595 | if (out_byte_bits > 0) { | ||
596 | *out = out_byte; | ||
597 | } | ||
598 | } | ||
599 | |||
600 | /* scalar_encode_1 is |scalar_encode| specialised for |bits| == 1. */ | ||
601 | static void | ||
602 | scalar_encode_1(uint8_t out[32], const scalar *s) | ||
603 | { | ||
604 | int i, j; | ||
605 | |||
606 | for (i = 0; i < DEGREE; i += 8) { | ||
607 | uint8_t out_byte = 0; | ||
608 | |||
609 | for (j = 0; j < 8; j++) { | ||
610 | out_byte |= (s->c[i + j] & 1) << j; | ||
611 | } | ||
612 | *out = out_byte; | ||
613 | out++; | ||
614 | } | ||
615 | } | ||
616 | |||
617 | /* | ||
618 | * Encodes an entire vector into 32*|RANK768|*|bits| bytes. Note that since 256 | ||
619 | * (DEGREE) is divisible by 8, the individual vector entries will always fill a | ||
620 | * whole number of bytes, so we do not need to worry about bit packing here. | ||
621 | */ | ||
622 | static void | ||
623 | vector_encode(uint8_t *out, const scalar *a, int bits, size_t rank) | ||
624 | { | ||
625 | int i; | ||
626 | |||
627 | for (i = 0; i < rank; i++) { | ||
628 | scalar_encode(out + i * bits * DEGREE / 8, &a[i], bits); | ||
629 | } | ||
630 | } | ||
631 | |||
632 | /* Encodes an entire vector as above, but adding it to a CBB */ | ||
633 | static int | ||
634 | vector_encode_cbb(CBB *cbb, const scalar *a, int bits, size_t rank) | ||
635 | { | ||
636 | uint8_t *encoded_vector; | ||
637 | |||
638 | if (!CBB_add_space(cbb, &encoded_vector, encoded_vector_size(rank))) | ||
639 | return 0; | ||
640 | vector_encode(encoded_vector, a, bits, rank); | ||
641 | |||
642 | return 1; | ||
643 | } | ||
644 | |||
645 | /* | ||
646 | * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in | ||
647 | * |out|. It returns one on success and zero if any parsed value is >= | ||
648 | * |kPrime|. | ||
649 | */ | ||
650 | static int | ||
651 | scalar_decode(scalar *out, const uint8_t *in, int bits) | ||
652 | { | ||
653 | uint8_t in_byte = 0; | ||
654 | int i, in_byte_bits_left = 0; | ||
655 | |||
656 | assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1); | ||
657 | |||
658 | for (i = 0; i < DEGREE; i++) { | ||
659 | uint16_t element = 0; | ||
660 | int element_bits_done = 0; | ||
661 | |||
662 | while (element_bits_done < bits) { | ||
663 | int chunk_bits = bits - element_bits_done; | ||
664 | |||
665 | if (in_byte_bits_left == 0) { | ||
666 | in_byte = *in; | ||
667 | in++; | ||
668 | in_byte_bits_left = 8; | ||
669 | } | ||
670 | |||
671 | if (chunk_bits > in_byte_bits_left) { | ||
672 | chunk_bits = in_byte_bits_left; | ||
673 | } | ||
674 | |||
675 | element |= (in_byte & kMasks[chunk_bits - 1]) << | ||
676 | element_bits_done; | ||
677 | in_byte_bits_left -= chunk_bits; | ||
678 | in_byte >>= chunk_bits; | ||
679 | |||
680 | element_bits_done += chunk_bits; | ||
681 | } | ||
682 | |||
683 | if (element >= kPrime) { | ||
684 | return 0; | ||
685 | } | ||
686 | out->c[i] = element; | ||
687 | } | ||
688 | |||
689 | return 1; | ||
690 | } | ||
691 | |||
692 | /* scalar_decode_1 is |scalar_decode| specialised for |bits| == 1. */ | ||
693 | static void | ||
694 | scalar_decode_1(scalar *out, const uint8_t in[32]) | ||
695 | { | ||
696 | int i, j; | ||
697 | |||
698 | for (i = 0; i < DEGREE; i += 8) { | ||
699 | uint8_t in_byte = *in; | ||
700 | |||
701 | in++; | ||
702 | for (j = 0; j < 8; j++) { | ||
703 | out->c[i + j] = in_byte & 1; | ||
704 | in_byte >>= 1; | ||
705 | } | ||
706 | } | ||
707 | } | ||
708 | |||
709 | /* | ||
710 | * Decodes 32*|RANK768|*|bits| bytes from |in| into |out|. It returns one on | ||
711 | * success or zero if any parsed value is >= |kPrime|. | ||
712 | */ | ||
713 | static int | ||
714 | vector_decode(scalar *out, const uint8_t *in, int bits, size_t rank) | ||
715 | { | ||
716 | size_t i; | ||
717 | |||
718 | for (i = 0; i < rank; i++) { | ||
719 | if (!scalar_decode(&out[i], in + i * bits * DEGREE / 8, | ||
720 | bits)) { | ||
721 | return 0; | ||
722 | } | ||
723 | } | ||
724 | return 1; | ||
725 | } | ||
726 | |||
727 | /* | ||
728 | * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping | ||
729 | * numbers close to each other together. The formula used is | ||
730 | * round(2^|bits|/kPrime*x) mod 2^|bits|. | ||
731 | * Uses Barrett reduction to achieve constant time. Since we need both the | ||
732 | * remainder (for rounding) and the quotient (as the result), we cannot use | ||
733 | * |reduce| here, but need to do the Barrett reduction directly. | ||
734 | */ | ||
735 | static uint16_t | ||
736 | compress(uint16_t x, int bits) | ||
737 | { | ||
738 | uint32_t shifted = (uint32_t)x << bits; | ||
739 | uint64_t product = (uint64_t)shifted * kBarrettMultiplier; | ||
740 | uint32_t quotient = (uint32_t)(product >> kBarrettShift); | ||
741 | uint32_t remainder = shifted - quotient * kPrime; | ||
742 | |||
743 | /* | ||
744 | * Adjust the quotient to round correctly: | ||
745 | * 0 <= remainder <= kHalfPrime round to 0 | ||
746 | * kHalfPrime < remainder <= kPrime + kHalfPrime round to 1 | ||
747 | * kPrime + kHalfPrime < remainder < 2 * kPrime round to 2 | ||
748 | */ | ||
749 | assert(remainder < 2u * kPrime); | ||
750 | quotient += 1 & constant_time_lt(kHalfPrime, remainder); | ||
751 | quotient += 1 & constant_time_lt(kPrime + kHalfPrime, remainder); | ||
752 | return quotient & ((1 << bits) - 1); | ||
753 | } | ||
754 | |||
755 | /* | ||
756 | * Decompresses |x| by using an equi-distant representative. The formula is | ||
757 | * round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us to | ||
758 | * implement this logic using only bit operations. | ||
759 | */ | ||
760 | static uint16_t | ||
761 | decompress(uint16_t x, int bits) | ||
762 | { | ||
763 | uint32_t product = (uint32_t)x * kPrime; | ||
764 | uint32_t power = 1 << bits; | ||
765 | /* This is |product| % power, since |power| is a power of 2. */ | ||
766 | uint32_t remainder = product & (power - 1); | ||
767 | /* This is |product| / power, since |power| is a power of 2. */ | ||
768 | uint32_t lower = product >> bits; | ||
769 | |||
770 | /* | ||
771 | * The rounding logic works since the first half of numbers mod |power| have a | ||
772 | * 0 as first bit, and the second half has a 1 as first bit, since |power| is | ||
773 | * a power of 2. As a 12 bit number, |remainder| is always positive, so we | ||
774 | * will shift in 0s for a right shift. | ||
775 | */ | ||
776 | return lower + (remainder >> (bits - 1)); | ||
777 | } | ||
778 | |||
779 | static void | ||
780 | scalar_compress(scalar *s, int bits) | ||
781 | { | ||
782 | int i; | ||
783 | |||
784 | for (i = 0; i < DEGREE; i++) { | ||
785 | s->c[i] = compress(s->c[i], bits); | ||
786 | } | ||
787 | } | ||
788 | |||
789 | static void | ||
790 | scalar_decompress(scalar *s, int bits) | ||
791 | { | ||
792 | int i; | ||
793 | |||
794 | for (i = 0; i < DEGREE; i++) { | ||
795 | s->c[i] = decompress(s->c[i], bits); | ||
796 | } | ||
797 | } | ||
798 | |||
799 | static void | ||
800 | vector_compress(scalar *v, int bits, size_t rank) | ||
801 | { | ||
802 | size_t i; | ||
803 | |||
804 | for (i = 0; i < rank; i++) { | ||
805 | scalar_compress(&v[i], bits); | ||
806 | } | ||
807 | } | ||
808 | |||
809 | static void | ||
810 | vector_decompress(scalar *v, int bits, size_t rank) | ||
811 | { | ||
812 | int i; | ||
813 | |||
814 | for (i = 0; i < rank; i++) { | ||
815 | scalar_decompress(&v[i], bits); | ||
816 | } | ||
817 | } | ||
818 | |||
819 | struct public_key { | ||
820 | scalar *t; | ||
821 | uint8_t *rho; | ||
822 | uint8_t *public_key_hash; | ||
823 | scalar *m; | ||
824 | }; | ||
825 | |||
826 | static void | ||
827 | public_key_from_external(const MLKEM_public_key *external, | ||
828 | struct public_key *pub) | ||
829 | { | ||
830 | size_t vector_size = external->rank * sizeof(scalar); | ||
831 | uint8_t *bytes = external->key_768->bytes; | ||
832 | size_t offset = 0; | ||
833 | |||
834 | if (external->rank == RANK1024) | ||
835 | bytes = external->key_1024->bytes; | ||
836 | |||
837 | pub->t = (struct scalar *)bytes + offset; | ||
838 | offset += vector_size; | ||
839 | pub->rho = bytes + offset; | ||
840 | offset += 32; | ||
841 | pub->public_key_hash = bytes + offset; | ||
842 | offset += 32; | ||
843 | pub->m = (void *)(bytes + offset); | ||
844 | offset += vector_size * external->rank; | ||
845 | } | ||
846 | |||
847 | struct private_key { | ||
848 | struct public_key pub; | ||
849 | scalar *s; | ||
850 | uint8_t *fo_failure_secret; | ||
851 | }; | ||
852 | |||
853 | static void | ||
854 | private_key_from_external(const MLKEM_private_key *external, | ||
855 | struct private_key *priv) | ||
856 | { | ||
857 | size_t vector_size = external->rank * sizeof(scalar); | ||
858 | size_t offset = 0; | ||
859 | uint8_t *bytes = external->key_768->bytes; | ||
860 | |||
861 | if (external->rank == RANK1024) | ||
862 | bytes = external->key_1024->bytes; | ||
863 | |||
864 | priv->pub.t = (struct scalar *)(bytes + offset); | ||
865 | offset += vector_size; | ||
866 | priv->pub.rho = bytes + offset; | ||
867 | offset += 32; | ||
868 | priv->pub.public_key_hash = bytes + offset; | ||
869 | offset += 32; | ||
870 | priv->pub.m = (void *)(bytes + offset); | ||
871 | offset += vector_size * external->rank; | ||
872 | priv->s = (void *)(bytes + offset); | ||
873 | offset += vector_size; | ||
874 | priv->fo_failure_secret = bytes + offset; | ||
875 | offset += 32; | ||
876 | } | ||
877 | |||
878 | /* | ||
879 | * Calls |mlkem_generate_key_external_entropy| with random bytes from | ||
880 | * |RAND_bytes|. | ||
881 | */ | ||
882 | int | ||
883 | mlkem_generate_key(uint8_t *out_encoded_public_key, | ||
884 | uint8_t optional_out_seed[MLKEM_SEED_LENGTH], | ||
885 | MLKEM_private_key *out_private_key) | ||
886 | { | ||
887 | uint8_t entropy_buf[MLKEM_SEED_LENGTH]; | ||
888 | uint8_t *entropy = optional_out_seed != NULL ? optional_out_seed : | ||
889 | entropy_buf; | ||
890 | |||
891 | arc4random_buf(entropy, MLKEM_SEED_LENGTH); | ||
892 | return mlkem_generate_key_external_entropy(out_encoded_public_key, | ||
893 | out_private_key, entropy); | ||
894 | } | ||
895 | |||
896 | int | ||
897 | mlkem_private_key_from_seed(const uint8_t *seed, size_t seed_len, | ||
898 | MLKEM_private_key *out_private_key) | ||
899 | { | ||
900 | uint8_t *public_key_buf = NULL; | ||
901 | size_t public_key_buf_len = out_private_key->rank == RANK768 ? | ||
902 | MLKEM768_PUBLIC_KEY_BYTES : MLKEM1024_PUBLIC_KEY_BYTES; | ||
903 | int ret = 0; | ||
904 | |||
905 | if (seed_len != MLKEM_SEED_LENGTH) { | ||
906 | goto err; | ||
907 | } | ||
908 | |||
909 | if ((public_key_buf = calloc(1, public_key_buf_len)) == NULL) | ||
910 | goto err; | ||
911 | |||
912 | ret = mlkem_generate_key_external_entropy(public_key_buf, | ||
913 | out_private_key, seed); | ||
914 | |||
915 | err: | ||
916 | freezero(public_key_buf, public_key_buf_len); | ||
917 | |||
918 | return ret; | ||
919 | } | ||
920 | |||
921 | static int | ||
922 | mlkem_marshal_public_key_internal(CBB *out, const struct public_key *pub, | ||
923 | size_t rank) | ||
924 | { | ||
925 | if (!vector_encode_cbb(out, &pub->t[0], kLog2Prime, rank)) | ||
926 | return 0; | ||
927 | return CBB_add_bytes(out, pub->rho, 32); | ||
928 | } | ||
929 | |||
930 | int | ||
931 | mlkem_generate_key_external_entropy(uint8_t *out_encoded_public_key, | ||
932 | MLKEM_private_key *out_private_key, | ||
933 | const uint8_t entropy[MLKEM_SEED_LENGTH]) | ||
934 | { | ||
935 | struct private_key priv; | ||
936 | uint8_t augmented_seed[33]; | ||
937 | uint8_t *rho, *sigma; | ||
938 | uint8_t counter = 0; | ||
939 | uint8_t hashed[64]; | ||
940 | scalar error[RANK1024]; | ||
941 | CBB cbb; | ||
942 | int ret = 0; | ||
943 | |||
944 | private_key_from_external(out_private_key, &priv); | ||
945 | memset(&cbb, 0, sizeof(CBB)); | ||
946 | memcpy(augmented_seed, entropy, 32); | ||
947 | augmented_seed[32] = out_private_key->rank; | ||
948 | hash_g(hashed, augmented_seed, 33); | ||
949 | rho = hashed; | ||
950 | sigma = hashed + 32; | ||
951 | memcpy(priv.pub.rho, hashed, 32); | ||
952 | matrix_expand(priv.pub.m, rho, out_private_key->rank); | ||
953 | vector_generate_secret_eta_2(priv.s, &counter, sigma, | ||
954 | out_private_key->rank); | ||
955 | vector_ntt(priv.s, out_private_key->rank); | ||
956 | vector_generate_secret_eta_2(&error[0], &counter, sigma, | ||
957 | out_private_key->rank); | ||
958 | vector_ntt(&error[0], out_private_key->rank); | ||
959 | matrix_mult_transpose(priv.pub.t, priv.pub.m, priv.s, | ||
960 | out_private_key->rank); | ||
961 | vector_add(priv.pub.t, &error[0], out_private_key->rank); | ||
962 | |||
963 | if (!CBB_init_fixed(&cbb, out_encoded_public_key, | ||
964 | out_private_key->rank == RANK768 ? MLKEM768_PUBLIC_KEY_BYTES : | ||
965 | MLKEM1024_PUBLIC_KEY_BYTES)) | ||
966 | goto err; | ||
967 | |||
968 | if (!mlkem_marshal_public_key_internal(&cbb, &priv.pub, | ||
969 | out_private_key->rank)) | ||
970 | goto err; | ||
971 | |||
972 | hash_h(priv.pub.public_key_hash, out_encoded_public_key, | ||
973 | out_private_key->rank == RANK768 ? MLKEM768_PUBLIC_KEY_BYTES : | ||
974 | MLKEM1024_PUBLIC_KEY_BYTES); | ||
975 | memcpy(priv.fo_failure_secret, entropy + 32, 32); | ||
976 | |||
977 | ret = 1; | ||
978 | |||
979 | err: | ||
980 | CBB_cleanup(&cbb); | ||
981 | |||
982 | return ret; | ||
983 | } | ||
984 | |||
985 | void | ||
986 | mlkem_public_from_private(const MLKEM_private_key *private_key, | ||
987 | MLKEM_public_key *out_public_key) | ||
988 | { | ||
989 | switch (private_key->rank) { | ||
990 | case RANK768: | ||
991 | memcpy(out_public_key->key_768->bytes, | ||
992 | private_key->key_768->bytes, | ||
993 | sizeof(struct MLKEM768_public_key)); | ||
994 | break; | ||
995 | case RANK1024: | ||
996 | memcpy(out_public_key->key_1024->bytes, | ||
997 | private_key->key_1024->bytes, | ||
998 | sizeof(struct MLKEM1024_public_key)); | ||
999 | break; | ||
1000 | } | ||
1001 | } | ||
1002 | |||
1003 | /* | ||
1004 | * Encrypts a message with given randomness to the ciphertext in |out|. Without | ||
1005 | * applying the Fujisaki-Okamoto transform this would not result in a CCA secure | ||
1006 | * scheme, since lattice schemes are vulnerable to decryption failure oracles. | ||
1007 | */ | ||
1008 | static void | ||
1009 | encrypt_cpa(uint8_t *out, const struct public_key *pub, | ||
1010 | const uint8_t message[32], const uint8_t randomness[32], | ||
1011 | size_t rank) | ||
1012 | { | ||
1013 | scalar secret[RANK1024], error[RANK1024], u[RANK1024]; | ||
1014 | scalar expanded_message, scalar_error; | ||
1015 | uint8_t counter = 0; | ||
1016 | uint8_t input[33]; | ||
1017 | scalar v; | ||
1018 | int u_bits = kDU768; | ||
1019 | int v_bits = kDV768; | ||
1020 | |||
1021 | if (rank == RANK1024) { | ||
1022 | u_bits = kDU1024; | ||
1023 | v_bits = kDV1024; | ||
1024 | } | ||
1025 | vector_generate_secret_eta_2(&secret[0], &counter, randomness, rank); | ||
1026 | vector_ntt(&secret[0], rank); | ||
1027 | vector_generate_secret_eta_2(&error[0], &counter, randomness, rank); | ||
1028 | memcpy(input, randomness, 32); | ||
1029 | input[32] = counter; | ||
1030 | scalar_centered_binomial_distribution_eta_2_with_prf(&scalar_error, | ||
1031 | input); | ||
1032 | matrix_mult(&u[0], pub->m, &secret[0], rank); | ||
1033 | vector_inverse_ntt(&u[0], rank); | ||
1034 | vector_add(&u[0], &error[0], rank); | ||
1035 | scalar_inner_product(&v, &pub->t[0], &secret[0], rank); | ||
1036 | scalar_inverse_ntt(&v); | ||
1037 | scalar_add(&v, &scalar_error); | ||
1038 | scalar_decode_1(&expanded_message, message); | ||
1039 | scalar_decompress(&expanded_message, 1); | ||
1040 | scalar_add(&v, &expanded_message); | ||
1041 | vector_compress(&u[0], u_bits, rank); | ||
1042 | vector_encode(out, &u[0], u_bits, rank); | ||
1043 | scalar_compress(&v, v_bits); | ||
1044 | scalar_encode(out + compressed_vector_size(rank), &v, v_bits); | ||
1045 | } | ||
1046 | |||
1047 | /* Calls mlkem_encap_external_entropy| with random bytes */ | ||
1048 | void | ||
1049 | mlkem_encap(const MLKEM_public_key *public_key, | ||
1050 | uint8_t *out_ciphertext, | ||
1051 | uint8_t out_shared_secret[MLKEM_SHARED_SECRET_LENGTH]) | ||
1052 | { | ||
1053 | uint8_t entropy[MLKEM_ENCAP_ENTROPY]; | ||
1054 | |||
1055 | arc4random_buf(entropy, MLKEM_ENCAP_ENTROPY); | ||
1056 | mlkem_encap_external_entropy(out_ciphertext, | ||
1057 | out_shared_secret, public_key, entropy); | ||
1058 | } | ||
1059 | |||
1060 | /* See section 6.2 of the spec. */ | ||
1061 | void | ||
1062 | mlkem_encap_external_entropy(uint8_t *out_ciphertext, | ||
1063 | uint8_t out_shared_secret[MLKEM_SHARED_SECRET_LENGTH], | ||
1064 | const MLKEM_public_key *public_key, | ||
1065 | const uint8_t entropy[MLKEM_ENCAP_ENTROPY]) | ||
1066 | { | ||
1067 | struct public_key pub; | ||
1068 | uint8_t key_and_randomness[64]; | ||
1069 | uint8_t input[64]; | ||
1070 | |||
1071 | public_key_from_external(public_key, &pub); | ||
1072 | memcpy(input, entropy, MLKEM_ENCAP_ENTROPY); | ||
1073 | memcpy(input + MLKEM_ENCAP_ENTROPY, pub.public_key_hash, | ||
1074 | sizeof(input) - MLKEM_ENCAP_ENTROPY); | ||
1075 | hash_g(key_and_randomness, input, sizeof(input)); | ||
1076 | encrypt_cpa(out_ciphertext, &pub, entropy, key_and_randomness + 32, | ||
1077 | public_key->rank); | ||
1078 | memcpy(out_shared_secret, key_and_randomness, 32); | ||
1079 | } | ||
1080 | |||
1081 | static void | ||
1082 | decrypt_cpa(uint8_t out[32], const struct private_key *priv, | ||
1083 | const uint8_t *ciphertext, size_t rank) | ||
1084 | { | ||
1085 | scalar u[RANK1024]; | ||
1086 | scalar mask, v; | ||
1087 | int u_bits = kDU768; | ||
1088 | int v_bits = kDV768; | ||
1089 | |||
1090 | if (rank == RANK1024) { | ||
1091 | u_bits = kDU1024; | ||
1092 | v_bits = kDV1024; | ||
1093 | } | ||
1094 | vector_decode(&u[0], ciphertext, u_bits, rank); | ||
1095 | vector_decompress(&u[0], u_bits, rank); | ||
1096 | vector_ntt(&u[0], rank); | ||
1097 | scalar_decode(&v, ciphertext + compressed_vector_size(rank), v_bits); | ||
1098 | scalar_decompress(&v, v_bits); | ||
1099 | scalar_inner_product(&mask, &priv->s[0], &u[0], rank); | ||
1100 | scalar_inverse_ntt(&mask); | ||
1101 | scalar_sub(&v, &mask); | ||
1102 | scalar_compress(&v, 1); | ||
1103 | scalar_encode_1(out, &v); | ||
1104 | } | ||
1105 | |||
1106 | /* See section 6.3 */ | ||
1107 | int | ||
1108 | mlkem_decap(const MLKEM_private_key *private_key, const uint8_t *ciphertext, | ||
1109 | size_t ciphertext_len, uint8_t out_shared_secret[MLKEM_SHARED_SECRET_LENGTH]) | ||
1110 | { | ||
1111 | struct private_key priv; | ||
1112 | size_t expected_ciphertext_length = private_key->rank == RANK768 ? | ||
1113 | MLKEM768_CIPHERTEXT_BYTES : MLKEM1024_CIPHERTEXT_BYTES; | ||
1114 | uint8_t *expected_ciphertext = NULL; | ||
1115 | uint8_t key_and_randomness[64]; | ||
1116 | uint8_t failure_key[32]; | ||
1117 | uint8_t decrypted[64]; | ||
1118 | uint8_t mask; | ||
1119 | int i; | ||
1120 | int ret = 0; | ||
1121 | |||
1122 | if (ciphertext_len != expected_ciphertext_length) { | ||
1123 | arc4random_buf(out_shared_secret, MLKEM_SHARED_SECRET_LENGTH); | ||
1124 | goto err; | ||
1125 | } | ||
1126 | |||
1127 | if ((expected_ciphertext = calloc(1, expected_ciphertext_length)) == | ||
1128 | NULL) { | ||
1129 | arc4random_buf(out_shared_secret, MLKEM_SHARED_SECRET_LENGTH); | ||
1130 | goto err; | ||
1131 | } | ||
1132 | |||
1133 | private_key_from_external(private_key, &priv); | ||
1134 | decrypt_cpa(decrypted, &priv, ciphertext, private_key->rank); | ||
1135 | memcpy(decrypted + 32, priv.pub.public_key_hash, | ||
1136 | sizeof(decrypted) - 32); | ||
1137 | hash_g(key_and_randomness, decrypted, sizeof(decrypted)); | ||
1138 | encrypt_cpa(expected_ciphertext, &priv.pub, decrypted, | ||
1139 | key_and_randomness + 32, private_key->rank); | ||
1140 | kdf(failure_key, priv.fo_failure_secret, ciphertext, ciphertext_len); | ||
1141 | mask = constant_time_eq_int_8(memcmp(ciphertext, expected_ciphertext, | ||
1142 | expected_ciphertext_length), 0); | ||
1143 | for (i = 0; i < MLKEM_SHARED_SECRET_LENGTH; i++) { | ||
1144 | out_shared_secret[i] = constant_time_select_8(mask, | ||
1145 | key_and_randomness[i], failure_key[i]); | ||
1146 | } | ||
1147 | |||
1148 | ret = 1; | ||
1149 | |||
1150 | err: | ||
1151 | freezero(expected_ciphertext, expected_ciphertext_length); | ||
1152 | |||
1153 | return ret; | ||
1154 | } | ||
1155 | |||
1156 | int | ||
1157 | mlkem_marshal_public_key(const MLKEM_public_key *public_key, | ||
1158 | uint8_t **output, size_t *output_len) | ||
1159 | { | ||
1160 | struct public_key pub; | ||
1161 | int ret = 0; | ||
1162 | CBB cbb; | ||
1163 | |||
1164 | if (!CBB_init(&cbb, public_key->rank == RANK768 ? | ||
1165 | MLKEM768_PUBLIC_KEY_BYTES : MLKEM1024_PUBLIC_KEY_BYTES)) | ||
1166 | goto err; | ||
1167 | public_key_from_external(public_key, &pub); | ||
1168 | if (!mlkem_marshal_public_key_internal(&cbb, &pub, public_key->rank)) | ||
1169 | goto err; | ||
1170 | if (!CBB_finish(&cbb, output, output_len)) | ||
1171 | goto err; | ||
1172 | |||
1173 | ret = 1; | ||
1174 | |||
1175 | err: | ||
1176 | CBB_cleanup(&cbb); | ||
1177 | |||
1178 | return ret; | ||
1179 | } | ||
1180 | |||
1181 | /* | ||
1182 | * mlkem_parse_public_key_no_hash parses |in| into |pub| but doesn't calculate | ||
1183 | * the value of |pub->public_key_hash|. | ||
1184 | */ | ||
1185 | static int | ||
1186 | mlkem_parse_public_key_no_hash(struct public_key *pub, CBS *in, size_t rank) | ||
1187 | { | ||
1188 | CBS t_bytes; | ||
1189 | |||
1190 | if (!CBS_get_bytes(in, &t_bytes, encoded_vector_size(rank))) | ||
1191 | return 0; | ||
1192 | if (!vector_decode(&pub->t[0], CBS_data(&t_bytes), kLog2Prime, rank)) | ||
1193 | return 0; | ||
1194 | |||
1195 | memcpy(pub->rho, CBS_data(in), 32); | ||
1196 | if (!CBS_skip(in, 32)) | ||
1197 | return 0; | ||
1198 | matrix_expand(pub->m, pub->rho, rank); | ||
1199 | return 1; | ||
1200 | } | ||
1201 | |||
1202 | int | ||
1203 | mlkem_parse_public_key(const uint8_t *input, size_t input_len, | ||
1204 | MLKEM_public_key *public_key) | ||
1205 | { | ||
1206 | struct public_key pub; | ||
1207 | CBS cbs; | ||
1208 | |||
1209 | public_key_from_external(public_key, &pub); | ||
1210 | CBS_init(&cbs, input, input_len); | ||
1211 | if (!mlkem_parse_public_key_no_hash(&pub, &cbs, public_key->rank)) | ||
1212 | return 0; | ||
1213 | if (CBS_len(&cbs) != 0) | ||
1214 | return 0; | ||
1215 | |||
1216 | hash_h(pub.public_key_hash, input, input_len); | ||
1217 | |||
1218 | return 1; | ||
1219 | } | ||
1220 | |||
1221 | int | ||
1222 | mlkem_marshal_private_key(const MLKEM_private_key *private_key, | ||
1223 | uint8_t **out_private_key, size_t *out_private_key_len) | ||
1224 | { | ||
1225 | struct private_key priv; | ||
1226 | size_t key_length = private_key->rank == RANK768 ? | ||
1227 | MLKEM768_PRIVATE_KEY_BYTES : MLKEM1024_PRIVATE_KEY_BYTES; | ||
1228 | CBB cbb; | ||
1229 | int ret = 0; | ||
1230 | |||
1231 | private_key_from_external(private_key, &priv); | ||
1232 | if (!CBB_init(&cbb, key_length)) | ||
1233 | goto err; | ||
1234 | |||
1235 | if (!vector_encode_cbb(&cbb, priv.s, kLog2Prime, private_key->rank)) | ||
1236 | goto err; | ||
1237 | if (!mlkem_marshal_public_key_internal(&cbb, &priv.pub, | ||
1238 | private_key->rank)) | ||
1239 | goto err; | ||
1240 | if (!CBB_add_bytes(&cbb, priv.pub.public_key_hash, 32)) | ||
1241 | goto err; | ||
1242 | if (!CBB_add_bytes(&cbb, priv.fo_failure_secret, 32)) | ||
1243 | goto err; | ||
1244 | |||
1245 | if (!CBB_finish(&cbb, out_private_key, out_private_key_len)) | ||
1246 | goto err; | ||
1247 | |||
1248 | ret = 1; | ||
1249 | |||
1250 | err: | ||
1251 | CBB_cleanup(&cbb); | ||
1252 | |||
1253 | return ret; | ||
1254 | } | ||
1255 | |||
1256 | int | ||
1257 | mlkem_parse_private_key(const uint8_t *input, size_t input_len, | ||
1258 | MLKEM_private_key *out_private_key) | ||
1259 | { | ||
1260 | struct private_key priv; | ||
1261 | CBS cbs, s_bytes; | ||
1262 | |||
1263 | private_key_from_external(out_private_key, &priv); | ||
1264 | CBS_init(&cbs, input, input_len); | ||
1265 | |||
1266 | if (!CBS_get_bytes(&cbs, &s_bytes, | ||
1267 | encoded_vector_size(out_private_key->rank))) | ||
1268 | return 0; | ||
1269 | if (!vector_decode(priv.s, CBS_data(&s_bytes), kLog2Prime, | ||
1270 | out_private_key->rank)) | ||
1271 | return 0; | ||
1272 | if (!mlkem_parse_public_key_no_hash(&priv.pub, &cbs, | ||
1273 | out_private_key->rank)) | ||
1274 | return 0; | ||
1275 | |||
1276 | memcpy(priv.pub.public_key_hash, CBS_data(&cbs), 32); | ||
1277 | if (!CBS_skip(&cbs, 32)) | ||
1278 | return 0; | ||
1279 | memcpy(priv.fo_failure_secret, CBS_data(&cbs), 32); | ||
1280 | if (!CBS_skip(&cbs, 32)) | ||
1281 | return 0; | ||
1282 | if (CBS_len(&cbs) != 0) | ||
1283 | return 0; | ||
1284 | |||
1285 | return 1; | ||
1286 | } | ||