summaryrefslogtreecommitdiff
path: root/src/lib/libcrypto/mlkem/mlkem_internal.c
diff options
context:
space:
mode:
authorbeck <>2025-09-05 23:30:12 +0000
committerbeck <>2025-09-05 23:30:12 +0000
commit21ce890cad6ae67e0d52f7bfdc44579df2bfc032 (patch)
tree696ffc96af9e6fa399dc93af7325749db45458c5 /src/lib/libcrypto/mlkem/mlkem_internal.c
parentdc98dc450acd1ba6a9a274662e55679f5c93e5fa (diff)
downloadopenbsd-21ce890cad6ae67e0d52f7bfdc44579df2bfc032.tar.gz
openbsd-21ce890cad6ae67e0d52f7bfdc44579df2bfc032.tar.bz2
openbsd-21ce890cad6ae67e0d52f7bfdc44579df2bfc032.zip
Deduplicate the mlkem 768 and mlkem 1024 code.
This moves everything not public to mlkem_internal.c removing the old files and doing some further cleanup on the way. With this landed mlkem is out of my stack and can be changed without breaking my subsequent changes ok tb@
Diffstat (limited to 'src/lib/libcrypto/mlkem/mlkem_internal.c')
-rw-r--r--src/lib/libcrypto/mlkem/mlkem_internal.c1286
1 files changed, 1286 insertions, 0 deletions
diff --git a/src/lib/libcrypto/mlkem/mlkem_internal.c b/src/lib/libcrypto/mlkem/mlkem_internal.c
new file mode 100644
index 0000000000..653b2f332d
--- /dev/null
+++ b/src/lib/libcrypto/mlkem/mlkem_internal.c
@@ -0,0 +1,1286 @@
1/* $OpenBSD: mlkem_internal.c,v 1.1 2025/09/05 23:30:12 beck Exp $ */
2/*
3 * Copyright (c) 2024, Google Inc.
4 * Copyright (c) 2024, 2025 Bob Beck <beck@obtuse.com>
5 *
6 * Permission to use, copy, modify, and/or distribute this software for any
7 * purpose with or without fee is hereby granted, provided that the above
8 * copyright notice and this permission notice appear in all copies.
9 *
10 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
13 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
15 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
16 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17 */
18
19#include <assert.h>
20#include <stdlib.h>
21#include <string.h>
22#include <stdio.h>
23
24#include <openssl/mlkem.h>
25
26#include "bytestring.h"
27#include "sha3_internal.h"
28#include "mlkem_internal.h"
29#include "constant_time.h"
30#include "crypto_internal.h"
31
32/*
33 * See
34 * https://csrc.nist.gov/pubs/fips/203/final
35 */
36
37static void
38prf(uint8_t *out, size_t out_len, const uint8_t in[33])
39{
40 sha3_ctx ctx;
41 shake256_init(&ctx);
42 shake_update(&ctx, in, 33);
43 shake_xof(&ctx);
44 shake_out(&ctx, out, out_len);
45}
46
47/* Section 4.1 */
48static void
49hash_h(uint8_t out[32], const uint8_t *in, size_t len)
50{
51 sha3_ctx ctx;
52 sha3_init(&ctx, 32);
53 sha3_update(&ctx, in, len);
54 sha3_final(out, &ctx);
55}
56
57static void
58hash_g(uint8_t out[64], const uint8_t *in, size_t len)
59{
60 sha3_ctx ctx;
61 sha3_init(&ctx, 64);
62 sha3_update(&ctx, in, len);
63 sha3_final(out, &ctx);
64}
65
66/* this is called 'J' in the spec */
67static void
68kdf(uint8_t out[MLKEM_SHARED_SECRET_LENGTH], const uint8_t failure_secret[32],
69 const uint8_t *in, size_t len)
70{
71 sha3_ctx ctx;
72 shake256_init(&ctx);
73 shake_update(&ctx, failure_secret, 32);
74 shake_update(&ctx, in, len);
75 shake_xof(&ctx);
76 shake_out(&ctx, out, MLKEM_SHARED_SECRET_LENGTH);
77}
78
79#define DEGREE 256
80
81static const size_t kBarrettMultiplier = 5039;
82static const unsigned kBarrettShift = 24;
83static const uint16_t kPrime = 3329;
84static const int kLog2Prime = 12;
85static const uint16_t kHalfPrime = (/*kPrime=*/3329 - 1) / 2;
86static const int kDU768 = 10;
87static const int kDV768 = 4;
88static const int kDU1024 = 11;
89static const int kDV1024 = 5;
90
91/*
92 * kInverseDegree is 128^-1 mod 3329; 128 because kPrime does not have a 512th
93 * root of unity.
94 */
95static const uint16_t kInverseDegree = 3303;
96
97static inline size_t
98encoded_vector_size(uint16_t rank)
99{
100 return (kLog2Prime * DEGREE / 8) * rank;
101}
102
103static inline size_t
104compressed_vector_size(uint16_t rank)
105{
106 return ((rank == RANK768) ? kDU768 : kDU1024) * rank * DEGREE / 8;
107}
108
109typedef struct scalar {
110 /* On every function entry and exit, 0 <= c < kPrime. */
111 uint16_t c[DEGREE];
112} scalar;
113
114/*
115 * Retrieve a const scalar from const matrix of |rank| at position [row][col]
116 */
117static inline const scalar *
118const_m2s(const scalar *v, size_t row, size_t col, uint16_t rank)
119{
120 return ((scalar *)v) + row * rank + col;
121}
122
123/*
124 * Retrieve a scalar from matrix of |rank| at position [row][col]
125 */
126static inline scalar *
127m2s(scalar *v, size_t row, size_t col, uint16_t rank)
128{
129 return ((scalar *)v) + row * rank + col;
130}
131
132/*
133 * This bit of Python will be referenced in some of the following comments:
134 *
135 * p = 3329
136 *
137 * def bitreverse(i):
138 * ret = 0
139 * for n in range(7):
140 * bit = i & 1
141 * ret <<= 1
142 * ret |= bit
143 * i >>= 1
144 * return ret
145 */
146
147/* kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)] */
148static const uint16_t kNTTRoots[128] = {
149 1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797,
150 2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333,
151 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756,
152 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915,
153 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648,
154 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100,
155 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789,
156 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641,
157 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757,
158 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886,
159 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154,
160};
161
162/* kInverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)] */
163static const uint16_t kInverseNTTRoots[128] = {
164 1, 1600, 40, 749, 2481, 1432, 2699, 687, 1583, 2760, 69, 543,
165 2532, 3136, 1410, 2267, 2508, 1355, 450, 936, 447, 2794, 1235, 1903,
166 1996, 1089, 3273, 283, 1853, 1990, 882, 3033, 2419, 2102, 219, 855,
167 2681, 1848, 712, 682, 927, 1795, 461, 1891, 2877, 2522, 1894, 1010,
168 1414, 2009, 3296, 464, 2697, 816, 1352, 2679, 1274, 1052, 1025, 2132,
169 1573, 76, 2998, 3040, 1175, 2444, 394, 1219, 2300, 1455, 2117, 1607,
170 2443, 554, 1179, 2186, 2303, 2926, 2237, 525, 735, 863, 2768, 1230,
171 2572, 556, 3010, 2266, 1684, 1239, 780, 2954, 109, 1292, 1031, 1745,
172 2688, 3061, 992, 2596, 941, 892, 1021, 2390, 642, 1868, 2377, 1482,
173 1540, 540, 1678, 1626, 279, 314, 1173, 2573, 3096, 48, 667, 1920,
174 2229, 1041, 2606, 1692, 680, 2746, 568, 3312,
175};
176
177/* kModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)] */
178static const uint16_t kModRoots[128] = {
179 17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606,
180 2288, 1041, 1100, 2229, 1409, 1920, 2662, 667, 3281, 48, 233, 3096,
181 756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678,
182 2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642,
183 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992,
184 268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109,
185 375, 2954, 2549, 780, 2090, 1239, 1645, 1684, 1063, 2266, 319, 3010,
186 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735,
187 2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179,
188 2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300,
189 2110, 1219, 2935, 394, 885, 2444, 2154, 1175,
190};
191
192/* reduce_once reduces 0 <= x < 2*kPrime, mod kPrime. */
193static uint16_t
194reduce_once(uint16_t x)
195{
196 assert(x < 2 * kPrime);
197 const uint16_t subtracted = x - kPrime;
198 uint16_t mask = 0u - (subtracted >> 15);
199
200 /*
201 * Although this is a constant-time select, we omit a value barrier here.
202 * Value barriers impede auto-vectorization (likely because it forces the
203 * value to transit through a general-purpose register). On AArch64, this
204 * is a difference of 2x.
205 *
206 * We usually add value barriers to selects because Clang turns
207 * consecutive selects with the same condition into a branch instead of
208 * CMOV/CSEL. This condition does not occur in ML-KEM, so omitting it
209 * seems to be safe so far but see
210 * |scalar_centered_binomial_distribution_eta_2_with_prf|.
211 */
212 return (mask & x) | (~mask & subtracted);
213}
214
215/*
216 * constant time reduce x mod kPrime using Barrett reduction. x must be less
217 * than kPrime + 2×kPrime².
218 */
219static uint16_t
220reduce(uint32_t x)
221{
222 uint64_t product = (uint64_t)x * kBarrettMultiplier;
223 uint32_t quotient = (uint32_t)(product >> kBarrettShift);
224 uint32_t remainder = x - quotient * kPrime;
225
226 assert(x < kPrime + 2u * kPrime * kPrime);
227 return reduce_once(remainder);
228}
229
230static void
231scalar_zero(scalar *out)
232{
233 memset(out, 0, sizeof(*out));
234}
235
236static void
237vector_zero(scalar *out, size_t rank)
238{
239 memset(out, 0, sizeof(*out) * rank);
240}
241
242/*
243 * In place number theoretic transform of a given scalar.
244 * Note that MLKEM's kPrime 3329 does not have a 512th root of unity, so this
245 * transform leaves off the last iteration of the usual FFT code, with the 128
246 * relevant roots of unity being stored in |kNTTRoots|. This means the output
247 * should be seen as 128 elements in GF(3329^2), with the coefficients of the
248 * elements being consecutive entries in |s->c|.
249 */
250static void
251scalar_ntt(scalar *s)
252{
253 int offset = DEGREE;
254 int step;
255 /*
256 * `int` is used here because using `size_t` throughout caused a ~5% slowdown
257 * with Clang 14 on Aarch64.
258 */
259 for (step = 1; step < DEGREE / 2; step <<= 1) {
260 int i, j, k = 0;
261
262 offset >>= 1;
263 for (i = 0; i < step; i++) {
264 const uint32_t step_root = kNTTRoots[i + step];
265
266 for (j = k; j < k + offset; j++) {
267 uint16_t odd, even;
268
269 odd = reduce(step_root * s->c[j + offset]);
270 even = s->c[j];
271 s->c[j] = reduce_once(odd + even);
272 s->c[j + offset] = reduce_once(even - odd +
273 kPrime);
274 }
275 k += 2 * offset;
276 }
277 }
278}
279
280static void
281vector_ntt(scalar *v, size_t rank)
282{
283 size_t i;
284
285 for (i = 0; i < rank; i++) {
286 scalar_ntt(&v[i]);
287 }
288}
289
290/*
291 * In place inverse number theoretic transform of a given scalar, with pairs of
292 * entries of s->v being interpreted as elements of GF(3329^2). Just as with the
293 * number theoretic transform, this leaves off the first step of the normal iFFT
294 * to account for the fact that 3329 does not have a 512th root of unity, using
295 * the precomputed 128 roots of unity stored in |kInverseNTTRoots|.
296 */
297static void
298scalar_inverse_ntt(scalar *s)
299{
300 int i, j, k, offset, step = DEGREE / 2;
301
302 /*
303 * `int` is used here because using `size_t` throughout caused a ~5% slowdown
304 * with Clang 14 on Aarch64.
305 */
306 for (offset = 2; offset < DEGREE; offset <<= 1) {
307 step >>= 1;
308 k = 0;
309 for (i = 0; i < step; i++) {
310 uint32_t step_root = kInverseNTTRoots[i + step];
311 for (j = k; j < k + offset; j++) {
312 uint16_t odd, even;
313 odd = s->c[j + offset];
314 even = s->c[j];
315 s->c[j] = reduce_once(odd + even);
316 s->c[j + offset] = reduce(step_root *
317 (even - odd + kPrime));
318 }
319 k += 2 * offset;
320 }
321 }
322 for (i = 0; i < DEGREE; i++) {
323 s->c[i] = reduce(s->c[i] * kInverseDegree);
324 }
325}
326
327static void
328vector_inverse_ntt(scalar *v, size_t rank)
329{
330 size_t i;
331
332 for (i = 0; i < rank; i++) {
333 scalar_inverse_ntt(&v[i]);
334 }
335}
336
337static void
338scalar_add(scalar *lhs, const scalar *rhs)
339{
340 int i;
341
342 for (i = 0; i < DEGREE; i++) {
343 lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
344 }
345}
346
347static void
348scalar_sub(scalar *lhs, const scalar *rhs)
349{
350 int i;
351
352 for (i = 0; i < DEGREE; i++) {
353 lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
354 }
355}
356
357/*
358 * Multiplying two scalars in the number theoretically transformed state.
359 * Since 3329 does not have a 512th root of unity, this means we have to
360 * interpret the 2*ith and (2*i+1)th entries of the scalar as elements of
361 * GF(3329)[X]/(X^2 - 17^(2*bitreverse(i)+1)).
362 * The value of 17^(2*bitreverse(i)+1) mod 3329 is stored in the precomputed
363 * |kModRoots| table. Our Barrett transform only allows us to multiply two
364 * reduced numbers together, so we need some intermediate reduction steps,
365 * even if an uint64_t could hold 3 multiplied numbers.
366 */
367static void
368scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs)
369{
370 int i;
371
372 for (i = 0; i < DEGREE / 2; i++) {
373 uint32_t real_real = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i];
374 uint32_t img_img = (uint32_t)lhs->c[2 * i + 1] *
375 rhs->c[2 * i + 1];
376 uint32_t real_img = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i + 1];
377 uint32_t img_real = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i];
378
379 out->c[2 * i] =
380 reduce(real_real +
381 (uint32_t)reduce(img_img) * kModRoots[i]);
382 out->c[2 * i + 1] = reduce(img_real + real_img);
383 }
384}
385
386static void
387vector_add(scalar *lhs, const scalar *rhs, size_t rank)
388{
389 size_t i;
390
391 for (i = 0; i < rank; i++) {
392 scalar_add(&lhs[i], &rhs[i]);
393 }
394}
395
396static void
397matrix_mult(scalar *out, const void *m, const scalar *a, size_t rank)
398{
399 size_t i, j;
400
401 vector_zero(&out[0], rank);
402 for (i = 0; i < rank; i++) {
403 for (j = 0; j < rank; j++) {
404 scalar product;
405
406 scalar_mult(&product, const_m2s(m, i, j, rank), &a[j]);
407 scalar_add(&out[i], &product);
408 }
409 }
410}
411
412static void
413matrix_mult_transpose(scalar *out, const void *m, const scalar *a, size_t rank)
414{
415 int i, j;
416
417 vector_zero(&out[0], rank);
418 for (i = 0; i < rank; i++) {
419 for (j = 0; j < rank; j++) {
420 scalar product;
421
422 scalar_mult(&product, const_m2s(m, j, i, rank), &a[j]);
423 scalar_add(&out[i], &product);
424 }
425 }
426}
427
428static void
429scalar_inner_product(scalar *out, const scalar *lhs,
430 const scalar *rhs, size_t rank)
431{
432 size_t i;
433
434 scalar_zero(out);
435 for (i = 0; i < rank; i++) {
436 scalar product;
437
438 scalar_mult(&product, &lhs[i], &rhs[i]);
439 scalar_add(out, &product);
440 }
441}
442
443/*
444 * Algorithm 6 of spec. Rejection samples a Keccak stream to get uniformly
445 * distributed elements. This is used for matrix expansion and only operates on
446 * public inputs.
447 */
448static void
449scalar_from_keccak_vartime(scalar *out, sha3_ctx *keccak_ctx)
450{
451 int i, done = 0;
452
453 while (done < DEGREE) {
454 uint8_t block[168];
455
456 shake_out(keccak_ctx, block, sizeof(block));
457 for (i = 0; i < sizeof(block) && done < DEGREE; i += 3) {
458 uint16_t d1 = block[i] + 256 * (block[i + 1] % 16);
459 uint16_t d2 = block[i + 1] / 16 + 16 * block[i + 2];
460
461 if (d1 < kPrime) {
462 out->c[done++] = d1;
463 }
464 if (d2 < kPrime && done < DEGREE) {
465 out->c[done++] = d2;
466 }
467 }
468 }
469}
470
471/*
472 * Algorithm 7 of the spec, with eta fixed to two and the PRF call
473 * included. Creates binominally distributed elements by sampling 2*|eta| bits,
474 * and setting the coefficient to the count of the first bits minus the count of
475 * the second bits, resulting in a centered binomial distribution. Since eta is
476 * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
477 * and 0 with probability 3/8.
478 */
479static void
480scalar_centered_binomial_distribution_eta_2_with_prf(scalar *out,
481 const uint8_t input[33])
482{
483 uint8_t entropy[128];
484 int i;
485
486 CTASSERT(sizeof(entropy) == 2 * /*kEta=*/ 2 * DEGREE / 8);
487 prf(entropy, sizeof(entropy), input);
488
489 for (i = 0; i < DEGREE; i += 2) {
490 uint8_t byte = entropy[i / 2];
491 uint16_t mask;
492 uint16_t value = (byte & 1) + ((byte >> 1) & 1);
493
494 value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
495
496 /*
497 * Add |kPrime| if |value| underflowed. See |reduce_once| for a
498 * discussion on why the value barrier is omitted. While this
499 * could have been written reduce_once(value + kPrime), this is
500 * one extra addition and small range of |value| tempts some
501 * versions of Clang to emit a branch.
502 */
503 mask = 0u - (value >> 15);
504 out->c[i] = ((value + kPrime) & mask) | (value & ~mask);
505
506 byte >>= 4;
507 value = (byte & 1) + ((byte >> 1) & 1);
508 value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
509 /* See above. */
510 mask = 0u - (value >> 15);
511 out->c[i + 1] = ((value + kPrime) & mask) | (value & ~mask);
512 }
513}
514
515/*
516 * Generates a secret vector by using
517 * |scalar_centered_binomial_distribution_eta_2_with_prf|, using the given seed
518 * appending and incrementing |counter| for entry of the vector.
519 */
520static void
521vector_generate_secret_eta_2(scalar *out, uint8_t *counter,
522 const uint8_t seed[32], size_t rank)
523{
524 uint8_t input[33];
525 size_t i;
526
527 memcpy(input, seed, 32);
528 for (i = 0; i < rank; i++) {
529 input[32] = (*counter)++;
530 scalar_centered_binomial_distribution_eta_2_with_prf(&out[i],
531 input);
532 }
533}
534
535/* Expands the matrix of a seed for key generation and for encaps-CPA. */
536static void
537matrix_expand(void *out, const uint8_t rho[32], size_t rank)
538{
539 uint8_t input[34];
540 size_t i, j;
541
542 memcpy(input, rho, 32);
543 for (i = 0; i < rank; i++) {
544 for (j = 0; j < rank; j++) {
545 sha3_ctx keccak_ctx;
546
547 input[32] = i;
548 input[33] = j;
549 shake128_init(&keccak_ctx);
550 shake_update(&keccak_ctx, input, sizeof(input));
551 shake_xof(&keccak_ctx);
552 scalar_from_keccak_vartime(m2s(out, i, j, rank),
553 &keccak_ctx);
554 }
555 }
556}
557
558static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f,
559 0x1f, 0x3f, 0x7f, 0xff};
560
561static void
562scalar_encode(uint8_t *out, const scalar *s, int bits)
563{
564 uint8_t out_byte = 0;
565 int i, out_byte_bits = 0;
566
567 assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1);
568 for (i = 0; i < DEGREE; i++) {
569 uint16_t element = s->c[i];
570 int element_bits_done = 0;
571
572 while (element_bits_done < bits) {
573 int chunk_bits = bits - element_bits_done;
574 int out_bits_remaining = 8 - out_byte_bits;
575
576 if (chunk_bits >= out_bits_remaining) {
577 chunk_bits = out_bits_remaining;
578 out_byte |= (element &
579 kMasks[chunk_bits - 1]) << out_byte_bits;
580 *out = out_byte;
581 out++;
582 out_byte_bits = 0;
583 out_byte = 0;
584 } else {
585 out_byte |= (element &
586 kMasks[chunk_bits - 1]) << out_byte_bits;
587 out_byte_bits += chunk_bits;
588 }
589
590 element_bits_done += chunk_bits;
591 element >>= chunk_bits;
592 }
593 }
594
595 if (out_byte_bits > 0) {
596 *out = out_byte;
597 }
598}
599
600/* scalar_encode_1 is |scalar_encode| specialised for |bits| == 1. */
601static void
602scalar_encode_1(uint8_t out[32], const scalar *s)
603{
604 int i, j;
605
606 for (i = 0; i < DEGREE; i += 8) {
607 uint8_t out_byte = 0;
608
609 for (j = 0; j < 8; j++) {
610 out_byte |= (s->c[i + j] & 1) << j;
611 }
612 *out = out_byte;
613 out++;
614 }
615}
616
617/*
618 * Encodes an entire vector into 32*|RANK768|*|bits| bytes. Note that since 256
619 * (DEGREE) is divisible by 8, the individual vector entries will always fill a
620 * whole number of bytes, so we do not need to worry about bit packing here.
621 */
622static void
623vector_encode(uint8_t *out, const scalar *a, int bits, size_t rank)
624{
625 int i;
626
627 for (i = 0; i < rank; i++) {
628 scalar_encode(out + i * bits * DEGREE / 8, &a[i], bits);
629 }
630}
631
632/* Encodes an entire vector as above, but adding it to a CBB */
633static int
634vector_encode_cbb(CBB *cbb, const scalar *a, int bits, size_t rank)
635{
636 uint8_t *encoded_vector;
637
638 if (!CBB_add_space(cbb, &encoded_vector, encoded_vector_size(rank)))
639 return 0;
640 vector_encode(encoded_vector, a, bits, rank);
641
642 return 1;
643}
644
645/*
646 * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
647 * |out|. It returns one on success and zero if any parsed value is >=
648 * |kPrime|.
649 */
650static int
651scalar_decode(scalar *out, const uint8_t *in, int bits)
652{
653 uint8_t in_byte = 0;
654 int i, in_byte_bits_left = 0;
655
656 assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1);
657
658 for (i = 0; i < DEGREE; i++) {
659 uint16_t element = 0;
660 int element_bits_done = 0;
661
662 while (element_bits_done < bits) {
663 int chunk_bits = bits - element_bits_done;
664
665 if (in_byte_bits_left == 0) {
666 in_byte = *in;
667 in++;
668 in_byte_bits_left = 8;
669 }
670
671 if (chunk_bits > in_byte_bits_left) {
672 chunk_bits = in_byte_bits_left;
673 }
674
675 element |= (in_byte & kMasks[chunk_bits - 1]) <<
676 element_bits_done;
677 in_byte_bits_left -= chunk_bits;
678 in_byte >>= chunk_bits;
679
680 element_bits_done += chunk_bits;
681 }
682
683 if (element >= kPrime) {
684 return 0;
685 }
686 out->c[i] = element;
687 }
688
689 return 1;
690}
691
692/* scalar_decode_1 is |scalar_decode| specialised for |bits| == 1. */
693static void
694scalar_decode_1(scalar *out, const uint8_t in[32])
695{
696 int i, j;
697
698 for (i = 0; i < DEGREE; i += 8) {
699 uint8_t in_byte = *in;
700
701 in++;
702 for (j = 0; j < 8; j++) {
703 out->c[i + j] = in_byte & 1;
704 in_byte >>= 1;
705 }
706 }
707}
708
709/*
710 * Decodes 32*|RANK768|*|bits| bytes from |in| into |out|. It returns one on
711 * success or zero if any parsed value is >= |kPrime|.
712 */
713static int
714vector_decode(scalar *out, const uint8_t *in, int bits, size_t rank)
715{
716 size_t i;
717
718 for (i = 0; i < rank; i++) {
719 if (!scalar_decode(&out[i], in + i * bits * DEGREE / 8,
720 bits)) {
721 return 0;
722 }
723 }
724 return 1;
725}
726
727/*
728 * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
729 * numbers close to each other together. The formula used is
730 * round(2^|bits|/kPrime*x) mod 2^|bits|.
731 * Uses Barrett reduction to achieve constant time. Since we need both the
732 * remainder (for rounding) and the quotient (as the result), we cannot use
733 * |reduce| here, but need to do the Barrett reduction directly.
734 */
735static uint16_t
736compress(uint16_t x, int bits)
737{
738 uint32_t shifted = (uint32_t)x << bits;
739 uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
740 uint32_t quotient = (uint32_t)(product >> kBarrettShift);
741 uint32_t remainder = shifted - quotient * kPrime;
742
743 /*
744 * Adjust the quotient to round correctly:
745 * 0 <= remainder <= kHalfPrime round to 0
746 * kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
747 * kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
748 */
749 assert(remainder < 2u * kPrime);
750 quotient += 1 & constant_time_lt(kHalfPrime, remainder);
751 quotient += 1 & constant_time_lt(kPrime + kHalfPrime, remainder);
752 return quotient & ((1 << bits) - 1);
753}
754
755/*
756 * Decompresses |x| by using an equi-distant representative. The formula is
757 * round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us to
758 * implement this logic using only bit operations.
759 */
760static uint16_t
761decompress(uint16_t x, int bits)
762{
763 uint32_t product = (uint32_t)x * kPrime;
764 uint32_t power = 1 << bits;
765 /* This is |product| % power, since |power| is a power of 2. */
766 uint32_t remainder = product & (power - 1);
767 /* This is |product| / power, since |power| is a power of 2. */
768 uint32_t lower = product >> bits;
769
770 /*
771 * The rounding logic works since the first half of numbers mod |power| have a
772 * 0 as first bit, and the second half has a 1 as first bit, since |power| is
773 * a power of 2. As a 12 bit number, |remainder| is always positive, so we
774 * will shift in 0s for a right shift.
775 */
776 return lower + (remainder >> (bits - 1));
777}
778
779static void
780scalar_compress(scalar *s, int bits)
781{
782 int i;
783
784 for (i = 0; i < DEGREE; i++) {
785 s->c[i] = compress(s->c[i], bits);
786 }
787}
788
789static void
790scalar_decompress(scalar *s, int bits)
791{
792 int i;
793
794 for (i = 0; i < DEGREE; i++) {
795 s->c[i] = decompress(s->c[i], bits);
796 }
797}
798
799static void
800vector_compress(scalar *v, int bits, size_t rank)
801{
802 size_t i;
803
804 for (i = 0; i < rank; i++) {
805 scalar_compress(&v[i], bits);
806 }
807}
808
809static void
810vector_decompress(scalar *v, int bits, size_t rank)
811{
812 int i;
813
814 for (i = 0; i < rank; i++) {
815 scalar_decompress(&v[i], bits);
816 }
817}
818
819struct public_key {
820 scalar *t;
821 uint8_t *rho;
822 uint8_t *public_key_hash;
823 scalar *m;
824};
825
826static void
827public_key_from_external(const MLKEM_public_key *external,
828 struct public_key *pub)
829{
830 size_t vector_size = external->rank * sizeof(scalar);
831 uint8_t *bytes = external->key_768->bytes;
832 size_t offset = 0;
833
834 if (external->rank == RANK1024)
835 bytes = external->key_1024->bytes;
836
837 pub->t = (struct scalar *)bytes + offset;
838 offset += vector_size;
839 pub->rho = bytes + offset;
840 offset += 32;
841 pub->public_key_hash = bytes + offset;
842 offset += 32;
843 pub->m = (void *)(bytes + offset);
844 offset += vector_size * external->rank;
845}
846
847struct private_key {
848 struct public_key pub;
849 scalar *s;
850 uint8_t *fo_failure_secret;
851};
852
853static void
854private_key_from_external(const MLKEM_private_key *external,
855 struct private_key *priv)
856{
857 size_t vector_size = external->rank * sizeof(scalar);
858 size_t offset = 0;
859 uint8_t *bytes = external->key_768->bytes;
860
861 if (external->rank == RANK1024)
862 bytes = external->key_1024->bytes;
863
864 priv->pub.t = (struct scalar *)(bytes + offset);
865 offset += vector_size;
866 priv->pub.rho = bytes + offset;
867 offset += 32;
868 priv->pub.public_key_hash = bytes + offset;
869 offset += 32;
870 priv->pub.m = (void *)(bytes + offset);
871 offset += vector_size * external->rank;
872 priv->s = (void *)(bytes + offset);
873 offset += vector_size;
874 priv->fo_failure_secret = bytes + offset;
875 offset += 32;
876}
877
878/*
879 * Calls |mlkem_generate_key_external_entropy| with random bytes from
880 * |RAND_bytes|.
881 */
882int
883mlkem_generate_key(uint8_t *out_encoded_public_key,
884 uint8_t optional_out_seed[MLKEM_SEED_LENGTH],
885 MLKEM_private_key *out_private_key)
886{
887 uint8_t entropy_buf[MLKEM_SEED_LENGTH];
888 uint8_t *entropy = optional_out_seed != NULL ? optional_out_seed :
889 entropy_buf;
890
891 arc4random_buf(entropy, MLKEM_SEED_LENGTH);
892 return mlkem_generate_key_external_entropy(out_encoded_public_key,
893 out_private_key, entropy);
894}
895
896int
897mlkem_private_key_from_seed(const uint8_t *seed, size_t seed_len,
898 MLKEM_private_key *out_private_key)
899{
900 uint8_t *public_key_buf = NULL;
901 size_t public_key_buf_len = out_private_key->rank == RANK768 ?
902 MLKEM768_PUBLIC_KEY_BYTES : MLKEM1024_PUBLIC_KEY_BYTES;
903 int ret = 0;
904
905 if (seed_len != MLKEM_SEED_LENGTH) {
906 goto err;
907 }
908
909 if ((public_key_buf = calloc(1, public_key_buf_len)) == NULL)
910 goto err;
911
912 ret = mlkem_generate_key_external_entropy(public_key_buf,
913 out_private_key, seed);
914
915 err:
916 freezero(public_key_buf, public_key_buf_len);
917
918 return ret;
919}
920
921static int
922mlkem_marshal_public_key_internal(CBB *out, const struct public_key *pub,
923 size_t rank)
924{
925 if (!vector_encode_cbb(out, &pub->t[0], kLog2Prime, rank))
926 return 0;
927 return CBB_add_bytes(out, pub->rho, 32);
928}
929
930int
931mlkem_generate_key_external_entropy(uint8_t *out_encoded_public_key,
932 MLKEM_private_key *out_private_key,
933 const uint8_t entropy[MLKEM_SEED_LENGTH])
934{
935 struct private_key priv;
936 uint8_t augmented_seed[33];
937 uint8_t *rho, *sigma;
938 uint8_t counter = 0;
939 uint8_t hashed[64];
940 scalar error[RANK1024];
941 CBB cbb;
942 int ret = 0;
943
944 private_key_from_external(out_private_key, &priv);
945 memset(&cbb, 0, sizeof(CBB));
946 memcpy(augmented_seed, entropy, 32);
947 augmented_seed[32] = out_private_key->rank;
948 hash_g(hashed, augmented_seed, 33);
949 rho = hashed;
950 sigma = hashed + 32;
951 memcpy(priv.pub.rho, hashed, 32);
952 matrix_expand(priv.pub.m, rho, out_private_key->rank);
953 vector_generate_secret_eta_2(priv.s, &counter, sigma,
954 out_private_key->rank);
955 vector_ntt(priv.s, out_private_key->rank);
956 vector_generate_secret_eta_2(&error[0], &counter, sigma,
957 out_private_key->rank);
958 vector_ntt(&error[0], out_private_key->rank);
959 matrix_mult_transpose(priv.pub.t, priv.pub.m, priv.s,
960 out_private_key->rank);
961 vector_add(priv.pub.t, &error[0], out_private_key->rank);
962
963 if (!CBB_init_fixed(&cbb, out_encoded_public_key,
964 out_private_key->rank == RANK768 ? MLKEM768_PUBLIC_KEY_BYTES :
965 MLKEM1024_PUBLIC_KEY_BYTES))
966 goto err;
967
968 if (!mlkem_marshal_public_key_internal(&cbb, &priv.pub,
969 out_private_key->rank))
970 goto err;
971
972 hash_h(priv.pub.public_key_hash, out_encoded_public_key,
973 out_private_key->rank == RANK768 ? MLKEM768_PUBLIC_KEY_BYTES :
974 MLKEM1024_PUBLIC_KEY_BYTES);
975 memcpy(priv.fo_failure_secret, entropy + 32, 32);
976
977 ret = 1;
978
979 err:
980 CBB_cleanup(&cbb);
981
982 return ret;
983}
984
985void
986mlkem_public_from_private(const MLKEM_private_key *private_key,
987 MLKEM_public_key *out_public_key)
988{
989 switch (private_key->rank) {
990 case RANK768:
991 memcpy(out_public_key->key_768->bytes,
992 private_key->key_768->bytes,
993 sizeof(struct MLKEM768_public_key));
994 break;
995 case RANK1024:
996 memcpy(out_public_key->key_1024->bytes,
997 private_key->key_1024->bytes,
998 sizeof(struct MLKEM1024_public_key));
999 break;
1000 }
1001}
1002
1003/*
1004 * Encrypts a message with given randomness to the ciphertext in |out|. Without
1005 * applying the Fujisaki-Okamoto transform this would not result in a CCA secure
1006 * scheme, since lattice schemes are vulnerable to decryption failure oracles.
1007 */
1008static void
1009encrypt_cpa(uint8_t *out, const struct public_key *pub,
1010 const uint8_t message[32], const uint8_t randomness[32],
1011 size_t rank)
1012{
1013 scalar secret[RANK1024], error[RANK1024], u[RANK1024];
1014 scalar expanded_message, scalar_error;
1015 uint8_t counter = 0;
1016 uint8_t input[33];
1017 scalar v;
1018 int u_bits = kDU768;
1019 int v_bits = kDV768;
1020
1021 if (rank == RANK1024) {
1022 u_bits = kDU1024;
1023 v_bits = kDV1024;
1024 }
1025 vector_generate_secret_eta_2(&secret[0], &counter, randomness, rank);
1026 vector_ntt(&secret[0], rank);
1027 vector_generate_secret_eta_2(&error[0], &counter, randomness, rank);
1028 memcpy(input, randomness, 32);
1029 input[32] = counter;
1030 scalar_centered_binomial_distribution_eta_2_with_prf(&scalar_error,
1031 input);
1032 matrix_mult(&u[0], pub->m, &secret[0], rank);
1033 vector_inverse_ntt(&u[0], rank);
1034 vector_add(&u[0], &error[0], rank);
1035 scalar_inner_product(&v, &pub->t[0], &secret[0], rank);
1036 scalar_inverse_ntt(&v);
1037 scalar_add(&v, &scalar_error);
1038 scalar_decode_1(&expanded_message, message);
1039 scalar_decompress(&expanded_message, 1);
1040 scalar_add(&v, &expanded_message);
1041 vector_compress(&u[0], u_bits, rank);
1042 vector_encode(out, &u[0], u_bits, rank);
1043 scalar_compress(&v, v_bits);
1044 scalar_encode(out + compressed_vector_size(rank), &v, v_bits);
1045}
1046
1047/* Calls mlkem_encap_external_entropy| with random bytes */
1048void
1049mlkem_encap(const MLKEM_public_key *public_key,
1050 uint8_t *out_ciphertext,
1051 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_LENGTH])
1052{
1053 uint8_t entropy[MLKEM_ENCAP_ENTROPY];
1054
1055 arc4random_buf(entropy, MLKEM_ENCAP_ENTROPY);
1056 mlkem_encap_external_entropy(out_ciphertext,
1057 out_shared_secret, public_key, entropy);
1058}
1059
1060/* See section 6.2 of the spec. */
1061void
1062mlkem_encap_external_entropy(uint8_t *out_ciphertext,
1063 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_LENGTH],
1064 const MLKEM_public_key *public_key,
1065 const uint8_t entropy[MLKEM_ENCAP_ENTROPY])
1066{
1067 struct public_key pub;
1068 uint8_t key_and_randomness[64];
1069 uint8_t input[64];
1070
1071 public_key_from_external(public_key, &pub);
1072 memcpy(input, entropy, MLKEM_ENCAP_ENTROPY);
1073 memcpy(input + MLKEM_ENCAP_ENTROPY, pub.public_key_hash,
1074 sizeof(input) - MLKEM_ENCAP_ENTROPY);
1075 hash_g(key_and_randomness, input, sizeof(input));
1076 encrypt_cpa(out_ciphertext, &pub, entropy, key_and_randomness + 32,
1077 public_key->rank);
1078 memcpy(out_shared_secret, key_and_randomness, 32);
1079}
1080
1081static void
1082decrypt_cpa(uint8_t out[32], const struct private_key *priv,
1083 const uint8_t *ciphertext, size_t rank)
1084{
1085 scalar u[RANK1024];
1086 scalar mask, v;
1087 int u_bits = kDU768;
1088 int v_bits = kDV768;
1089
1090 if (rank == RANK1024) {
1091 u_bits = kDU1024;
1092 v_bits = kDV1024;
1093 }
1094 vector_decode(&u[0], ciphertext, u_bits, rank);
1095 vector_decompress(&u[0], u_bits, rank);
1096 vector_ntt(&u[0], rank);
1097 scalar_decode(&v, ciphertext + compressed_vector_size(rank), v_bits);
1098 scalar_decompress(&v, v_bits);
1099 scalar_inner_product(&mask, &priv->s[0], &u[0], rank);
1100 scalar_inverse_ntt(&mask);
1101 scalar_sub(&v, &mask);
1102 scalar_compress(&v, 1);
1103 scalar_encode_1(out, &v);
1104}
1105
1106/* See section 6.3 */
1107int
1108mlkem_decap(const MLKEM_private_key *private_key, const uint8_t *ciphertext,
1109 size_t ciphertext_len, uint8_t out_shared_secret[MLKEM_SHARED_SECRET_LENGTH])
1110{
1111 struct private_key priv;
1112 size_t expected_ciphertext_length = private_key->rank == RANK768 ?
1113 MLKEM768_CIPHERTEXT_BYTES : MLKEM1024_CIPHERTEXT_BYTES;
1114 uint8_t *expected_ciphertext = NULL;
1115 uint8_t key_and_randomness[64];
1116 uint8_t failure_key[32];
1117 uint8_t decrypted[64];
1118 uint8_t mask;
1119 int i;
1120 int ret = 0;
1121
1122 if (ciphertext_len != expected_ciphertext_length) {
1123 arc4random_buf(out_shared_secret, MLKEM_SHARED_SECRET_LENGTH);
1124 goto err;
1125 }
1126
1127 if ((expected_ciphertext = calloc(1, expected_ciphertext_length)) ==
1128 NULL) {
1129 arc4random_buf(out_shared_secret, MLKEM_SHARED_SECRET_LENGTH);
1130 goto err;
1131 }
1132
1133 private_key_from_external(private_key, &priv);
1134 decrypt_cpa(decrypted, &priv, ciphertext, private_key->rank);
1135 memcpy(decrypted + 32, priv.pub.public_key_hash,
1136 sizeof(decrypted) - 32);
1137 hash_g(key_and_randomness, decrypted, sizeof(decrypted));
1138 encrypt_cpa(expected_ciphertext, &priv.pub, decrypted,
1139 key_and_randomness + 32, private_key->rank);
1140 kdf(failure_key, priv.fo_failure_secret, ciphertext, ciphertext_len);
1141 mask = constant_time_eq_int_8(memcmp(ciphertext, expected_ciphertext,
1142 expected_ciphertext_length), 0);
1143 for (i = 0; i < MLKEM_SHARED_SECRET_LENGTH; i++) {
1144 out_shared_secret[i] = constant_time_select_8(mask,
1145 key_and_randomness[i], failure_key[i]);
1146 }
1147
1148 ret = 1;
1149
1150 err:
1151 freezero(expected_ciphertext, expected_ciphertext_length);
1152
1153 return ret;
1154}
1155
1156int
1157mlkem_marshal_public_key(const MLKEM_public_key *public_key,
1158 uint8_t **output, size_t *output_len)
1159{
1160 struct public_key pub;
1161 int ret = 0;
1162 CBB cbb;
1163
1164 if (!CBB_init(&cbb, public_key->rank == RANK768 ?
1165 MLKEM768_PUBLIC_KEY_BYTES : MLKEM1024_PUBLIC_KEY_BYTES))
1166 goto err;
1167 public_key_from_external(public_key, &pub);
1168 if (!mlkem_marshal_public_key_internal(&cbb, &pub, public_key->rank))
1169 goto err;
1170 if (!CBB_finish(&cbb, output, output_len))
1171 goto err;
1172
1173 ret = 1;
1174
1175 err:
1176 CBB_cleanup(&cbb);
1177
1178 return ret;
1179}
1180
1181/*
1182 * mlkem_parse_public_key_no_hash parses |in| into |pub| but doesn't calculate
1183 * the value of |pub->public_key_hash|.
1184 */
1185static int
1186mlkem_parse_public_key_no_hash(struct public_key *pub, CBS *in, size_t rank)
1187{
1188 CBS t_bytes;
1189
1190 if (!CBS_get_bytes(in, &t_bytes, encoded_vector_size(rank)))
1191 return 0;
1192 if (!vector_decode(&pub->t[0], CBS_data(&t_bytes), kLog2Prime, rank))
1193 return 0;
1194
1195 memcpy(pub->rho, CBS_data(in), 32);
1196 if (!CBS_skip(in, 32))
1197 return 0;
1198 matrix_expand(pub->m, pub->rho, rank);
1199 return 1;
1200}
1201
1202int
1203mlkem_parse_public_key(const uint8_t *input, size_t input_len,
1204 MLKEM_public_key *public_key)
1205{
1206 struct public_key pub;
1207 CBS cbs;
1208
1209 public_key_from_external(public_key, &pub);
1210 CBS_init(&cbs, input, input_len);
1211 if (!mlkem_parse_public_key_no_hash(&pub, &cbs, public_key->rank))
1212 return 0;
1213 if (CBS_len(&cbs) != 0)
1214 return 0;
1215
1216 hash_h(pub.public_key_hash, input, input_len);
1217
1218 return 1;
1219}
1220
1221int
1222mlkem_marshal_private_key(const MLKEM_private_key *private_key,
1223 uint8_t **out_private_key, size_t *out_private_key_len)
1224{
1225 struct private_key priv;
1226 size_t key_length = private_key->rank == RANK768 ?
1227 MLKEM768_PRIVATE_KEY_BYTES : MLKEM1024_PRIVATE_KEY_BYTES;
1228 CBB cbb;
1229 int ret = 0;
1230
1231 private_key_from_external(private_key, &priv);
1232 if (!CBB_init(&cbb, key_length))
1233 goto err;
1234
1235 if (!vector_encode_cbb(&cbb, priv.s, kLog2Prime, private_key->rank))
1236 goto err;
1237 if (!mlkem_marshal_public_key_internal(&cbb, &priv.pub,
1238 private_key->rank))
1239 goto err;
1240 if (!CBB_add_bytes(&cbb, priv.pub.public_key_hash, 32))
1241 goto err;
1242 if (!CBB_add_bytes(&cbb, priv.fo_failure_secret, 32))
1243 goto err;
1244
1245 if (!CBB_finish(&cbb, out_private_key, out_private_key_len))
1246 goto err;
1247
1248 ret = 1;
1249
1250 err:
1251 CBB_cleanup(&cbb);
1252
1253 return ret;
1254}
1255
1256int
1257mlkem_parse_private_key(const uint8_t *input, size_t input_len,
1258 MLKEM_private_key *out_private_key)
1259{
1260 struct private_key priv;
1261 CBS cbs, s_bytes;
1262
1263 private_key_from_external(out_private_key, &priv);
1264 CBS_init(&cbs, input, input_len);
1265
1266 if (!CBS_get_bytes(&cbs, &s_bytes,
1267 encoded_vector_size(out_private_key->rank)))
1268 return 0;
1269 if (!vector_decode(priv.s, CBS_data(&s_bytes), kLog2Prime,
1270 out_private_key->rank))
1271 return 0;
1272 if (!mlkem_parse_public_key_no_hash(&priv.pub, &cbs,
1273 out_private_key->rank))
1274 return 0;
1275
1276 memcpy(priv.pub.public_key_hash, CBS_data(&cbs), 32);
1277 if (!CBS_skip(&cbs, 32))
1278 return 0;
1279 memcpy(priv.fo_failure_secret, CBS_data(&cbs), 32);
1280 if (!CBS_skip(&cbs, 32))
1281 return 0;
1282 if (CBS_len(&cbs) != 0)
1283 return 0;
1284
1285 return 1;
1286}