summaryrefslogtreecommitdiff
path: root/src/lib/libcrypto/mlkem/mlkem1024.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib/libcrypto/mlkem/mlkem1024.c')
-rw-r--r--src/lib/libcrypto/mlkem/mlkem1024.c1183
1 files changed, 0 insertions, 1183 deletions
diff --git a/src/lib/libcrypto/mlkem/mlkem1024.c b/src/lib/libcrypto/mlkem/mlkem1024.c
deleted file mode 100644
index 8f4f41f8ff..0000000000
--- a/src/lib/libcrypto/mlkem/mlkem1024.c
+++ /dev/null
@@ -1,1183 +0,0 @@
1/* $OpenBSD: mlkem1024.c,v 1.12 2025/08/14 15:48:48 beck Exp $ */
2/*
3 * Copyright (c) 2024, Google Inc.
4 * Copyright (c) 2024, 2025 Bob Beck <beck@obtuse.com>
5 *
6 * Permission to use, copy, modify, and/or distribute this software for any
7 * purpose with or without fee is hereby granted, provided that the above
8 * copyright notice and this permission notice appear in all copies.
9 *
10 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
13 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
15 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
16 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17 */
18
19#include <assert.h>
20#include <stdlib.h>
21#include <string.h>
22
23#include <openssl/mlkem.h>
24
25#include "bytestring.h"
26#include "sha3_internal.h"
27#include "mlkem_internal.h"
28#include "constant_time.h"
29#include "crypto_internal.h"
30
31/*
32 * See
33 * https://csrc.nist.gov/pubs/fips/203/final
34 */
35
36static void
37prf(uint8_t *out, size_t out_len, const uint8_t in[33])
38{
39 sha3_ctx ctx;
40 shake256_init(&ctx);
41 shake_update(&ctx, in, 33);
42 shake_xof(&ctx);
43 shake_out(&ctx, out, out_len);
44}
45
46/* Section 4.1 */
47static void
48hash_h(uint8_t out[32], const uint8_t *in, size_t len)
49{
50 sha3_ctx ctx;
51 sha3_init(&ctx, 32);
52 sha3_update(&ctx, in, len);
53 sha3_final(out, &ctx);
54}
55
56static void
57hash_g(uint8_t out[64], const uint8_t *in, size_t len)
58{
59 sha3_ctx ctx;
60 sha3_init(&ctx, 64);
61 sha3_update(&ctx, in, len);
62 sha3_final(out, &ctx);
63}
64
65/* this is called 'J' in the spec */
66static void
67kdf(uint8_t out[MLKEM_SHARED_SECRET_BYTES], const uint8_t failure_secret[32],
68 const uint8_t *in, size_t len)
69{
70 sha3_ctx ctx;
71 shake256_init(&ctx);
72 shake_update(&ctx, failure_secret, 32);
73 shake_update(&ctx, in, len);
74 shake_xof(&ctx);
75 shake_out(&ctx, out, MLKEM_SHARED_SECRET_BYTES);
76}
77
78#define DEGREE 256
79
80static const size_t kBarrettMultiplier = 5039;
81static const unsigned kBarrettShift = 24;
82static const uint16_t kPrime = 3329;
83static const int kLog2Prime = 12;
84static const uint16_t kHalfPrime = (/*kPrime=*/3329 - 1) / 2;
85static const int kDU1024 = 11;
86static const int kDV1024 = 5;
87
88/*
89 * kInverseDegree is 128^-1 mod 3329; 128 because kPrime does not have a 512th
90 * root of unity.
91 */
92static const uint16_t kInverseDegree = 3303;
93static const size_t kEncodedVectorSize =
94 (/*kLog2Prime=*/12 * DEGREE / 8) * RANK1024;
95static const size_t kCompressedVectorSize = /*kDU1024=*/ 11 * RANK1024 * DEGREE /
96 8;
97
98typedef struct scalar {
99 /* On every function entry and exit, 0 <= c < kPrime. */
100 uint16_t c[DEGREE];
101} scalar;
102
103typedef struct vector {
104 scalar v[RANK1024];
105} vector;
106
107typedef struct matrix {
108 scalar v[RANK1024][RANK1024];
109} matrix;
110
111/*
112 * This bit of Python will be referenced in some of the following comments:
113 *
114 * p = 3329
115 *
116 * def bitreverse(i):
117 * ret = 0
118 * for n in range(7):
119 * bit = i & 1
120 * ret <<= 1
121 * ret |= bit
122 * i >>= 1
123 * return ret
124 */
125
126/* kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)] */
127static const uint16_t kNTTRoots[128] = {
128 1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797,
129 2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333,
130 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756,
131 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915,
132 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648,
133 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100,
134 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789,
135 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641,
136 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757,
137 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886,
138 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154,
139};
140
141/* kInverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)] */
142static const uint16_t kInverseNTTRoots[128] = {
143 1, 1600, 40, 749, 2481, 1432, 2699, 687, 1583, 2760, 69, 543,
144 2532, 3136, 1410, 2267, 2508, 1355, 450, 936, 447, 2794, 1235, 1903,
145 1996, 1089, 3273, 283, 1853, 1990, 882, 3033, 2419, 2102, 219, 855,
146 2681, 1848, 712, 682, 927, 1795, 461, 1891, 2877, 2522, 1894, 1010,
147 1414, 2009, 3296, 464, 2697, 816, 1352, 2679, 1274, 1052, 1025, 2132,
148 1573, 76, 2998, 3040, 1175, 2444, 394, 1219, 2300, 1455, 2117, 1607,
149 2443, 554, 1179, 2186, 2303, 2926, 2237, 525, 735, 863, 2768, 1230,
150 2572, 556, 3010, 2266, 1684, 1239, 780, 2954, 109, 1292, 1031, 1745,
151 2688, 3061, 992, 2596, 941, 892, 1021, 2390, 642, 1868, 2377, 1482,
152 1540, 540, 1678, 1626, 279, 314, 1173, 2573, 3096, 48, 667, 1920,
153 2229, 1041, 2606, 1692, 680, 2746, 568, 3312,
154};
155
156/* kModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)] */
157static const uint16_t kModRoots[128] = {
158 17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606,
159 2288, 1041, 1100, 2229, 1409, 1920, 2662, 667, 3281, 48, 233, 3096,
160 756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678,
161 2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642,
162 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992,
163 268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109,
164 375, 2954, 2549, 780, 2090, 1239, 1645, 1684, 1063, 2266, 319, 3010,
165 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735,
166 2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179,
167 2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300,
168 2110, 1219, 2935, 394, 885, 2444, 2154, 1175,
169};
170
171/* reduce_once reduces 0 <= x < 2*kPrime, mod kPrime. */
172static uint16_t
173reduce_once(uint16_t x)
174{
175 assert(x < 2 * kPrime);
176 const uint16_t subtracted = x - kPrime;
177 uint16_t mask = 0u - (subtracted >> 15);
178
179 /*
180 * Although this is a constant-time select, we omit a value barrier here.
181 * Value barriers impede auto-vectorization (likely because it forces the
182 * value to transit through a general-purpose register). On AArch64, this
183 * is a difference of 2x.
184 *
185 * We usually add value barriers to selects because Clang turns
186 * consecutive selects with the same condition into a branch instead of
187 * CMOV/CSEL. This condition does not occur in ML-KEM, so omitting it
188 * seems to be safe so far but see
189 * |scalar_centered_binomial_distribution_eta_2_with_prf|.
190 */
191 return (mask & x) | (~mask & subtracted);
192}
193
194/*
195 * constant time reduce x mod kPrime using Barrett reduction. x must be less
196 * than kPrime + 2×kPrime².
197 */
198static uint16_t
199reduce(uint32_t x)
200{
201 uint64_t product = (uint64_t)x * kBarrettMultiplier;
202 uint32_t quotient = (uint32_t)(product >> kBarrettShift);
203 uint32_t remainder = x - quotient * kPrime;
204
205 assert(x < kPrime + 2u * kPrime * kPrime);
206 return reduce_once(remainder);
207}
208
209static void
210scalar_zero(scalar *out)
211{
212 memset(out, 0, sizeof(*out));
213}
214
215static void
216vector_zero(vector *out)
217{
218 memset(out, 0, sizeof(*out));
219}
220
221/*
222 * In place number theoretic transform of a given scalar.
223 * Note that MLKEM's kPrime 3329 does not have a 512th root of unity, so this
224 * transform leaves off the last iteration of the usual FFT code, with the 128
225 * relevant roots of unity being stored in |kNTTRoots|. This means the output
226 * should be seen as 128 elements in GF(3329^2), with the coefficients of the
227 * elements being consecutive entries in |s->c|.
228 */
229static void
230scalar_ntt(scalar *s)
231{
232 int offset = DEGREE;
233 int step;
234 /*
235 * `int` is used here because using `size_t` throughout caused a ~5% slowdown
236 * with Clang 14 on Aarch64.
237 */
238 for (step = 1; step < DEGREE / 2; step <<= 1) {
239 int i, j, k = 0;
240
241 offset >>= 1;
242 for (i = 0; i < step; i++) {
243 const uint32_t step_root = kNTTRoots[i + step];
244
245 for (j = k; j < k + offset; j++) {
246 uint16_t odd, even;
247
248 odd = reduce(step_root * s->c[j + offset]);
249 even = s->c[j];
250 s->c[j] = reduce_once(odd + even);
251 s->c[j + offset] = reduce_once(even - odd +
252 kPrime);
253 }
254 k += 2 * offset;
255 }
256 }
257}
258
259static void
260vector_ntt(vector *a)
261{
262 int i;
263
264 for (i = 0; i < RANK1024; i++) {
265 scalar_ntt(&a->v[i]);
266 }
267}
268
269/*
270 * In place inverse number theoretic transform of a given scalar, with pairs of
271 * entries of s->v being interpreted as elements of GF(3329^2). Just as with the
272 * number theoretic transform, this leaves off the first step of the normal iFFT
273 * to account for the fact that 3329 does not have a 512th root of unity, using
274 * the precomputed 128 roots of unity stored in |kInverseNTTRoots|.
275 */
276static void
277scalar_inverse_ntt(scalar *s)
278{
279 int i, j, k, offset, step = DEGREE / 2;
280
281 /*
282 * `int` is used here because using `size_t` throughout caused a ~5% slowdown
283 * with Clang 14 on Aarch64.
284 */
285 for (offset = 2; offset < DEGREE; offset <<= 1) {
286 step >>= 1;
287 k = 0;
288 for (i = 0; i < step; i++) {
289 uint32_t step_root = kInverseNTTRoots[i + step];
290 for (j = k; j < k + offset; j++) {
291 uint16_t odd, even;
292 odd = s->c[j + offset];
293 even = s->c[j];
294 s->c[j] = reduce_once(odd + even);
295 s->c[j + offset] = reduce(step_root *
296 (even - odd + kPrime));
297 }
298 k += 2 * offset;
299 }
300 }
301 for (i = 0; i < DEGREE; i++) {
302 s->c[i] = reduce(s->c[i] * kInverseDegree);
303 }
304}
305
306static void
307vector_inverse_ntt(vector *a)
308{
309 int i;
310
311 for (i = 0; i < RANK1024; i++) {
312 scalar_inverse_ntt(&a->v[i]);
313 }
314}
315
316static void
317scalar_add(scalar *lhs, const scalar *rhs)
318{
319 int i;
320
321 for (i = 0; i < DEGREE; i++) {
322 lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
323 }
324}
325
326static void
327scalar_sub(scalar *lhs, const scalar *rhs)
328{
329 int i;
330
331 for (i = 0; i < DEGREE; i++) {
332 lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
333 }
334}
335
336/*
337 * Multiplying two scalars in the number theoretically transformed state.
338 * Since 3329 does not have a 512th root of unity, this means we have to
339 * interpret the 2*ith and (2*i+1)th entries of the scalar as elements of
340 * GF(3329)[X]/(X^2 - 17^(2*bitreverse(i)+1)).
341 * The value of 17^(2*bitreverse(i)+1) mod 3329 is stored in the precomputed
342 * |kModRoots| table. Our Barrett transform only allows us to multiply two
343 * reduced numbers together, so we need some intermediate reduction steps,
344 * even if an uint64_t could hold 3 multiplied numbers.
345 */
346static void
347scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs)
348{
349 int i;
350
351 for (i = 0; i < DEGREE / 2; i++) {
352 uint32_t real_real = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i];
353 uint32_t img_img = (uint32_t)lhs->c[2 * i + 1] *
354 rhs->c[2 * i + 1];
355 uint32_t real_img = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i + 1];
356 uint32_t img_real = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i];
357
358 out->c[2 * i] =
359 reduce(real_real +
360 (uint32_t)reduce(img_img) * kModRoots[i]);
361 out->c[2 * i + 1] = reduce(img_real + real_img);
362 }
363}
364
365static void
366vector_add(vector *lhs, const vector *rhs)
367{
368 int i;
369
370 for (i = 0; i < RANK1024; i++) {
371 scalar_add(&lhs->v[i], &rhs->v[i]);
372 }
373}
374
375static void
376matrix_mult(vector *out, const matrix *m, const vector *a)
377{
378 int i, j;
379
380 vector_zero(out);
381 for (i = 0; i < RANK1024; i++) {
382 for (j = 0; j < RANK1024; j++) {
383 scalar product;
384
385 scalar_mult(&product, &m->v[i][j], &a->v[j]);
386 scalar_add(&out->v[i], &product);
387 }
388 }
389}
390
391static void
392matrix_mult_transpose(vector *out, const matrix *m,
393 const vector *a)
394{
395 int i, j;
396
397 vector_zero(out);
398 for (i = 0; i < RANK1024; i++) {
399 for (j = 0; j < RANK1024; j++) {
400 scalar product;
401
402 scalar_mult(&product, &m->v[j][i], &a->v[j]);
403 scalar_add(&out->v[i], &product);
404 }
405 }
406}
407
408static void
409scalar_inner_product(scalar *out, const vector *lhs,
410 const vector *rhs)
411{
412 int i;
413 scalar_zero(out);
414 for (i = 0; i < RANK1024; i++) {
415 scalar product;
416
417 scalar_mult(&product, &lhs->v[i], &rhs->v[i]);
418 scalar_add(out, &product);
419 }
420}
421
422/*
423 * Algorithm 6 of spec. Rejection samples a Keccak stream to get uniformly
424 * distributed elements. This is used for matrix expansion and only operates on
425 * public inputs.
426 */
427static void
428scalar_from_keccak_vartime(scalar *out, sha3_ctx *keccak_ctx)
429{
430 int i, done = 0;
431
432 while (done < DEGREE) {
433 uint8_t block[168];
434
435 shake_out(keccak_ctx, block, sizeof(block));
436 for (i = 0; i < sizeof(block) && done < DEGREE; i += 3) {
437 uint16_t d1 = block[i] + 256 * (block[i + 1] % 16);
438 uint16_t d2 = block[i + 1] / 16 + 16 * block[i + 2];
439
440 if (d1 < kPrime) {
441 out->c[done++] = d1;
442 }
443 if (d2 < kPrime && done < DEGREE) {
444 out->c[done++] = d2;
445 }
446 }
447 }
448}
449
450/*
451 * Algorithm 7 of the spec, with eta fixed to two and the PRF call
452 * included. Creates binominally distributed elements by sampling 2*|eta| bits,
453 * and setting the coefficient to the count of the first bits minus the count of
454 * the second bits, resulting in a centered binomial distribution. Since eta is
455 * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
456 * and 0 with probability 3/8.
457 */
458static void
459scalar_centered_binomial_distribution_eta_2_with_prf(scalar *out,
460 const uint8_t input[33])
461{
462 uint8_t entropy[128];
463 int i;
464
465 CTASSERT(sizeof(entropy) == 2 * /*kEta=*/ 2 * DEGREE / 8);
466 prf(entropy, sizeof(entropy), input);
467
468 for (i = 0; i < DEGREE; i += 2) {
469 uint8_t byte = entropy[i / 2];
470 uint16_t mask;
471 uint16_t value = (byte & 1) + ((byte >> 1) & 1);
472
473 value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
474
475 /*
476 * Add |kPrime| if |value| underflowed. See |reduce_once| for a
477 * discussion on why the value barrier is omitted. While this
478 * could have been written reduce_once(value + kPrime), this is
479 * one extra addition and small range of |value| tempts some
480 * versions of Clang to emit a branch.
481 */
482 mask = 0u - (value >> 15);
483 out->c[i] = ((value + kPrime) & mask) | (value & ~mask);
484
485 byte >>= 4;
486 value = (byte & 1) + ((byte >> 1) & 1);
487 value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
488 /* See above. */
489 mask = 0u - (value >> 15);
490 out->c[i + 1] = ((value + kPrime) & mask) | (value & ~mask);
491 }
492}
493
494/*
495 * Generates a secret vector by using
496 * |scalar_centered_binomial_distribution_eta_2_with_prf|, using the given seed
497 * appending and incrementing |counter| for entry of the vector.
498 */
499static void
500vector_generate_secret_eta_2(vector *out, uint8_t *counter,
501 const uint8_t seed[32])
502{
503 uint8_t input[33];
504 int i;
505
506 memcpy(input, seed, 32);
507 for (i = 0; i < RANK1024; i++) {
508 input[32] = (*counter)++;
509 scalar_centered_binomial_distribution_eta_2_with_prf(&out->v[i],
510 input);
511 }
512}
513
514/* Expands the matrix of a seed for key generation and for encaps-CPA. */
515static void
516matrix_expand(matrix *out, const uint8_t rho[32])
517{
518 uint8_t input[34];
519 int i, j;
520
521 memcpy(input, rho, 32);
522 for (i = 0; i < RANK1024; i++) {
523 for (j = 0; j < RANK1024; j++) {
524 sha3_ctx keccak_ctx;
525
526 input[32] = i;
527 input[33] = j;
528 shake128_init(&keccak_ctx);
529 shake_update(&keccak_ctx, input, sizeof(input));
530 shake_xof(&keccak_ctx);
531 scalar_from_keccak_vartime(&out->v[i][j], &keccak_ctx);
532 }
533 }
534}
535
536static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f,
537 0x1f, 0x3f, 0x7f, 0xff};
538
539static void
540scalar_encode(uint8_t *out, const scalar *s, int bits)
541{
542 uint8_t out_byte = 0;
543 int i, out_byte_bits = 0;
544
545 assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1);
546 for (i = 0; i < DEGREE; i++) {
547 uint16_t element = s->c[i];
548 int element_bits_done = 0;
549
550 while (element_bits_done < bits) {
551 int chunk_bits = bits - element_bits_done;
552 int out_bits_remaining = 8 - out_byte_bits;
553
554 if (chunk_bits >= out_bits_remaining) {
555 chunk_bits = out_bits_remaining;
556 out_byte |= (element &
557 kMasks[chunk_bits - 1]) << out_byte_bits;
558 *out = out_byte;
559 out++;
560 out_byte_bits = 0;
561 out_byte = 0;
562 } else {
563 out_byte |= (element &
564 kMasks[chunk_bits - 1]) << out_byte_bits;
565 out_byte_bits += chunk_bits;
566 }
567
568 element_bits_done += chunk_bits;
569 element >>= chunk_bits;
570 }
571 }
572
573 if (out_byte_bits > 0) {
574 *out = out_byte;
575 }
576}
577
578/* scalar_encode_1 is |scalar_encode| specialised for |bits| == 1. */
579static void
580scalar_encode_1(uint8_t out[32], const scalar *s)
581{
582 int i, j;
583
584 for (i = 0; i < DEGREE; i += 8) {
585 uint8_t out_byte = 0;
586
587 for (j = 0; j < 8; j++) {
588 out_byte |= (s->c[i + j] & 1) << j;
589 }
590 *out = out_byte;
591 out++;
592 }
593}
594
595/*
596 * Encodes an entire vector into 32*|RANK1024|*|bits| bytes. Note that since 256
597 * (DEGREE) is divisible by 8, the individual vector entries will always fill a
598 * whole number of bytes, so we do not need to worry about bit packing here.
599 */
600static void
601vector_encode(uint8_t *out, const vector *a, int bits)
602{
603 int i;
604
605 for (i = 0; i < RANK1024; i++) {
606 scalar_encode(out + i * bits * DEGREE / 8, &a->v[i], bits);
607 }
608}
609
610/* Encodes an entire vector as above, but adding it to a CBB */
611static int
612vector_encode_cbb(CBB *cbb, const vector *a, int bits)
613{
614 uint8_t *encoded_vector;
615
616 if (!CBB_add_space(cbb, &encoded_vector, kEncodedVectorSize))
617 return 0;
618 vector_encode(encoded_vector, a, bits);
619
620 return 1;
621}
622
623/*
624 * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
625 * |out|. It returns one on success and zero if any parsed value is >=
626 * |kPrime|.
627 */
628static int
629scalar_decode(scalar *out, const uint8_t *in, int bits)
630{
631 uint8_t in_byte = 0;
632 int i, in_byte_bits_left = 0;
633
634 assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1);
635
636 for (i = 0; i < DEGREE; i++) {
637 uint16_t element = 0;
638 int element_bits_done = 0;
639
640 while (element_bits_done < bits) {
641 int chunk_bits = bits - element_bits_done;
642
643 if (in_byte_bits_left == 0) {
644 in_byte = *in;
645 in++;
646 in_byte_bits_left = 8;
647 }
648
649 if (chunk_bits > in_byte_bits_left) {
650 chunk_bits = in_byte_bits_left;
651 }
652
653 element |= (in_byte & kMasks[chunk_bits - 1]) <<
654 element_bits_done;
655 in_byte_bits_left -= chunk_bits;
656 in_byte >>= chunk_bits;
657
658 element_bits_done += chunk_bits;
659 }
660
661 if (element >= kPrime) {
662 return 0;
663 }
664 out->c[i] = element;
665 }
666
667 return 1;
668}
669
670/* scalar_decode_1 is |scalar_decode| specialised for |bits| == 1. */
671static void
672scalar_decode_1(scalar *out, const uint8_t in[32])
673{
674 int i, j;
675
676 for (i = 0; i < DEGREE; i += 8) {
677 uint8_t in_byte = *in;
678
679 in++;
680 for (j = 0; j < 8; j++) {
681 out->c[i + j] = in_byte & 1;
682 in_byte >>= 1;
683 }
684 }
685}
686
687/*
688 * Decodes 32*|RANK1024|*|bits| bytes from |in| into |out|. It returns one on
689 * success or zero if any parsed value is >= |kPrime|.
690 */
691static int
692vector_decode(vector *out, const uint8_t *in, int bits)
693{
694 int i;
695
696 for (i = 0; i < RANK1024; i++) {
697 if (!scalar_decode(&out->v[i], in + i * bits * DEGREE / 8,
698 bits)) {
699 return 0;
700 }
701 }
702 return 1;
703}
704
705/*
706 * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
707 * numbers close to each other together. The formula used is
708 * round(2^|bits|/kPrime*x) mod 2^|bits|.
709 * Uses Barrett reduction to achieve constant time. Since we need both the
710 * remainder (for rounding) and the quotient (as the result), we cannot use
711 * |reduce| here, but need to do the Barrett reduction directly.
712 */
713static uint16_t
714compress(uint16_t x, int bits)
715{
716 uint32_t shifted = (uint32_t)x << bits;
717 uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
718 uint32_t quotient = (uint32_t)(product >> kBarrettShift);
719 uint32_t remainder = shifted - quotient * kPrime;
720
721 /*
722 * Adjust the quotient to round correctly:
723 * 0 <= remainder <= kHalfPrime round to 0
724 * kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
725 * kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
726 */
727 assert(remainder < 2u * kPrime);
728 quotient += 1 & constant_time_lt(kHalfPrime, remainder);
729 quotient += 1 & constant_time_lt(kPrime + kHalfPrime, remainder);
730 return quotient & ((1 << bits) - 1);
731}
732
733/*
734 * Decompresses |x| by using an equi-distant representative. The formula is
735 * round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us to
736 * implement this logic using only bit operations.
737 */
738static uint16_t
739decompress(uint16_t x, int bits)
740{
741 uint32_t product = (uint32_t)x * kPrime;
742 uint32_t power = 1 << bits;
743 /* This is |product| % power, since |power| is a power of 2. */
744 uint32_t remainder = product & (power - 1);
745 /* This is |product| / power, since |power| is a power of 2. */
746 uint32_t lower = product >> bits;
747
748 /*
749 * The rounding logic works since the first half of numbers mod |power| have a
750 * 0 as first bit, and the second half has a 1 as first bit, since |power| is
751 * a power of 2. As a 12 bit number, |remainder| is always positive, so we
752 * will shift in 0s for a right shift.
753 */
754 return lower + (remainder >> (bits - 1));
755}
756
757static void
758scalar_compress(scalar *s, int bits)
759{
760 int i;
761
762 for (i = 0; i < DEGREE; i++) {
763 s->c[i] = compress(s->c[i], bits);
764 }
765}
766
767static void
768scalar_decompress(scalar *s, int bits)
769{
770 int i;
771
772 for (i = 0; i < DEGREE; i++) {
773 s->c[i] = decompress(s->c[i], bits);
774 }
775}
776
777static void
778vector_compress(vector *a, int bits)
779{
780 int i;
781
782 for (i = 0; i < RANK1024; i++) {
783 scalar_compress(&a->v[i], bits);
784 }
785}
786
787static void
788vector_decompress(vector *a, int bits)
789{
790 int i;
791
792 for (i = 0; i < RANK1024; i++) {
793 scalar_decompress(&a->v[i], bits);
794 }
795}
796
797struct public_key {
798 vector t;
799 uint8_t rho[32];
800 uint8_t public_key_hash[32];
801 matrix m;
802};
803
804CTASSERT(sizeof(struct MLKEM1024_public_key) == sizeof(struct public_key));
805
806static struct public_key *
807public_key_1024_from_external(const MLKEM_public_key *external)
808{
809 if (external->rank != RANK1024)
810 return NULL;
811 return (struct public_key *)external->key_1024;
812}
813
814struct private_key {
815 struct public_key pub;
816 vector s;
817 uint8_t fo_failure_secret[32];
818};
819
820CTASSERT(sizeof(struct MLKEM1024_private_key) == sizeof(struct private_key));
821
822static struct private_key *
823private_key_1024_from_external(const MLKEM_private_key *external)
824{
825 if (external->rank != RANK1024)
826 return NULL;
827 return (struct private_key *)external->key_1024;
828}
829
830/*
831 * Calls |MLKEM1024_generate_key_external_entropy| with random bytes from
832 * |RAND_bytes|.
833 */
834int
835MLKEM1024_generate_key(uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES],
836 uint8_t optional_out_seed[MLKEM_SEED_BYTES],
837 MLKEM_private_key *out_private_key)
838{
839 uint8_t entropy_buf[MLKEM_SEED_BYTES];
840 uint8_t *entropy = optional_out_seed != NULL ? optional_out_seed :
841 entropy_buf;
842
843 arc4random_buf(entropy, MLKEM_SEED_BYTES);
844 return MLKEM1024_generate_key_external_entropy(out_encoded_public_key,
845 out_private_key, entropy);
846}
847
848int
849MLKEM1024_private_key_from_seed(MLKEM_private_key *out_private_key,
850 const uint8_t *seed, size_t seed_len)
851{
852 uint8_t public_key_bytes[MLKEM1024_PUBLIC_KEY_BYTES];
853
854 if (seed_len != MLKEM_SEED_BYTES) {
855 return 0;
856 }
857 return MLKEM1024_generate_key_external_entropy(public_key_bytes,
858 out_private_key, seed);
859}
860
861static int
862mlkem_marshal_public_key(CBB *out, const struct public_key *pub)
863{
864 if (!vector_encode_cbb(out, &pub->t, kLog2Prime))
865 return 0;
866 return CBB_add_bytes(out, pub->rho, sizeof(pub->rho));
867}
868
869int
870MLKEM1024_generate_key_external_entropy(
871 uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES],
872 MLKEM_private_key *out_private_key,
873 const uint8_t entropy[MLKEM_SEED_BYTES])
874{
875 struct private_key *priv = private_key_1024_from_external(
876 out_private_key);
877 uint8_t augmented_seed[33];
878 uint8_t *rho, *sigma;
879 uint8_t counter = 0;
880 uint8_t hashed[64];
881 vector error;
882 CBB cbb;
883 int ret = 0;
884
885 memset(&cbb, 0, sizeof(CBB));
886 memcpy(augmented_seed, entropy, 32);
887 augmented_seed[32] = RANK1024;
888 hash_g(hashed, augmented_seed, 33);
889 rho = hashed;
890 sigma = hashed + 32;
891 memcpy(priv->pub.rho, hashed, sizeof(priv->pub.rho));
892 matrix_expand(&priv->pub.m, rho);
893 vector_generate_secret_eta_2(&priv->s, &counter, sigma);
894 vector_ntt(&priv->s);
895 vector_generate_secret_eta_2(&error, &counter, sigma);
896 vector_ntt(&error);
897 matrix_mult_transpose(&priv->pub.t, &priv->pub.m, &priv->s);
898 vector_add(&priv->pub.t, &error);
899
900 if (!CBB_init_fixed(&cbb, out_encoded_public_key,
901 MLKEM1024_PUBLIC_KEY_BYTES))
902 goto err;
903
904 if (!mlkem_marshal_public_key(&cbb, &priv->pub))
905 goto err;
906
907 hash_h(priv->pub.public_key_hash, out_encoded_public_key,
908 MLKEM1024_PUBLIC_KEY_BYTES);
909 memcpy(priv->fo_failure_secret, entropy + 32, 32);
910
911 ret = 1;
912
913 err:
914 CBB_cleanup(&cbb);
915
916 return ret;
917}
918
919void
920MLKEM1024_public_from_private(const MLKEM_private_key *private_key,
921 MLKEM_public_key *out_public_key)
922{
923 struct public_key *const pub = public_key_1024_from_external(
924 out_public_key);
925 const struct private_key *const priv = private_key_1024_from_external(
926 private_key);
927
928 *pub = priv->pub;
929}
930
931/*
932 * Encrypts a message with given randomness to the ciphertext in |out|. Without
933 * applying the Fujisaki-Okamoto transform this would not result in a CCA secure
934 * scheme, since lattice schemes are vulnerable to decryption failure oracles.
935 */
936static void
937encrypt_cpa(uint8_t out[MLKEM1024_CIPHERTEXT_BYTES],
938 const struct public_key *pub, const uint8_t message[32],
939 const uint8_t randomness[32])
940{
941 scalar expanded_message, scalar_error;
942 vector secret, error, u;
943 uint8_t counter = 0;
944 uint8_t input[33];
945 scalar v;
946
947 vector_generate_secret_eta_2(&secret, &counter, randomness);
948 vector_ntt(&secret);
949 vector_generate_secret_eta_2(&error, &counter, randomness);
950 memcpy(input, randomness, 32);
951 input[32] = counter;
952 scalar_centered_binomial_distribution_eta_2_with_prf(&scalar_error,
953 input);
954 matrix_mult(&u, &pub->m, &secret);
955 vector_inverse_ntt(&u);
956 vector_add(&u, &error);
957 scalar_inner_product(&v, &pub->t, &secret);
958 scalar_inverse_ntt(&v);
959 scalar_add(&v, &scalar_error);
960 scalar_decode_1(&expanded_message, message);
961 scalar_decompress(&expanded_message, 1);
962 scalar_add(&v, &expanded_message);
963 vector_compress(&u, kDU1024);
964 vector_encode(out, &u, kDU1024);
965 scalar_compress(&v, kDV1024);
966 scalar_encode(out + kCompressedVectorSize, &v, kDV1024);
967}
968
969/* Calls MLKEM1024_encap_external_entropy| with random bytes */
970void
971MLKEM1024_encap(const MLKEM_public_key *public_key,
972 uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES],
973 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES])
974{
975 uint8_t entropy[MLKEM_ENCAP_ENTROPY];
976
977 arc4random_buf(entropy, MLKEM_ENCAP_ENTROPY);
978 MLKEM1024_encap_external_entropy(out_ciphertext, out_shared_secret,
979 public_key, entropy);
980}
981
982/* See section 6.2 of the spec. */
983void
984MLKEM1024_encap_external_entropy(
985 uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES],
986 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
987 const MLKEM_public_key *public_key,
988 const uint8_t entropy[MLKEM_ENCAP_ENTROPY])
989{
990 const struct public_key *pub = public_key_1024_from_external(public_key);
991 uint8_t key_and_randomness[64];
992 uint8_t input[64];
993
994 memcpy(input, entropy, MLKEM_ENCAP_ENTROPY);
995 memcpy(input + MLKEM_ENCAP_ENTROPY, pub->public_key_hash,
996 sizeof(input) - MLKEM_ENCAP_ENTROPY);
997 hash_g(key_and_randomness, input, sizeof(input));
998 encrypt_cpa(out_ciphertext, pub, entropy, key_and_randomness + 32);
999 memcpy(out_shared_secret, key_and_randomness, 32);
1000}
1001
1002static void
1003decrypt_cpa(uint8_t out[32], const struct private_key *priv,
1004 const uint8_t ciphertext[MLKEM1024_CIPHERTEXT_BYTES])
1005{
1006 scalar mask, v;
1007 vector u;
1008
1009 vector_decode(&u, ciphertext, kDU1024);
1010 vector_decompress(&u, kDU1024);
1011 vector_ntt(&u);
1012 scalar_decode(&v, ciphertext + kCompressedVectorSize, kDV1024);
1013 scalar_decompress(&v, kDV1024);
1014 scalar_inner_product(&mask, &priv->s, &u);
1015 scalar_inverse_ntt(&mask);
1016 scalar_sub(&v, &mask);
1017 scalar_compress(&v, 1);
1018 scalar_encode_1(out, &v);
1019}
1020
1021/* See section 6.3 */
1022int
1023MLKEM1024_decap(const MLKEM_private_key *private_key,
1024 const uint8_t *ciphertext, size_t ciphertext_len,
1025 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES])
1026 {
1027 const struct private_key *priv = private_key_1024_from_external(
1028 private_key);
1029 uint8_t expected_ciphertext[MLKEM1024_CIPHERTEXT_BYTES];
1030 uint8_t key_and_randomness[64];
1031 uint8_t failure_key[32];
1032 uint8_t decrypted[64];
1033 uint8_t mask;
1034 int i;
1035
1036 if (ciphertext_len != MLKEM1024_CIPHERTEXT_BYTES) {
1037 arc4random_buf(out_shared_secret, MLKEM_SHARED_SECRET_BYTES);
1038 return 0;
1039 }
1040
1041 decrypt_cpa(decrypted, priv, ciphertext);
1042 memcpy(decrypted + 32, priv->pub.public_key_hash,
1043 sizeof(decrypted) - 32);
1044 hash_g(key_and_randomness, decrypted, sizeof(decrypted));
1045 encrypt_cpa(expected_ciphertext, &priv->pub, decrypted,
1046 key_and_randomness + 32);
1047 kdf(failure_key, priv->fo_failure_secret, ciphertext, ciphertext_len);
1048 mask = constant_time_eq_int_8(memcmp(ciphertext, expected_ciphertext,
1049 sizeof(expected_ciphertext)), 0);
1050 for (i = 0; i < MLKEM_SHARED_SECRET_BYTES; i++) {
1051 out_shared_secret[i] = constant_time_select_8(mask,
1052 key_and_randomness[i], failure_key[i]);
1053 }
1054
1055 return 1;
1056}
1057
1058int
1059MLKEM1024_marshal_public_key(const MLKEM_public_key *public_key,
1060 uint8_t **output, size_t *output_len)
1061{
1062 int ret = 0;
1063 CBB cbb;
1064
1065 if (!CBB_init(&cbb, MLKEM1024_PUBLIC_KEY_BYTES))
1066 goto err;
1067 if (!mlkem_marshal_public_key(&cbb,
1068 public_key_1024_from_external(public_key)))
1069 goto err;
1070 if (!CBB_finish(&cbb, output, output_len))
1071 goto err;
1072
1073 ret = 1;
1074
1075 err:
1076 CBB_cleanup(&cbb);
1077
1078 return ret;
1079}
1080
1081/*
1082 * mlkem_parse_public_key_no_hash parses |in| into |pub| but doesn't calculate
1083 * the value of |pub->public_key_hash|.
1084 */
1085static int
1086mlkem_parse_public_key_no_hash(struct public_key *pub, CBS *in)
1087{
1088 CBS t_bytes;
1089
1090 if (!CBS_get_bytes(in, &t_bytes, kEncodedVectorSize))
1091 return 0;
1092 if (!vector_decode(&pub->t, CBS_data(&t_bytes), kLog2Prime))
1093 return 0;
1094
1095 memcpy(pub->rho, CBS_data(in), sizeof(pub->rho));
1096 if (!CBS_skip(in, sizeof(pub->rho)))
1097 return 0;
1098 matrix_expand(&pub->m, pub->rho);
1099 return 1;
1100}
1101
1102int
1103MLKEM1024_parse_public_key(const uint8_t *input, size_t input_len,
1104 MLKEM_public_key *public_key)
1105{
1106 struct public_key *pub = public_key_1024_from_external(public_key);
1107 CBS cbs;
1108
1109 CBS_init(&cbs, input, input_len);
1110 if (!mlkem_parse_public_key_no_hash(pub, &cbs))
1111 return 0;
1112 if (CBS_len(&cbs) != 0)
1113 return 0;
1114
1115 hash_h(pub->public_key_hash, input, input_len);
1116
1117 return 1;
1118}
1119
1120int
1121MLKEM1024_marshal_private_key(const MLKEM_private_key *private_key,
1122 uint8_t **out_private_key, size_t *out_private_key_len)
1123{
1124 const struct private_key *const priv = private_key_1024_from_external(
1125 private_key);
1126 CBB cbb;
1127 int ret = 0;
1128
1129 if (!CBB_init(&cbb, MLKEM1024_PRIVATE_KEY_BYTES))
1130 goto err;
1131
1132 if (!vector_encode_cbb(&cbb, &priv->s, kLog2Prime))
1133 goto err;
1134 if (!mlkem_marshal_public_key(&cbb, &priv->pub))
1135 goto err;
1136 if (!CBB_add_bytes(&cbb, priv->pub.public_key_hash,
1137 sizeof(priv->pub.public_key_hash)))
1138 goto err;
1139 if (!CBB_add_bytes(&cbb, priv->fo_failure_secret,
1140 sizeof(priv->fo_failure_secret)))
1141 goto err;
1142
1143 if (!CBB_finish(&cbb, out_private_key, out_private_key_len))
1144 goto err;
1145
1146 ret = 1;
1147
1148 err:
1149 CBB_cleanup(&cbb);
1150
1151 return ret;
1152}
1153
1154int
1155MLKEM1024_parse_private_key(const uint8_t *input, size_t input_len,
1156 MLKEM_private_key *out_private_key)
1157{
1158 struct private_key *const priv = private_key_1024_from_external(
1159 out_private_key);
1160 CBS cbs, s_bytes;
1161
1162 CBS_init(&cbs, input, input_len);
1163
1164 if (!CBS_get_bytes(&cbs, &s_bytes, kEncodedVectorSize))
1165 return 0;
1166 if (!vector_decode(&priv->s, CBS_data(&s_bytes), kLog2Prime))
1167 return 0;
1168 if (!mlkem_parse_public_key_no_hash(&priv->pub, &cbs))
1169 return 0;
1170
1171 memcpy(priv->pub.public_key_hash, CBS_data(&cbs),
1172 sizeof(priv->pub.public_key_hash));
1173 if (!CBS_skip(&cbs, sizeof(priv->pub.public_key_hash)))
1174 return 0;
1175 memcpy(priv->fo_failure_secret, CBS_data(&cbs),
1176 sizeof(priv->fo_failure_secret));
1177 if (!CBS_skip(&cbs, sizeof(priv->fo_failure_secret)))
1178 return 0;
1179 if (CBS_len(&cbs) != 0)
1180 return 0;
1181
1182 return 1;
1183}