summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/lib/libcrypto/Makefile5
-rw-r--r--src/lib/libcrypto/mlkem/mlkem.c106
-rw-r--r--src/lib/libcrypto/mlkem/mlkem1024.c1183
-rw-r--r--src/lib/libcrypto/mlkem/mlkem_internal.c (renamed from src/lib/libcrypto/mlkem/mlkem768.c)557
-rw-r--r--src/lib/libcrypto/mlkem/mlkem_internal.h460
-rw-r--r--src/lib/libcrypto/mlkem/mlkem_key.c4
6 files changed, 527 insertions, 1788 deletions
diff --git a/src/lib/libcrypto/Makefile b/src/lib/libcrypto/Makefile
index 5ee28b0e6c..92866400c2 100644
--- a/src/lib/libcrypto/Makefile
+++ b/src/lib/libcrypto/Makefile
@@ -1,4 +1,4 @@
1# $OpenBSD: Makefile,v 1.243 2025/08/25 16:48:01 tb Exp $ 1# $OpenBSD: Makefile,v 1.244 2025/09/05 23:30:12 beck Exp $
2 2
3LIB= crypto 3LIB= crypto
4LIBREBUILD=y 4LIBREBUILD=y
@@ -375,8 +375,7 @@ SRCS+= md5.c
375 375
376# mlkem/ 376# mlkem/
377SRCS+= mlkem.c 377SRCS+= mlkem.c
378SRCS+= mlkem768.c 378SRCS+= mlkem_internal.c
379SRCS+= mlkem1024.c
380SRCS+= mlkem_key.c 379SRCS+= mlkem_key.c
381 380
382# modes/ 381# modes/
diff --git a/src/lib/libcrypto/mlkem/mlkem.c b/src/lib/libcrypto/mlkem/mlkem.c
index dcc73c2631..9461a338e9 100644
--- a/src/lib/libcrypto/mlkem/mlkem.c
+++ b/src/lib/libcrypto/mlkem/mlkem.c
@@ -1,4 +1,4 @@
1/* $OpenBSD: mlkem.c,v 1.3 2025/08/19 21:37:08 tb Exp $ */ 1/* $OpenBSD: mlkem.c,v 1.4 2025/09/05 23:30:12 beck Exp $ */
2/* 2/*
3 * Copyright (c) 2025, Bob Beck <beck@obtuse.com> 3 * Copyright (c) 2025, Bob Beck <beck@obtuse.com>
4 * 4 *
@@ -77,24 +77,15 @@ MLKEM_generate_key_external_entropy(MLKEM_private_key *private_key,
77 if ((k = calloc(1, k_len)) == NULL) 77 if ((k = calloc(1, k_len)) == NULL)
78 goto err; 78 goto err;
79 79
80 switch (private_key->rank) { 80 if (!mlkem_generate_key_external_entropy(k, private_key, entropy))
81 case RANK768: 81 goto err;
82 if (!MLKEM768_generate_key_external_entropy(k, private_key,
83 entropy))
84 goto err;
85 break;
86 case RANK1024:
87 if (!MLKEM1024_generate_key_external_entropy(k, private_key,
88 entropy))
89 goto err;
90 break;
91 }
92 82
93 private_key->state = MLKEM_PRIVATE_KEY_INITIALIZED; 83 private_key->state = MLKEM_PRIVATE_KEY_INITIALIZED;
94 84
95 *out_encoded_public_key = k; 85 *out_encoded_public_key = k;
96 *out_encoded_public_key_len = k_len; 86 *out_encoded_public_key_len = k_len;
97 k = NULL; 87 k = NULL;
88 k_len = 0;
98 89
99 ret = 1; 90 ret = 1;
100 91
@@ -154,18 +145,8 @@ MLKEM_private_key_from_seed(MLKEM_private_key *private_key,
154 if (seed_len != MLKEM_SEED_LENGTH) 145 if (seed_len != MLKEM_SEED_LENGTH)
155 goto err; 146 goto err;
156 147
157 switch (private_key->rank) { 148 if (!mlkem_private_key_from_seed(seed, seed_len, private_key))
158 case RANK768: 149 goto err;
159 if (!MLKEM768_private_key_from_seed(seed,
160 seed_len, private_key))
161 goto err;
162 break;
163 case RANK1024:
164 if (!MLKEM1024_private_key_from_seed(private_key,
165 seed, seed_len))
166 goto err;
167 break;
168 }
169 150
170 private_key->state = MLKEM_PRIVATE_KEY_INITIALIZED; 151 private_key->state = MLKEM_PRIVATE_KEY_INITIALIZED;
171 152
@@ -187,14 +168,8 @@ MLKEM_public_from_private(const MLKEM_private_key *private_key,
187 return 0; 168 return 0;
188 if (public_key->rank != private_key->rank) 169 if (public_key->rank != private_key->rank)
189 return 0; 170 return 0;
190 switch (private_key->rank) { 171
191 case RANK768: 172 mlkem_public_from_private(private_key, public_key);
192 MLKEM768_public_from_private(private_key, public_key);
193 break;
194 case RANK1024:
195 MLKEM1024_public_from_private(private_key, public_key);
196 break;
197 }
198 173
199 public_key->state = MLKEM_PUBLIC_KEY_INITIALIZED; 174 public_key->state = MLKEM_PUBLIC_KEY_INITIALIZED;
200 175
@@ -230,17 +205,8 @@ MLKEM_encap_external_entropy(const MLKEM_public_key *public_key,
230 if ((ciphertext = calloc(1, ciphertext_len)) == NULL) 205 if ((ciphertext = calloc(1, ciphertext_len)) == NULL)
231 goto err; 206 goto err;
232 207
233 switch (public_key->rank) { 208 mlkem_encap_external_entropy(ciphertext, secret, public_key, entropy);
234 case RANK768:
235 MLKEM768_encap_external_entropy(ciphertext, secret, public_key,
236 entropy);
237 break;
238 209
239 case RANK1024:
240 MLKEM1024_encap_external_entropy(ciphertext, secret, public_key,
241 entropy);
242 break;
243 }
244 *out_ciphertext = ciphertext; 210 *out_ciphertext = ciphertext;
245 *out_ciphertext_len = ciphertext_len; 211 *out_ciphertext_len = ciphertext_len;
246 ciphertext = NULL; 212 ciphertext = NULL;
@@ -291,15 +257,7 @@ MLKEM_decap(const MLKEM_private_key *private_key,
291 if ((s = calloc(1, MLKEM_SHARED_SECRET_LENGTH)) == NULL) 257 if ((s = calloc(1, MLKEM_SHARED_SECRET_LENGTH)) == NULL)
292 goto err; 258 goto err;
293 259
294 switch (private_key->rank) { 260 mlkem_decap(private_key, ciphertext, ciphertext_len, s);
295 case RANK768:
296 MLKEM768_decap(private_key, ciphertext, ciphertext_len, s);
297 break;
298
299 case RANK1024:
300 MLKEM1024_decap(private_key, ciphertext, ciphertext_len, s);
301 break;
302 }
303 261
304 *out_shared_secret = s; 262 *out_shared_secret = s;
305 *out_shared_secret_len = MLKEM_SHARED_SECRET_LENGTH; 263 *out_shared_secret_len = MLKEM_SHARED_SECRET_LENGTH;
@@ -324,14 +282,7 @@ MLKEM_marshal_public_key(const MLKEM_public_key *public_key, uint8_t **out,
324 if (!public_key_is_valid(public_key)) 282 if (!public_key_is_valid(public_key))
325 return 0; 283 return 0;
326 284
327 switch (public_key->rank) { 285 return mlkem_marshal_public_key(public_key, out, out_len);
328 case RANK768:
329 return MLKEM768_marshal_public_key(public_key, out, out_len);
330 case RANK1024:
331 return MLKEM1024_marshal_public_key(public_key, out, out_len);
332 default:
333 return 0;
334 }
335} 286}
336LCRYPTO_ALIAS(MLKEM_marshal_public_key); 287LCRYPTO_ALIAS(MLKEM_marshal_public_key);
337 288
@@ -349,14 +300,7 @@ MLKEM_marshal_private_key(const MLKEM_private_key *private_key, uint8_t **out,
349 if (!private_key_is_valid(private_key)) 300 if (!private_key_is_valid(private_key))
350 return 0; 301 return 0;
351 302
352 switch (private_key->rank) { 303 return mlkem_marshal_private_key(private_key, out, out_len);
353 case RANK768:
354 return MLKEM768_marshal_private_key(private_key, out, out_len);
355 case RANK1024:
356 return MLKEM1024_marshal_private_key(private_key, out, out_len);
357 default:
358 return 0;
359 }
360} 304}
361LCRYPTO_ALIAS(MLKEM_marshal_private_key); 305LCRYPTO_ALIAS(MLKEM_marshal_private_key);
362 306
@@ -370,18 +314,8 @@ MLKEM_parse_public_key(MLKEM_public_key *public_key, const uint8_t *in,
370 if (in_len != MLKEM_public_key_encoded_length(public_key)) 314 if (in_len != MLKEM_public_key_encoded_length(public_key))
371 return 0; 315 return 0;
372 316
373 switch (public_key->rank) { 317 if (!mlkem_parse_public_key(in, in_len, public_key))
374 case RANK768: 318 return 0;
375 if (!MLKEM768_parse_public_key(in, in_len,
376 public_key))
377 return 0;
378 break;
379 case RANK1024:
380 if (!MLKEM1024_parse_public_key(in, in_len,
381 public_key))
382 return 0;
383 break;
384 }
385 319
386 public_key->state = MLKEM_PUBLIC_KEY_INITIALIZED; 320 public_key->state = MLKEM_PUBLIC_KEY_INITIALIZED;
387 321
@@ -399,16 +333,8 @@ MLKEM_parse_private_key(MLKEM_private_key *private_key, const uint8_t *in,
399 if (in_len != MLKEM_private_key_encoded_length(private_key)) 333 if (in_len != MLKEM_private_key_encoded_length(private_key))
400 return 0; 334 return 0;
401 335
402 switch (private_key->rank) { 336 if (!mlkem_parse_private_key(in, in_len, private_key))
403 case RANK768: 337 return 0;
404 if (!MLKEM768_parse_private_key(in, in_len, private_key))
405 return 0;
406 break;
407 case RANK1024:
408 if (!MLKEM1024_parse_private_key(in, in_len, private_key))
409 return 0;
410 break;
411 }
412 338
413 private_key->state = MLKEM_PRIVATE_KEY_INITIALIZED; 339 private_key->state = MLKEM_PRIVATE_KEY_INITIALIZED;
414 340
diff --git a/src/lib/libcrypto/mlkem/mlkem1024.c b/src/lib/libcrypto/mlkem/mlkem1024.c
deleted file mode 100644
index 8f4f41f8ff..0000000000
--- a/src/lib/libcrypto/mlkem/mlkem1024.c
+++ /dev/null
@@ -1,1183 +0,0 @@
1/* $OpenBSD: mlkem1024.c,v 1.12 2025/08/14 15:48:48 beck Exp $ */
2/*
3 * Copyright (c) 2024, Google Inc.
4 * Copyright (c) 2024, 2025 Bob Beck <beck@obtuse.com>
5 *
6 * Permission to use, copy, modify, and/or distribute this software for any
7 * purpose with or without fee is hereby granted, provided that the above
8 * copyright notice and this permission notice appear in all copies.
9 *
10 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
13 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
15 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
16 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17 */
18
19#include <assert.h>
20#include <stdlib.h>
21#include <string.h>
22
23#include <openssl/mlkem.h>
24
25#include "bytestring.h"
26#include "sha3_internal.h"
27#include "mlkem_internal.h"
28#include "constant_time.h"
29#include "crypto_internal.h"
30
31/*
32 * See
33 * https://csrc.nist.gov/pubs/fips/203/final
34 */
35
36static void
37prf(uint8_t *out, size_t out_len, const uint8_t in[33])
38{
39 sha3_ctx ctx;
40 shake256_init(&ctx);
41 shake_update(&ctx, in, 33);
42 shake_xof(&ctx);
43 shake_out(&ctx, out, out_len);
44}
45
46/* Section 4.1 */
47static void
48hash_h(uint8_t out[32], const uint8_t *in, size_t len)
49{
50 sha3_ctx ctx;
51 sha3_init(&ctx, 32);
52 sha3_update(&ctx, in, len);
53 sha3_final(out, &ctx);
54}
55
56static void
57hash_g(uint8_t out[64], const uint8_t *in, size_t len)
58{
59 sha3_ctx ctx;
60 sha3_init(&ctx, 64);
61 sha3_update(&ctx, in, len);
62 sha3_final(out, &ctx);
63}
64
65/* this is called 'J' in the spec */
66static void
67kdf(uint8_t out[MLKEM_SHARED_SECRET_BYTES], const uint8_t failure_secret[32],
68 const uint8_t *in, size_t len)
69{
70 sha3_ctx ctx;
71 shake256_init(&ctx);
72 shake_update(&ctx, failure_secret, 32);
73 shake_update(&ctx, in, len);
74 shake_xof(&ctx);
75 shake_out(&ctx, out, MLKEM_SHARED_SECRET_BYTES);
76}
77
78#define DEGREE 256
79
80static const size_t kBarrettMultiplier = 5039;
81static const unsigned kBarrettShift = 24;
82static const uint16_t kPrime = 3329;
83static const int kLog2Prime = 12;
84static const uint16_t kHalfPrime = (/*kPrime=*/3329 - 1) / 2;
85static const int kDU1024 = 11;
86static const int kDV1024 = 5;
87
88/*
89 * kInverseDegree is 128^-1 mod 3329; 128 because kPrime does not have a 512th
90 * root of unity.
91 */
92static const uint16_t kInverseDegree = 3303;
93static const size_t kEncodedVectorSize =
94 (/*kLog2Prime=*/12 * DEGREE / 8) * RANK1024;
95static const size_t kCompressedVectorSize = /*kDU1024=*/ 11 * RANK1024 * DEGREE /
96 8;
97
98typedef struct scalar {
99 /* On every function entry and exit, 0 <= c < kPrime. */
100 uint16_t c[DEGREE];
101} scalar;
102
103typedef struct vector {
104 scalar v[RANK1024];
105} vector;
106
107typedef struct matrix {
108 scalar v[RANK1024][RANK1024];
109} matrix;
110
111/*
112 * This bit of Python will be referenced in some of the following comments:
113 *
114 * p = 3329
115 *
116 * def bitreverse(i):
117 * ret = 0
118 * for n in range(7):
119 * bit = i & 1
120 * ret <<= 1
121 * ret |= bit
122 * i >>= 1
123 * return ret
124 */
125
126/* kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)] */
127static const uint16_t kNTTRoots[128] = {
128 1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797,
129 2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333,
130 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756,
131 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915,
132 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648,
133 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100,
134 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789,
135 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641,
136 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757,
137 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886,
138 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154,
139};
140
141/* kInverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)] */
142static const uint16_t kInverseNTTRoots[128] = {
143 1, 1600, 40, 749, 2481, 1432, 2699, 687, 1583, 2760, 69, 543,
144 2532, 3136, 1410, 2267, 2508, 1355, 450, 936, 447, 2794, 1235, 1903,
145 1996, 1089, 3273, 283, 1853, 1990, 882, 3033, 2419, 2102, 219, 855,
146 2681, 1848, 712, 682, 927, 1795, 461, 1891, 2877, 2522, 1894, 1010,
147 1414, 2009, 3296, 464, 2697, 816, 1352, 2679, 1274, 1052, 1025, 2132,
148 1573, 76, 2998, 3040, 1175, 2444, 394, 1219, 2300, 1455, 2117, 1607,
149 2443, 554, 1179, 2186, 2303, 2926, 2237, 525, 735, 863, 2768, 1230,
150 2572, 556, 3010, 2266, 1684, 1239, 780, 2954, 109, 1292, 1031, 1745,
151 2688, 3061, 992, 2596, 941, 892, 1021, 2390, 642, 1868, 2377, 1482,
152 1540, 540, 1678, 1626, 279, 314, 1173, 2573, 3096, 48, 667, 1920,
153 2229, 1041, 2606, 1692, 680, 2746, 568, 3312,
154};
155
156/* kModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)] */
157static const uint16_t kModRoots[128] = {
158 17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606,
159 2288, 1041, 1100, 2229, 1409, 1920, 2662, 667, 3281, 48, 233, 3096,
160 756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678,
161 2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642,
162 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992,
163 268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109,
164 375, 2954, 2549, 780, 2090, 1239, 1645, 1684, 1063, 2266, 319, 3010,
165 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735,
166 2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179,
167 2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300,
168 2110, 1219, 2935, 394, 885, 2444, 2154, 1175,
169};
170
171/* reduce_once reduces 0 <= x < 2*kPrime, mod kPrime. */
172static uint16_t
173reduce_once(uint16_t x)
174{
175 assert(x < 2 * kPrime);
176 const uint16_t subtracted = x - kPrime;
177 uint16_t mask = 0u - (subtracted >> 15);
178
179 /*
180 * Although this is a constant-time select, we omit a value barrier here.
181 * Value barriers impede auto-vectorization (likely because it forces the
182 * value to transit through a general-purpose register). On AArch64, this
183 * is a difference of 2x.
184 *
185 * We usually add value barriers to selects because Clang turns
186 * consecutive selects with the same condition into a branch instead of
187 * CMOV/CSEL. This condition does not occur in ML-KEM, so omitting it
188 * seems to be safe so far but see
189 * |scalar_centered_binomial_distribution_eta_2_with_prf|.
190 */
191 return (mask & x) | (~mask & subtracted);
192}
193
194/*
195 * constant time reduce x mod kPrime using Barrett reduction. x must be less
196 * than kPrime + 2×kPrime².
197 */
198static uint16_t
199reduce(uint32_t x)
200{
201 uint64_t product = (uint64_t)x * kBarrettMultiplier;
202 uint32_t quotient = (uint32_t)(product >> kBarrettShift);
203 uint32_t remainder = x - quotient * kPrime;
204
205 assert(x < kPrime + 2u * kPrime * kPrime);
206 return reduce_once(remainder);
207}
208
209static void
210scalar_zero(scalar *out)
211{
212 memset(out, 0, sizeof(*out));
213}
214
215static void
216vector_zero(vector *out)
217{
218 memset(out, 0, sizeof(*out));
219}
220
221/*
222 * In place number theoretic transform of a given scalar.
223 * Note that MLKEM's kPrime 3329 does not have a 512th root of unity, so this
224 * transform leaves off the last iteration of the usual FFT code, with the 128
225 * relevant roots of unity being stored in |kNTTRoots|. This means the output
226 * should be seen as 128 elements in GF(3329^2), with the coefficients of the
227 * elements being consecutive entries in |s->c|.
228 */
229static void
230scalar_ntt(scalar *s)
231{
232 int offset = DEGREE;
233 int step;
234 /*
235 * `int` is used here because using `size_t` throughout caused a ~5% slowdown
236 * with Clang 14 on Aarch64.
237 */
238 for (step = 1; step < DEGREE / 2; step <<= 1) {
239 int i, j, k = 0;
240
241 offset >>= 1;
242 for (i = 0; i < step; i++) {
243 const uint32_t step_root = kNTTRoots[i + step];
244
245 for (j = k; j < k + offset; j++) {
246 uint16_t odd, even;
247
248 odd = reduce(step_root * s->c[j + offset]);
249 even = s->c[j];
250 s->c[j] = reduce_once(odd + even);
251 s->c[j + offset] = reduce_once(even - odd +
252 kPrime);
253 }
254 k += 2 * offset;
255 }
256 }
257}
258
259static void
260vector_ntt(vector *a)
261{
262 int i;
263
264 for (i = 0; i < RANK1024; i++) {
265 scalar_ntt(&a->v[i]);
266 }
267}
268
269/*
270 * In place inverse number theoretic transform of a given scalar, with pairs of
271 * entries of s->v being interpreted as elements of GF(3329^2). Just as with the
272 * number theoretic transform, this leaves off the first step of the normal iFFT
273 * to account for the fact that 3329 does not have a 512th root of unity, using
274 * the precomputed 128 roots of unity stored in |kInverseNTTRoots|.
275 */
276static void
277scalar_inverse_ntt(scalar *s)
278{
279 int i, j, k, offset, step = DEGREE / 2;
280
281 /*
282 * `int` is used here because using `size_t` throughout caused a ~5% slowdown
283 * with Clang 14 on Aarch64.
284 */
285 for (offset = 2; offset < DEGREE; offset <<= 1) {
286 step >>= 1;
287 k = 0;
288 for (i = 0; i < step; i++) {
289 uint32_t step_root = kInverseNTTRoots[i + step];
290 for (j = k; j < k + offset; j++) {
291 uint16_t odd, even;
292 odd = s->c[j + offset];
293 even = s->c[j];
294 s->c[j] = reduce_once(odd + even);
295 s->c[j + offset] = reduce(step_root *
296 (even - odd + kPrime));
297 }
298 k += 2 * offset;
299 }
300 }
301 for (i = 0; i < DEGREE; i++) {
302 s->c[i] = reduce(s->c[i] * kInverseDegree);
303 }
304}
305
306static void
307vector_inverse_ntt(vector *a)
308{
309 int i;
310
311 for (i = 0; i < RANK1024; i++) {
312 scalar_inverse_ntt(&a->v[i]);
313 }
314}
315
316static void
317scalar_add(scalar *lhs, const scalar *rhs)
318{
319 int i;
320
321 for (i = 0; i < DEGREE; i++) {
322 lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
323 }
324}
325
326static void
327scalar_sub(scalar *lhs, const scalar *rhs)
328{
329 int i;
330
331 for (i = 0; i < DEGREE; i++) {
332 lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
333 }
334}
335
336/*
337 * Multiplying two scalars in the number theoretically transformed state.
338 * Since 3329 does not have a 512th root of unity, this means we have to
339 * interpret the 2*ith and (2*i+1)th entries of the scalar as elements of
340 * GF(3329)[X]/(X^2 - 17^(2*bitreverse(i)+1)).
341 * The value of 17^(2*bitreverse(i)+1) mod 3329 is stored in the precomputed
342 * |kModRoots| table. Our Barrett transform only allows us to multiply two
343 * reduced numbers together, so we need some intermediate reduction steps,
344 * even if an uint64_t could hold 3 multiplied numbers.
345 */
346static void
347scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs)
348{
349 int i;
350
351 for (i = 0; i < DEGREE / 2; i++) {
352 uint32_t real_real = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i];
353 uint32_t img_img = (uint32_t)lhs->c[2 * i + 1] *
354 rhs->c[2 * i + 1];
355 uint32_t real_img = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i + 1];
356 uint32_t img_real = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i];
357
358 out->c[2 * i] =
359 reduce(real_real +
360 (uint32_t)reduce(img_img) * kModRoots[i]);
361 out->c[2 * i + 1] = reduce(img_real + real_img);
362 }
363}
364
365static void
366vector_add(vector *lhs, const vector *rhs)
367{
368 int i;
369
370 for (i = 0; i < RANK1024; i++) {
371 scalar_add(&lhs->v[i], &rhs->v[i]);
372 }
373}
374
375static void
376matrix_mult(vector *out, const matrix *m, const vector *a)
377{
378 int i, j;
379
380 vector_zero(out);
381 for (i = 0; i < RANK1024; i++) {
382 for (j = 0; j < RANK1024; j++) {
383 scalar product;
384
385 scalar_mult(&product, &m->v[i][j], &a->v[j]);
386 scalar_add(&out->v[i], &product);
387 }
388 }
389}
390
391static void
392matrix_mult_transpose(vector *out, const matrix *m,
393 const vector *a)
394{
395 int i, j;
396
397 vector_zero(out);
398 for (i = 0; i < RANK1024; i++) {
399 for (j = 0; j < RANK1024; j++) {
400 scalar product;
401
402 scalar_mult(&product, &m->v[j][i], &a->v[j]);
403 scalar_add(&out->v[i], &product);
404 }
405 }
406}
407
408static void
409scalar_inner_product(scalar *out, const vector *lhs,
410 const vector *rhs)
411{
412 int i;
413 scalar_zero(out);
414 for (i = 0; i < RANK1024; i++) {
415 scalar product;
416
417 scalar_mult(&product, &lhs->v[i], &rhs->v[i]);
418 scalar_add(out, &product);
419 }
420}
421
422/*
423 * Algorithm 6 of spec. Rejection samples a Keccak stream to get uniformly
424 * distributed elements. This is used for matrix expansion and only operates on
425 * public inputs.
426 */
427static void
428scalar_from_keccak_vartime(scalar *out, sha3_ctx *keccak_ctx)
429{
430 int i, done = 0;
431
432 while (done < DEGREE) {
433 uint8_t block[168];
434
435 shake_out(keccak_ctx, block, sizeof(block));
436 for (i = 0; i < sizeof(block) && done < DEGREE; i += 3) {
437 uint16_t d1 = block[i] + 256 * (block[i + 1] % 16);
438 uint16_t d2 = block[i + 1] / 16 + 16 * block[i + 2];
439
440 if (d1 < kPrime) {
441 out->c[done++] = d1;
442 }
443 if (d2 < kPrime && done < DEGREE) {
444 out->c[done++] = d2;
445 }
446 }
447 }
448}
449
450/*
451 * Algorithm 7 of the spec, with eta fixed to two and the PRF call
452 * included. Creates binominally distributed elements by sampling 2*|eta| bits,
453 * and setting the coefficient to the count of the first bits minus the count of
454 * the second bits, resulting in a centered binomial distribution. Since eta is
455 * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
456 * and 0 with probability 3/8.
457 */
458static void
459scalar_centered_binomial_distribution_eta_2_with_prf(scalar *out,
460 const uint8_t input[33])
461{
462 uint8_t entropy[128];
463 int i;
464
465 CTASSERT(sizeof(entropy) == 2 * /*kEta=*/ 2 * DEGREE / 8);
466 prf(entropy, sizeof(entropy), input);
467
468 for (i = 0; i < DEGREE; i += 2) {
469 uint8_t byte = entropy[i / 2];
470 uint16_t mask;
471 uint16_t value = (byte & 1) + ((byte >> 1) & 1);
472
473 value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
474
475 /*
476 * Add |kPrime| if |value| underflowed. See |reduce_once| for a
477 * discussion on why the value barrier is omitted. While this
478 * could have been written reduce_once(value + kPrime), this is
479 * one extra addition and small range of |value| tempts some
480 * versions of Clang to emit a branch.
481 */
482 mask = 0u - (value >> 15);
483 out->c[i] = ((value + kPrime) & mask) | (value & ~mask);
484
485 byte >>= 4;
486 value = (byte & 1) + ((byte >> 1) & 1);
487 value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
488 /* See above. */
489 mask = 0u - (value >> 15);
490 out->c[i + 1] = ((value + kPrime) & mask) | (value & ~mask);
491 }
492}
493
494/*
495 * Generates a secret vector by using
496 * |scalar_centered_binomial_distribution_eta_2_with_prf|, using the given seed
497 * appending and incrementing |counter| for entry of the vector.
498 */
499static void
500vector_generate_secret_eta_2(vector *out, uint8_t *counter,
501 const uint8_t seed[32])
502{
503 uint8_t input[33];
504 int i;
505
506 memcpy(input, seed, 32);
507 for (i = 0; i < RANK1024; i++) {
508 input[32] = (*counter)++;
509 scalar_centered_binomial_distribution_eta_2_with_prf(&out->v[i],
510 input);
511 }
512}
513
514/* Expands the matrix of a seed for key generation and for encaps-CPA. */
515static void
516matrix_expand(matrix *out, const uint8_t rho[32])
517{
518 uint8_t input[34];
519 int i, j;
520
521 memcpy(input, rho, 32);
522 for (i = 0; i < RANK1024; i++) {
523 for (j = 0; j < RANK1024; j++) {
524 sha3_ctx keccak_ctx;
525
526 input[32] = i;
527 input[33] = j;
528 shake128_init(&keccak_ctx);
529 shake_update(&keccak_ctx, input, sizeof(input));
530 shake_xof(&keccak_ctx);
531 scalar_from_keccak_vartime(&out->v[i][j], &keccak_ctx);
532 }
533 }
534}
535
536static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f,
537 0x1f, 0x3f, 0x7f, 0xff};
538
539static void
540scalar_encode(uint8_t *out, const scalar *s, int bits)
541{
542 uint8_t out_byte = 0;
543 int i, out_byte_bits = 0;
544
545 assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1);
546 for (i = 0; i < DEGREE; i++) {
547 uint16_t element = s->c[i];
548 int element_bits_done = 0;
549
550 while (element_bits_done < bits) {
551 int chunk_bits = bits - element_bits_done;
552 int out_bits_remaining = 8 - out_byte_bits;
553
554 if (chunk_bits >= out_bits_remaining) {
555 chunk_bits = out_bits_remaining;
556 out_byte |= (element &
557 kMasks[chunk_bits - 1]) << out_byte_bits;
558 *out = out_byte;
559 out++;
560 out_byte_bits = 0;
561 out_byte = 0;
562 } else {
563 out_byte |= (element &
564 kMasks[chunk_bits - 1]) << out_byte_bits;
565 out_byte_bits += chunk_bits;
566 }
567
568 element_bits_done += chunk_bits;
569 element >>= chunk_bits;
570 }
571 }
572
573 if (out_byte_bits > 0) {
574 *out = out_byte;
575 }
576}
577
578/* scalar_encode_1 is |scalar_encode| specialised for |bits| == 1. */
579static void
580scalar_encode_1(uint8_t out[32], const scalar *s)
581{
582 int i, j;
583
584 for (i = 0; i < DEGREE; i += 8) {
585 uint8_t out_byte = 0;
586
587 for (j = 0; j < 8; j++) {
588 out_byte |= (s->c[i + j] & 1) << j;
589 }
590 *out = out_byte;
591 out++;
592 }
593}
594
595/*
596 * Encodes an entire vector into 32*|RANK1024|*|bits| bytes. Note that since 256
597 * (DEGREE) is divisible by 8, the individual vector entries will always fill a
598 * whole number of bytes, so we do not need to worry about bit packing here.
599 */
600static void
601vector_encode(uint8_t *out, const vector *a, int bits)
602{
603 int i;
604
605 for (i = 0; i < RANK1024; i++) {
606 scalar_encode(out + i * bits * DEGREE / 8, &a->v[i], bits);
607 }
608}
609
610/* Encodes an entire vector as above, but adding it to a CBB */
611static int
612vector_encode_cbb(CBB *cbb, const vector *a, int bits)
613{
614 uint8_t *encoded_vector;
615
616 if (!CBB_add_space(cbb, &encoded_vector, kEncodedVectorSize))
617 return 0;
618 vector_encode(encoded_vector, a, bits);
619
620 return 1;
621}
622
623/*
624 * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
625 * |out|. It returns one on success and zero if any parsed value is >=
626 * |kPrime|.
627 */
628static int
629scalar_decode(scalar *out, const uint8_t *in, int bits)
630{
631 uint8_t in_byte = 0;
632 int i, in_byte_bits_left = 0;
633
634 assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1);
635
636 for (i = 0; i < DEGREE; i++) {
637 uint16_t element = 0;
638 int element_bits_done = 0;
639
640 while (element_bits_done < bits) {
641 int chunk_bits = bits - element_bits_done;
642
643 if (in_byte_bits_left == 0) {
644 in_byte = *in;
645 in++;
646 in_byte_bits_left = 8;
647 }
648
649 if (chunk_bits > in_byte_bits_left) {
650 chunk_bits = in_byte_bits_left;
651 }
652
653 element |= (in_byte & kMasks[chunk_bits - 1]) <<
654 element_bits_done;
655 in_byte_bits_left -= chunk_bits;
656 in_byte >>= chunk_bits;
657
658 element_bits_done += chunk_bits;
659 }
660
661 if (element >= kPrime) {
662 return 0;
663 }
664 out->c[i] = element;
665 }
666
667 return 1;
668}
669
670/* scalar_decode_1 is |scalar_decode| specialised for |bits| == 1. */
671static void
672scalar_decode_1(scalar *out, const uint8_t in[32])
673{
674 int i, j;
675
676 for (i = 0; i < DEGREE; i += 8) {
677 uint8_t in_byte = *in;
678
679 in++;
680 for (j = 0; j < 8; j++) {
681 out->c[i + j] = in_byte & 1;
682 in_byte >>= 1;
683 }
684 }
685}
686
687/*
688 * Decodes 32*|RANK1024|*|bits| bytes from |in| into |out|. It returns one on
689 * success or zero if any parsed value is >= |kPrime|.
690 */
691static int
692vector_decode(vector *out, const uint8_t *in, int bits)
693{
694 int i;
695
696 for (i = 0; i < RANK1024; i++) {
697 if (!scalar_decode(&out->v[i], in + i * bits * DEGREE / 8,
698 bits)) {
699 return 0;
700 }
701 }
702 return 1;
703}
704
705/*
706 * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
707 * numbers close to each other together. The formula used is
708 * round(2^|bits|/kPrime*x) mod 2^|bits|.
709 * Uses Barrett reduction to achieve constant time. Since we need both the
710 * remainder (for rounding) and the quotient (as the result), we cannot use
711 * |reduce| here, but need to do the Barrett reduction directly.
712 */
713static uint16_t
714compress(uint16_t x, int bits)
715{
716 uint32_t shifted = (uint32_t)x << bits;
717 uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
718 uint32_t quotient = (uint32_t)(product >> kBarrettShift);
719 uint32_t remainder = shifted - quotient * kPrime;
720
721 /*
722 * Adjust the quotient to round correctly:
723 * 0 <= remainder <= kHalfPrime round to 0
724 * kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
725 * kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
726 */
727 assert(remainder < 2u * kPrime);
728 quotient += 1 & constant_time_lt(kHalfPrime, remainder);
729 quotient += 1 & constant_time_lt(kPrime + kHalfPrime, remainder);
730 return quotient & ((1 << bits) - 1);
731}
732
733/*
734 * Decompresses |x| by using an equi-distant representative. The formula is
735 * round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us to
736 * implement this logic using only bit operations.
737 */
738static uint16_t
739decompress(uint16_t x, int bits)
740{
741 uint32_t product = (uint32_t)x * kPrime;
742 uint32_t power = 1 << bits;
743 /* This is |product| % power, since |power| is a power of 2. */
744 uint32_t remainder = product & (power - 1);
745 /* This is |product| / power, since |power| is a power of 2. */
746 uint32_t lower = product >> bits;
747
748 /*
749 * The rounding logic works since the first half of numbers mod |power| have a
750 * 0 as first bit, and the second half has a 1 as first bit, since |power| is
751 * a power of 2. As a 12 bit number, |remainder| is always positive, so we
752 * will shift in 0s for a right shift.
753 */
754 return lower + (remainder >> (bits - 1));
755}
756
757static void
758scalar_compress(scalar *s, int bits)
759{
760 int i;
761
762 for (i = 0; i < DEGREE; i++) {
763 s->c[i] = compress(s->c[i], bits);
764 }
765}
766
767static void
768scalar_decompress(scalar *s, int bits)
769{
770 int i;
771
772 for (i = 0; i < DEGREE; i++) {
773 s->c[i] = decompress(s->c[i], bits);
774 }
775}
776
777static void
778vector_compress(vector *a, int bits)
779{
780 int i;
781
782 for (i = 0; i < RANK1024; i++) {
783 scalar_compress(&a->v[i], bits);
784 }
785}
786
787static void
788vector_decompress(vector *a, int bits)
789{
790 int i;
791
792 for (i = 0; i < RANK1024; i++) {
793 scalar_decompress(&a->v[i], bits);
794 }
795}
796
797struct public_key {
798 vector t;
799 uint8_t rho[32];
800 uint8_t public_key_hash[32];
801 matrix m;
802};
803
804CTASSERT(sizeof(struct MLKEM1024_public_key) == sizeof(struct public_key));
805
806static struct public_key *
807public_key_1024_from_external(const MLKEM_public_key *external)
808{
809 if (external->rank != RANK1024)
810 return NULL;
811 return (struct public_key *)external->key_1024;
812}
813
814struct private_key {
815 struct public_key pub;
816 vector s;
817 uint8_t fo_failure_secret[32];
818};
819
820CTASSERT(sizeof(struct MLKEM1024_private_key) == sizeof(struct private_key));
821
822static struct private_key *
823private_key_1024_from_external(const MLKEM_private_key *external)
824{
825 if (external->rank != RANK1024)
826 return NULL;
827 return (struct private_key *)external->key_1024;
828}
829
830/*
831 * Calls |MLKEM1024_generate_key_external_entropy| with random bytes from
832 * |RAND_bytes|.
833 */
834int
835MLKEM1024_generate_key(uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES],
836 uint8_t optional_out_seed[MLKEM_SEED_BYTES],
837 MLKEM_private_key *out_private_key)
838{
839 uint8_t entropy_buf[MLKEM_SEED_BYTES];
840 uint8_t *entropy = optional_out_seed != NULL ? optional_out_seed :
841 entropy_buf;
842
843 arc4random_buf(entropy, MLKEM_SEED_BYTES);
844 return MLKEM1024_generate_key_external_entropy(out_encoded_public_key,
845 out_private_key, entropy);
846}
847
848int
849MLKEM1024_private_key_from_seed(MLKEM_private_key *out_private_key,
850 const uint8_t *seed, size_t seed_len)
851{
852 uint8_t public_key_bytes[MLKEM1024_PUBLIC_KEY_BYTES];
853
854 if (seed_len != MLKEM_SEED_BYTES) {
855 return 0;
856 }
857 return MLKEM1024_generate_key_external_entropy(public_key_bytes,
858 out_private_key, seed);
859}
860
861static int
862mlkem_marshal_public_key(CBB *out, const struct public_key *pub)
863{
864 if (!vector_encode_cbb(out, &pub->t, kLog2Prime))
865 return 0;
866 return CBB_add_bytes(out, pub->rho, sizeof(pub->rho));
867}
868
869int
870MLKEM1024_generate_key_external_entropy(
871 uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES],
872 MLKEM_private_key *out_private_key,
873 const uint8_t entropy[MLKEM_SEED_BYTES])
874{
875 struct private_key *priv = private_key_1024_from_external(
876 out_private_key);
877 uint8_t augmented_seed[33];
878 uint8_t *rho, *sigma;
879 uint8_t counter = 0;
880 uint8_t hashed[64];
881 vector error;
882 CBB cbb;
883 int ret = 0;
884
885 memset(&cbb, 0, sizeof(CBB));
886 memcpy(augmented_seed, entropy, 32);
887 augmented_seed[32] = RANK1024;
888 hash_g(hashed, augmented_seed, 33);
889 rho = hashed;
890 sigma = hashed + 32;
891 memcpy(priv->pub.rho, hashed, sizeof(priv->pub.rho));
892 matrix_expand(&priv->pub.m, rho);
893 vector_generate_secret_eta_2(&priv->s, &counter, sigma);
894 vector_ntt(&priv->s);
895 vector_generate_secret_eta_2(&error, &counter, sigma);
896 vector_ntt(&error);
897 matrix_mult_transpose(&priv->pub.t, &priv->pub.m, &priv->s);
898 vector_add(&priv->pub.t, &error);
899
900 if (!CBB_init_fixed(&cbb, out_encoded_public_key,
901 MLKEM1024_PUBLIC_KEY_BYTES))
902 goto err;
903
904 if (!mlkem_marshal_public_key(&cbb, &priv->pub))
905 goto err;
906
907 hash_h(priv->pub.public_key_hash, out_encoded_public_key,
908 MLKEM1024_PUBLIC_KEY_BYTES);
909 memcpy(priv->fo_failure_secret, entropy + 32, 32);
910
911 ret = 1;
912
913 err:
914 CBB_cleanup(&cbb);
915
916 return ret;
917}
918
919void
920MLKEM1024_public_from_private(const MLKEM_private_key *private_key,
921 MLKEM_public_key *out_public_key)
922{
923 struct public_key *const pub = public_key_1024_from_external(
924 out_public_key);
925 const struct private_key *const priv = private_key_1024_from_external(
926 private_key);
927
928 *pub = priv->pub;
929}
930
931/*
932 * Encrypts a message with given randomness to the ciphertext in |out|. Without
933 * applying the Fujisaki-Okamoto transform this would not result in a CCA secure
934 * scheme, since lattice schemes are vulnerable to decryption failure oracles.
935 */
936static void
937encrypt_cpa(uint8_t out[MLKEM1024_CIPHERTEXT_BYTES],
938 const struct public_key *pub, const uint8_t message[32],
939 const uint8_t randomness[32])
940{
941 scalar expanded_message, scalar_error;
942 vector secret, error, u;
943 uint8_t counter = 0;
944 uint8_t input[33];
945 scalar v;
946
947 vector_generate_secret_eta_2(&secret, &counter, randomness);
948 vector_ntt(&secret);
949 vector_generate_secret_eta_2(&error, &counter, randomness);
950 memcpy(input, randomness, 32);
951 input[32] = counter;
952 scalar_centered_binomial_distribution_eta_2_with_prf(&scalar_error,
953 input);
954 matrix_mult(&u, &pub->m, &secret);
955 vector_inverse_ntt(&u);
956 vector_add(&u, &error);
957 scalar_inner_product(&v, &pub->t, &secret);
958 scalar_inverse_ntt(&v);
959 scalar_add(&v, &scalar_error);
960 scalar_decode_1(&expanded_message, message);
961 scalar_decompress(&expanded_message, 1);
962 scalar_add(&v, &expanded_message);
963 vector_compress(&u, kDU1024);
964 vector_encode(out, &u, kDU1024);
965 scalar_compress(&v, kDV1024);
966 scalar_encode(out + kCompressedVectorSize, &v, kDV1024);
967}
968
969/* Calls MLKEM1024_encap_external_entropy| with random bytes */
970void
971MLKEM1024_encap(const MLKEM_public_key *public_key,
972 uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES],
973 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES])
974{
975 uint8_t entropy[MLKEM_ENCAP_ENTROPY];
976
977 arc4random_buf(entropy, MLKEM_ENCAP_ENTROPY);
978 MLKEM1024_encap_external_entropy(out_ciphertext, out_shared_secret,
979 public_key, entropy);
980}
981
982/* See section 6.2 of the spec. */
983void
984MLKEM1024_encap_external_entropy(
985 uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES],
986 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
987 const MLKEM_public_key *public_key,
988 const uint8_t entropy[MLKEM_ENCAP_ENTROPY])
989{
990 const struct public_key *pub = public_key_1024_from_external(public_key);
991 uint8_t key_and_randomness[64];
992 uint8_t input[64];
993
994 memcpy(input, entropy, MLKEM_ENCAP_ENTROPY);
995 memcpy(input + MLKEM_ENCAP_ENTROPY, pub->public_key_hash,
996 sizeof(input) - MLKEM_ENCAP_ENTROPY);
997 hash_g(key_and_randomness, input, sizeof(input));
998 encrypt_cpa(out_ciphertext, pub, entropy, key_and_randomness + 32);
999 memcpy(out_shared_secret, key_and_randomness, 32);
1000}
1001
1002static void
1003decrypt_cpa(uint8_t out[32], const struct private_key *priv,
1004 const uint8_t ciphertext[MLKEM1024_CIPHERTEXT_BYTES])
1005{
1006 scalar mask, v;
1007 vector u;
1008
1009 vector_decode(&u, ciphertext, kDU1024);
1010 vector_decompress(&u, kDU1024);
1011 vector_ntt(&u);
1012 scalar_decode(&v, ciphertext + kCompressedVectorSize, kDV1024);
1013 scalar_decompress(&v, kDV1024);
1014 scalar_inner_product(&mask, &priv->s, &u);
1015 scalar_inverse_ntt(&mask);
1016 scalar_sub(&v, &mask);
1017 scalar_compress(&v, 1);
1018 scalar_encode_1(out, &v);
1019}
1020
1021/* See section 6.3 */
1022int
1023MLKEM1024_decap(const MLKEM_private_key *private_key,
1024 const uint8_t *ciphertext, size_t ciphertext_len,
1025 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES])
1026 {
1027 const struct private_key *priv = private_key_1024_from_external(
1028 private_key);
1029 uint8_t expected_ciphertext[MLKEM1024_CIPHERTEXT_BYTES];
1030 uint8_t key_and_randomness[64];
1031 uint8_t failure_key[32];
1032 uint8_t decrypted[64];
1033 uint8_t mask;
1034 int i;
1035
1036 if (ciphertext_len != MLKEM1024_CIPHERTEXT_BYTES) {
1037 arc4random_buf(out_shared_secret, MLKEM_SHARED_SECRET_BYTES);
1038 return 0;
1039 }
1040
1041 decrypt_cpa(decrypted, priv, ciphertext);
1042 memcpy(decrypted + 32, priv->pub.public_key_hash,
1043 sizeof(decrypted) - 32);
1044 hash_g(key_and_randomness, decrypted, sizeof(decrypted));
1045 encrypt_cpa(expected_ciphertext, &priv->pub, decrypted,
1046 key_and_randomness + 32);
1047 kdf(failure_key, priv->fo_failure_secret, ciphertext, ciphertext_len);
1048 mask = constant_time_eq_int_8(memcmp(ciphertext, expected_ciphertext,
1049 sizeof(expected_ciphertext)), 0);
1050 for (i = 0; i < MLKEM_SHARED_SECRET_BYTES; i++) {
1051 out_shared_secret[i] = constant_time_select_8(mask,
1052 key_and_randomness[i], failure_key[i]);
1053 }
1054
1055 return 1;
1056}
1057
1058int
1059MLKEM1024_marshal_public_key(const MLKEM_public_key *public_key,
1060 uint8_t **output, size_t *output_len)
1061{
1062 int ret = 0;
1063 CBB cbb;
1064
1065 if (!CBB_init(&cbb, MLKEM1024_PUBLIC_KEY_BYTES))
1066 goto err;
1067 if (!mlkem_marshal_public_key(&cbb,
1068 public_key_1024_from_external(public_key)))
1069 goto err;
1070 if (!CBB_finish(&cbb, output, output_len))
1071 goto err;
1072
1073 ret = 1;
1074
1075 err:
1076 CBB_cleanup(&cbb);
1077
1078 return ret;
1079}
1080
1081/*
1082 * mlkem_parse_public_key_no_hash parses |in| into |pub| but doesn't calculate
1083 * the value of |pub->public_key_hash|.
1084 */
1085static int
1086mlkem_parse_public_key_no_hash(struct public_key *pub, CBS *in)
1087{
1088 CBS t_bytes;
1089
1090 if (!CBS_get_bytes(in, &t_bytes, kEncodedVectorSize))
1091 return 0;
1092 if (!vector_decode(&pub->t, CBS_data(&t_bytes), kLog2Prime))
1093 return 0;
1094
1095 memcpy(pub->rho, CBS_data(in), sizeof(pub->rho));
1096 if (!CBS_skip(in, sizeof(pub->rho)))
1097 return 0;
1098 matrix_expand(&pub->m, pub->rho);
1099 return 1;
1100}
1101
1102int
1103MLKEM1024_parse_public_key(const uint8_t *input, size_t input_len,
1104 MLKEM_public_key *public_key)
1105{
1106 struct public_key *pub = public_key_1024_from_external(public_key);
1107 CBS cbs;
1108
1109 CBS_init(&cbs, input, input_len);
1110 if (!mlkem_parse_public_key_no_hash(pub, &cbs))
1111 return 0;
1112 if (CBS_len(&cbs) != 0)
1113 return 0;
1114
1115 hash_h(pub->public_key_hash, input, input_len);
1116
1117 return 1;
1118}
1119
1120int
1121MLKEM1024_marshal_private_key(const MLKEM_private_key *private_key,
1122 uint8_t **out_private_key, size_t *out_private_key_len)
1123{
1124 const struct private_key *const priv = private_key_1024_from_external(
1125 private_key);
1126 CBB cbb;
1127 int ret = 0;
1128
1129 if (!CBB_init(&cbb, MLKEM1024_PRIVATE_KEY_BYTES))
1130 goto err;
1131
1132 if (!vector_encode_cbb(&cbb, &priv->s, kLog2Prime))
1133 goto err;
1134 if (!mlkem_marshal_public_key(&cbb, &priv->pub))
1135 goto err;
1136 if (!CBB_add_bytes(&cbb, priv->pub.public_key_hash,
1137 sizeof(priv->pub.public_key_hash)))
1138 goto err;
1139 if (!CBB_add_bytes(&cbb, priv->fo_failure_secret,
1140 sizeof(priv->fo_failure_secret)))
1141 goto err;
1142
1143 if (!CBB_finish(&cbb, out_private_key, out_private_key_len))
1144 goto err;
1145
1146 ret = 1;
1147
1148 err:
1149 CBB_cleanup(&cbb);
1150
1151 return ret;
1152}
1153
1154int
1155MLKEM1024_parse_private_key(const uint8_t *input, size_t input_len,
1156 MLKEM_private_key *out_private_key)
1157{
1158 struct private_key *const priv = private_key_1024_from_external(
1159 out_private_key);
1160 CBS cbs, s_bytes;
1161
1162 CBS_init(&cbs, input, input_len);
1163
1164 if (!CBS_get_bytes(&cbs, &s_bytes, kEncodedVectorSize))
1165 return 0;
1166 if (!vector_decode(&priv->s, CBS_data(&s_bytes), kLog2Prime))
1167 return 0;
1168 if (!mlkem_parse_public_key_no_hash(&priv->pub, &cbs))
1169 return 0;
1170
1171 memcpy(priv->pub.public_key_hash, CBS_data(&cbs),
1172 sizeof(priv->pub.public_key_hash));
1173 if (!CBS_skip(&cbs, sizeof(priv->pub.public_key_hash)))
1174 return 0;
1175 memcpy(priv->fo_failure_secret, CBS_data(&cbs),
1176 sizeof(priv->fo_failure_secret));
1177 if (!CBS_skip(&cbs, sizeof(priv->fo_failure_secret)))
1178 return 0;
1179 if (CBS_len(&cbs) != 0)
1180 return 0;
1181
1182 return 1;
1183}
diff --git a/src/lib/libcrypto/mlkem/mlkem768.c b/src/lib/libcrypto/mlkem/mlkem_internal.c
index 1a44b9413f..653b2f332d 100644
--- a/src/lib/libcrypto/mlkem/mlkem768.c
+++ b/src/lib/libcrypto/mlkem/mlkem_internal.c
@@ -1,7 +1,7 @@
1/* $OpenBSD: mlkem768.c,v 1.13 2025/08/14 15:48:48 beck Exp $ */ 1/* $OpenBSD: mlkem_internal.c,v 1.1 2025/09/05 23:30:12 beck Exp $ */
2/* 2/*
3 * Copyright (c) 2024, Google Inc. 3 * Copyright (c) 2024, Google Inc.
4 * Copyright (c) 2024, Bob Beck <beck@obtuse.com> 4 * Copyright (c) 2024, 2025 Bob Beck <beck@obtuse.com>
5 * 5 *
6 * Permission to use, copy, modify, and/or distribute this software for any 6 * Permission to use, copy, modify, and/or distribute this software for any
7 * purpose with or without fee is hereby granted, provided that the above 7 * purpose with or without fee is hereby granted, provided that the above
@@ -65,7 +65,7 @@ hash_g(uint8_t out[64], const uint8_t *in, size_t len)
65 65
66/* this is called 'J' in the spec */ 66/* this is called 'J' in the spec */
67static void 67static void
68kdf(uint8_t out[MLKEM_SHARED_SECRET_BYTES], const uint8_t failure_secret[32], 68kdf(uint8_t out[MLKEM_SHARED_SECRET_LENGTH], const uint8_t failure_secret[32],
69 const uint8_t *in, size_t len) 69 const uint8_t *in, size_t len)
70{ 70{
71 sha3_ctx ctx; 71 sha3_ctx ctx;
@@ -73,7 +73,7 @@ kdf(uint8_t out[MLKEM_SHARED_SECRET_BYTES], const uint8_t failure_secret[32],
73 shake_update(&ctx, failure_secret, 32); 73 shake_update(&ctx, failure_secret, 32);
74 shake_update(&ctx, in, len); 74 shake_update(&ctx, in, len);
75 shake_xof(&ctx); 75 shake_xof(&ctx);
76 shake_out(&ctx, out, MLKEM_SHARED_SECRET_BYTES); 76 shake_out(&ctx, out, MLKEM_SHARED_SECRET_LENGTH);
77} 77}
78 78
79#define DEGREE 256 79#define DEGREE 256
@@ -85,29 +85,49 @@ static const int kLog2Prime = 12;
85static const uint16_t kHalfPrime = (/*kPrime=*/3329 - 1) / 2; 85static const uint16_t kHalfPrime = (/*kPrime=*/3329 - 1) / 2;
86static const int kDU768 = 10; 86static const int kDU768 = 10;
87static const int kDV768 = 4; 87static const int kDV768 = 4;
88static const int kDU1024 = 11;
89static const int kDV1024 = 5;
88 90
89/* 91/*
90 * kInverseDegree is 128^-1 mod 3329; 128 because kPrime does not have a 512th 92 * kInverseDegree is 128^-1 mod 3329; 128 because kPrime does not have a 512th
91 * root of unity. 93 * root of unity.
92 */ 94 */
93static const uint16_t kInverseDegree = 3303; 95static const uint16_t kInverseDegree = 3303;
94static const size_t kEncodedVectorSize = 96
95 (/*kLog2Prime=*/12 * DEGREE / 8) * RANK768; 97static inline size_t
96static const size_t kCompressedVectorSize = /*kDU768=*/ 10 * RANK768 * DEGREE / 98encoded_vector_size(uint16_t rank)
97 8; 99{
100 return (kLog2Prime * DEGREE / 8) * rank;
101}
102
103static inline size_t
104compressed_vector_size(uint16_t rank)
105{
106 return ((rank == RANK768) ? kDU768 : kDU1024) * rank * DEGREE / 8;
107}
98 108
99typedef struct scalar { 109typedef struct scalar {
100 /* On every function entry and exit, 0 <= c < kPrime. */ 110 /* On every function entry and exit, 0 <= c < kPrime. */
101 uint16_t c[DEGREE]; 111 uint16_t c[DEGREE];
102} scalar; 112} scalar;
103 113
104typedef struct vector { 114/*
105 scalar v[RANK768]; 115 * Retrieve a const scalar from const matrix of |rank| at position [row][col]
106} vector; 116 */
117static inline const scalar *
118const_m2s(const scalar *v, size_t row, size_t col, uint16_t rank)
119{
120 return ((scalar *)v) + row * rank + col;
121}
107 122
108typedef struct matrix { 123/*
109 scalar v[RANK768][RANK768]; 124 * Retrieve a scalar from matrix of |rank| at position [row][col]
110} matrix; 125 */
126static inline scalar *
127m2s(scalar *v, size_t row, size_t col, uint16_t rank)
128{
129 return ((scalar *)v) + row * rank + col;
130}
111 131
112/* 132/*
113 * This bit of Python will be referenced in some of the following comments: 133 * This bit of Python will be referenced in some of the following comments:
@@ -184,10 +204,10 @@ reduce_once(uint16_t x)
184 * is a difference of 2x. 204 * is a difference of 2x.
185 * 205 *
186 * We usually add value barriers to selects because Clang turns 206 * We usually add value barriers to selects because Clang turns
187 * consecutive selects with the same condition into a branch instead of 207 * consecutive selects with the same condition into a branch instead of
188 * CMOV/CSEL. This condition does not occur in ML-KEM, so omitting it 208 * CMOV/CSEL. This condition does not occur in ML-KEM, so omitting it
189 * seems to be safe so far but see 209 * seems to be safe so far but see
190 * |scalar_centered_binomial_distribution_eta_2_with_prf|. 210 * |scalar_centered_binomial_distribution_eta_2_with_prf|.
191 */ 211 */
192 return (mask & x) | (~mask & subtracted); 212 return (mask & x) | (~mask & subtracted);
193} 213}
@@ -214,9 +234,9 @@ scalar_zero(scalar *out)
214} 234}
215 235
216static void 236static void
217vector_zero(vector *out) 237vector_zero(scalar *out, size_t rank)
218{ 238{
219 memset(out, 0, sizeof(*out)); 239 memset(out, 0, sizeof(*out) * rank);
220} 240}
221 241
222/* 242/*
@@ -258,12 +278,12 @@ scalar_ntt(scalar *s)
258} 278}
259 279
260static void 280static void
261vector_ntt(vector *a) 281vector_ntt(scalar *v, size_t rank)
262{ 282{
263 int i; 283 size_t i;
264 284
265 for (i = 0; i < RANK768; i++) { 285 for (i = 0; i < rank; i++) {
266 scalar_ntt(&a->v[i]); 286 scalar_ntt(&v[i]);
267 } 287 }
268} 288}
269 289
@@ -305,12 +325,12 @@ scalar_inverse_ntt(scalar *s)
305} 325}
306 326
307static void 327static void
308vector_inverse_ntt(vector *a) 328vector_inverse_ntt(scalar *v, size_t rank)
309{ 329{
310 int i; 330 size_t i;
311 331
312 for (i = 0; i < RANK768; i++) { 332 for (i = 0; i < rank; i++) {
313 scalar_inverse_ntt(&a->v[i]); 333 scalar_inverse_ntt(&v[i]);
314 } 334 }
315} 335}
316 336
@@ -364,58 +384,58 @@ scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs)
364} 384}
365 385
366static void 386static void
367vector_add(vector *lhs, const vector *rhs) 387vector_add(scalar *lhs, const scalar *rhs, size_t rank)
368{ 388{
369 int i; 389 size_t i;
370 390
371 for (i = 0; i < RANK768; i++) { 391 for (i = 0; i < rank; i++) {
372 scalar_add(&lhs->v[i], &rhs->v[i]); 392 scalar_add(&lhs[i], &rhs[i]);
373 } 393 }
374} 394}
375 395
376static void 396static void
377matrix_mult(vector *out, const matrix *m, const vector *a) 397matrix_mult(scalar *out, const void *m, const scalar *a, size_t rank)
378{ 398{
379 int i, j; 399 size_t i, j;
380 400
381 vector_zero(out); 401 vector_zero(&out[0], rank);
382 for (i = 0; i < RANK768; i++) { 402 for (i = 0; i < rank; i++) {
383 for (j = 0; j < RANK768; j++) { 403 for (j = 0; j < rank; j++) {
384 scalar product; 404 scalar product;
385 405
386 scalar_mult(&product, &m->v[i][j], &a->v[j]); 406 scalar_mult(&product, const_m2s(m, i, j, rank), &a[j]);
387 scalar_add(&out->v[i], &product); 407 scalar_add(&out[i], &product);
388 } 408 }
389 } 409 }
390} 410}
391 411
392static void 412static void
393matrix_mult_transpose(vector *out, const matrix *m, 413matrix_mult_transpose(scalar *out, const void *m, const scalar *a, size_t rank)
394 const vector *a)
395{ 414{
396 int i, j; 415 int i, j;
397 416
398 vector_zero(out); 417 vector_zero(&out[0], rank);
399 for (i = 0; i < RANK768; i++) { 418 for (i = 0; i < rank; i++) {
400 for (j = 0; j < RANK768; j++) { 419 for (j = 0; j < rank; j++) {
401 scalar product; 420 scalar product;
402 421
403 scalar_mult(&product, &m->v[j][i], &a->v[j]); 422 scalar_mult(&product, const_m2s(m, j, i, rank), &a[j]);
404 scalar_add(&out->v[i], &product); 423 scalar_add(&out[i], &product);
405 } 424 }
406 } 425 }
407} 426}
408 427
409static void 428static void
410scalar_inner_product(scalar *out, const vector *lhs, 429scalar_inner_product(scalar *out, const scalar *lhs,
411 const vector *rhs) 430 const scalar *rhs, size_t rank)
412{ 431{
413 int i; 432 size_t i;
433
414 scalar_zero(out); 434 scalar_zero(out);
415 for (i = 0; i < RANK768; i++) { 435 for (i = 0; i < rank; i++) {
416 scalar product; 436 scalar product;
417 437
418 scalar_mult(&product, &lhs->v[i], &rhs->v[i]); 438 scalar_mult(&product, &lhs[i], &rhs[i]);
419 scalar_add(out, &product); 439 scalar_add(out, &product);
420 } 440 }
421} 441}
@@ -498,30 +518,30 @@ scalar_centered_binomial_distribution_eta_2_with_prf(scalar *out,
498 * appending and incrementing |counter| for entry of the vector. 518 * appending and incrementing |counter| for entry of the vector.
499 */ 519 */
500static void 520static void
501vector_generate_secret_eta_2(vector *out, uint8_t *counter, 521vector_generate_secret_eta_2(scalar *out, uint8_t *counter,
502 const uint8_t seed[32]) 522 const uint8_t seed[32], size_t rank)
503{ 523{
504 uint8_t input[33]; 524 uint8_t input[33];
505 int i; 525 size_t i;
506 526
507 memcpy(input, seed, 32); 527 memcpy(input, seed, 32);
508 for (i = 0; i < RANK768; i++) { 528 for (i = 0; i < rank; i++) {
509 input[32] = (*counter)++; 529 input[32] = (*counter)++;
510 scalar_centered_binomial_distribution_eta_2_with_prf(&out->v[i], 530 scalar_centered_binomial_distribution_eta_2_with_prf(&out[i],
511 input); 531 input);
512 } 532 }
513} 533}
514 534
515/* Expands the matrix of a seed for key generation and for encaps-CPA. */ 535/* Expands the matrix of a seed for key generation and for encaps-CPA. */
516static void 536static void
517matrix_expand(matrix *out, const uint8_t rho[32]) 537matrix_expand(void *out, const uint8_t rho[32], size_t rank)
518{ 538{
519 uint8_t input[34]; 539 uint8_t input[34];
520 int i, j; 540 size_t i, j;
521 541
522 memcpy(input, rho, 32); 542 memcpy(input, rho, 32);
523 for (i = 0; i < RANK768; i++) { 543 for (i = 0; i < rank; i++) {
524 for (j = 0; j < RANK768; j++) { 544 for (j = 0; j < rank; j++) {
525 sha3_ctx keccak_ctx; 545 sha3_ctx keccak_ctx;
526 546
527 input[32] = i; 547 input[32] = i;
@@ -529,7 +549,8 @@ matrix_expand(matrix *out, const uint8_t rho[32])
529 shake128_init(&keccak_ctx); 549 shake128_init(&keccak_ctx);
530 shake_update(&keccak_ctx, input, sizeof(input)); 550 shake_update(&keccak_ctx, input, sizeof(input));
531 shake_xof(&keccak_ctx); 551 shake_xof(&keccak_ctx);
532 scalar_from_keccak_vartime(&out->v[i][j], &keccak_ctx); 552 scalar_from_keccak_vartime(m2s(out, i, j, rank),
553 &keccak_ctx);
533 } 554 }
534 } 555 }
535} 556}
@@ -599,24 +620,24 @@ scalar_encode_1(uint8_t out[32], const scalar *s)
599 * whole number of bytes, so we do not need to worry about bit packing here. 620 * whole number of bytes, so we do not need to worry about bit packing here.
600 */ 621 */
601static void 622static void
602vector_encode(uint8_t *out, const vector *a, int bits) 623vector_encode(uint8_t *out, const scalar *a, int bits, size_t rank)
603{ 624{
604 int i; 625 int i;
605 626
606 for (i = 0; i < RANK768; i++) { 627 for (i = 0; i < rank; i++) {
607 scalar_encode(out + i * bits * DEGREE / 8, &a->v[i], bits); 628 scalar_encode(out + i * bits * DEGREE / 8, &a[i], bits);
608 } 629 }
609} 630}
610 631
611/* Encodes an entire vector as above, but adding it to a CBB */ 632/* Encodes an entire vector as above, but adding it to a CBB */
612static int 633static int
613vector_encode_cbb(CBB *cbb, const vector *a, int bits) 634vector_encode_cbb(CBB *cbb, const scalar *a, int bits, size_t rank)
614{ 635{
615 uint8_t *encoded_vector; 636 uint8_t *encoded_vector;
616 637
617 if (!CBB_add_space(cbb, &encoded_vector, kEncodedVectorSize)) 638 if (!CBB_add_space(cbb, &encoded_vector, encoded_vector_size(rank)))
618 return 0; 639 return 0;
619 vector_encode(encoded_vector, a, bits); 640 vector_encode(encoded_vector, a, bits, rank);
620 641
621 return 1; 642 return 1;
622} 643}
@@ -690,12 +711,12 @@ scalar_decode_1(scalar *out, const uint8_t in[32])
690 * success or zero if any parsed value is >= |kPrime|. 711 * success or zero if any parsed value is >= |kPrime|.
691 */ 712 */
692static int 713static int
693vector_decode(vector *out, const uint8_t *in, int bits) 714vector_decode(scalar *out, const uint8_t *in, int bits, size_t rank)
694{ 715{
695 int i; 716 size_t i;
696 717
697 for (i = 0; i < RANK768; i++) { 718 for (i = 0; i < rank; i++) {
698 if (!scalar_decode(&out->v[i], in + i * bits * DEGREE / 8, 719 if (!scalar_decode(&out[i], in + i * bits * DEGREE / 8,
699 bits)) { 720 bits)) {
700 return 0; 721 return 0;
701 } 722 }
@@ -776,139 +797,182 @@ scalar_decompress(scalar *s, int bits)
776} 797}
777 798
778static void 799static void
779vector_compress(vector *a, int bits) 800vector_compress(scalar *v, int bits, size_t rank)
780{ 801{
781 int i; 802 size_t i;
782 803
783 for (i = 0; i < RANK768; i++) { 804 for (i = 0; i < rank; i++) {
784 scalar_compress(&a->v[i], bits); 805 scalar_compress(&v[i], bits);
785 } 806 }
786} 807}
787 808
788static void 809static void
789vector_decompress(vector *a, int bits) 810vector_decompress(scalar *v, int bits, size_t rank)
790{ 811{
791 int i; 812 int i;
792 813
793 for (i = 0; i < RANK768; i++) { 814 for (i = 0; i < rank; i++) {
794 scalar_decompress(&a->v[i], bits); 815 scalar_decompress(&v[i], bits);
795 } 816 }
796} 817}
797 818
798struct public_key { 819struct public_key {
799 vector t; 820 scalar *t;
800 uint8_t rho[32]; 821 uint8_t *rho;
801 uint8_t public_key_hash[32]; 822 uint8_t *public_key_hash;
802 matrix m; 823 scalar *m;
803}; 824};
804 825
805CTASSERT(sizeof(struct MLKEM768_public_key) == sizeof(struct public_key)); 826static void
806 827public_key_from_external(const MLKEM_public_key *external,
807static struct public_key * 828 struct public_key *pub)
808public_key_768_from_external(const MLKEM_public_key *external)
809{ 829{
810 if (external->rank != RANK768) 830 size_t vector_size = external->rank * sizeof(scalar);
811 return NULL; 831 uint8_t *bytes = external->key_768->bytes;
812 return (struct public_key *)external->key_768; 832 size_t offset = 0;
833
834 if (external->rank == RANK1024)
835 bytes = external->key_1024->bytes;
836
837 pub->t = (struct scalar *)bytes + offset;
838 offset += vector_size;
839 pub->rho = bytes + offset;
840 offset += 32;
841 pub->public_key_hash = bytes + offset;
842 offset += 32;
843 pub->m = (void *)(bytes + offset);
844 offset += vector_size * external->rank;
813} 845}
814 846
815struct private_key { 847struct private_key {
816 struct public_key pub; 848 struct public_key pub;
817 vector s; 849 scalar *s;
818 uint8_t fo_failure_secret[32]; 850 uint8_t *fo_failure_secret;
819}; 851};
820 852
821CTASSERT(sizeof(struct MLKEM768_private_key) == sizeof(struct private_key)); 853static void
822 854private_key_from_external(const MLKEM_private_key *external,
823static struct private_key * 855 struct private_key *priv)
824private_key_768_from_external(const MLKEM_private_key *external)
825{ 856{
826 if (external->rank != RANK768) 857 size_t vector_size = external->rank * sizeof(scalar);
827 return NULL; 858 size_t offset = 0;
828 return (struct private_key *)external->key_768; 859 uint8_t *bytes = external->key_768->bytes;
860
861 if (external->rank == RANK1024)
862 bytes = external->key_1024->bytes;
863
864 priv->pub.t = (struct scalar *)(bytes + offset);
865 offset += vector_size;
866 priv->pub.rho = bytes + offset;
867 offset += 32;
868 priv->pub.public_key_hash = bytes + offset;
869 offset += 32;
870 priv->pub.m = (void *)(bytes + offset);
871 offset += vector_size * external->rank;
872 priv->s = (void *)(bytes + offset);
873 offset += vector_size;
874 priv->fo_failure_secret = bytes + offset;
875 offset += 32;
829} 876}
830 877
831/* 878/*
832 * Calls |MLKEM768_generate_key_external_entropy| with random bytes from 879 * Calls |mlkem_generate_key_external_entropy| with random bytes from
833 * |RAND_bytes|. 880 * |RAND_bytes|.
834 */ 881 */
835int 882int
836MLKEM768_generate_key(uint8_t out_encoded_public_key[MLKEM768_PUBLIC_KEY_BYTES], 883mlkem_generate_key(uint8_t *out_encoded_public_key,
837 uint8_t optional_out_seed[MLKEM_SEED_BYTES], 884 uint8_t optional_out_seed[MLKEM_SEED_LENGTH],
838 MLKEM_private_key *out_private_key) 885 MLKEM_private_key *out_private_key)
839{ 886{
840 uint8_t entropy_buf[MLKEM_SEED_BYTES]; 887 uint8_t entropy_buf[MLKEM_SEED_LENGTH];
841 uint8_t *entropy = optional_out_seed != NULL ? optional_out_seed : 888 uint8_t *entropy = optional_out_seed != NULL ? optional_out_seed :
842 entropy_buf; 889 entropy_buf;
843 890
844 arc4random_buf(entropy, MLKEM_SEED_BYTES); 891 arc4random_buf(entropy, MLKEM_SEED_LENGTH);
845 return MLKEM768_generate_key_external_entropy(out_encoded_public_key, 892 return mlkem_generate_key_external_entropy(out_encoded_public_key,
846 out_private_key, entropy); 893 out_private_key, entropy);
847} 894}
848 895
849int 896int
850MLKEM768_private_key_from_seed(const uint8_t *seed, size_t seed_len, 897mlkem_private_key_from_seed(const uint8_t *seed, size_t seed_len,
851 MLKEM_private_key *out_private_key) 898 MLKEM_private_key *out_private_key)
852{ 899{
853 /* XXX stack */ 900 uint8_t *public_key_buf = NULL;
854 uint8_t public_key_bytes[MLKEM768_PUBLIC_KEY_BYTES]; 901 size_t public_key_buf_len = out_private_key->rank == RANK768 ?
902 MLKEM768_PUBLIC_KEY_BYTES : MLKEM1024_PUBLIC_KEY_BYTES;
903 int ret = 0;
855 904
856 if (seed_len != MLKEM_SEED_BYTES) { 905 if (seed_len != MLKEM_SEED_LENGTH) {
857 return 0; 906 goto err;
858 } 907 }
859 return MLKEM768_generate_key_external_entropy(public_key_bytes, 908
909 if ((public_key_buf = calloc(1, public_key_buf_len)) == NULL)
910 goto err;
911
912 ret = mlkem_generate_key_external_entropy(public_key_buf,
860 out_private_key, seed); 913 out_private_key, seed);
914
915 err:
916 freezero(public_key_buf, public_key_buf_len);
917
918 return ret;
861} 919}
862 920
863static int 921static int
864mlkem_marshal_public_key(CBB *out, const struct public_key *pub) 922mlkem_marshal_public_key_internal(CBB *out, const struct public_key *pub,
923 size_t rank)
865{ 924{
866 if (!vector_encode_cbb(out, &pub->t, kLog2Prime)) 925 if (!vector_encode_cbb(out, &pub->t[0], kLog2Prime, rank))
867 return 0; 926 return 0;
868 return CBB_add_bytes(out, pub->rho, sizeof(pub->rho)); 927 return CBB_add_bytes(out, pub->rho, 32);
869} 928}
870 929
871int 930int
872MLKEM768_generate_key_external_entropy( 931mlkem_generate_key_external_entropy(uint8_t *out_encoded_public_key,
873 uint8_t out_encoded_public_key[MLKEM768_PUBLIC_KEY_BYTES],
874 MLKEM_private_key *out_private_key, 932 MLKEM_private_key *out_private_key,
875 const uint8_t entropy[MLKEM_SEED_BYTES]) 933 const uint8_t entropy[MLKEM_SEED_LENGTH])
876{ 934{
877 struct private_key *priv = private_key_768_from_external( 935 struct private_key priv;
878 out_private_key);
879 uint8_t augmented_seed[33]; 936 uint8_t augmented_seed[33];
880 uint8_t *rho, *sigma; 937 uint8_t *rho, *sigma;
881 uint8_t counter = 0; 938 uint8_t counter = 0;
882 uint8_t hashed[64]; 939 uint8_t hashed[64];
883 vector error; 940 scalar error[RANK1024];
884 CBB cbb; 941 CBB cbb;
885 int ret = 0; 942 int ret = 0;
886 943
944 private_key_from_external(out_private_key, &priv);
887 memset(&cbb, 0, sizeof(CBB)); 945 memset(&cbb, 0, sizeof(CBB));
888 memcpy(augmented_seed, entropy, 32); 946 memcpy(augmented_seed, entropy, 32);
889 augmented_seed[32] = RANK768; 947 augmented_seed[32] = out_private_key->rank;
890 hash_g(hashed, augmented_seed, 33); 948 hash_g(hashed, augmented_seed, 33);
891 rho = hashed; 949 rho = hashed;
892 sigma = hashed + 32; 950 sigma = hashed + 32;
893 memcpy(priv->pub.rho, hashed, sizeof(priv->pub.rho)); 951 memcpy(priv.pub.rho, hashed, 32);
894 matrix_expand(&priv->pub.m, rho); 952 matrix_expand(priv.pub.m, rho, out_private_key->rank);
895 vector_generate_secret_eta_2(&priv->s, &counter, sigma); 953 vector_generate_secret_eta_2(priv.s, &counter, sigma,
896 vector_ntt(&priv->s); 954 out_private_key->rank);
897 vector_generate_secret_eta_2(&error, &counter, sigma); 955 vector_ntt(priv.s, out_private_key->rank);
898 vector_ntt(&error); 956 vector_generate_secret_eta_2(&error[0], &counter, sigma,
899 matrix_mult_transpose(&priv->pub.t, &priv->pub.m, &priv->s); 957 out_private_key->rank);
900 vector_add(&priv->pub.t, &error); 958 vector_ntt(&error[0], out_private_key->rank);
959 matrix_mult_transpose(priv.pub.t, priv.pub.m, priv.s,
960 out_private_key->rank);
961 vector_add(priv.pub.t, &error[0], out_private_key->rank);
901 962
902 if (!CBB_init_fixed(&cbb, out_encoded_public_key, 963 if (!CBB_init_fixed(&cbb, out_encoded_public_key,
903 MLKEM768_PUBLIC_KEY_BYTES)) 964 out_private_key->rank == RANK768 ? MLKEM768_PUBLIC_KEY_BYTES :
965 MLKEM1024_PUBLIC_KEY_BYTES))
904 goto err; 966 goto err;
905 967
906 if (!mlkem_marshal_public_key(&cbb, &priv->pub)) 968 if (!mlkem_marshal_public_key_internal(&cbb, &priv.pub,
969 out_private_key->rank))
907 goto err; 970 goto err;
908 971
909 hash_h(priv->pub.public_key_hash, out_encoded_public_key, 972 hash_h(priv.pub.public_key_hash, out_encoded_public_key,
910 MLKEM768_PUBLIC_KEY_BYTES); 973 out_private_key->rank == RANK768 ? MLKEM768_PUBLIC_KEY_BYTES :
911 memcpy(priv->fo_failure_secret, entropy + 32, 32); 974 MLKEM1024_PUBLIC_KEY_BYTES);
975 memcpy(priv.fo_failure_secret, entropy + 32, 32);
912 976
913 ret = 1; 977 ret = 1;
914 978
@@ -919,14 +983,21 @@ MLKEM768_generate_key_external_entropy(
919} 983}
920 984
921void 985void
922MLKEM768_public_from_private(const MLKEM_private_key *private_key, 986mlkem_public_from_private(const MLKEM_private_key *private_key,
923 MLKEM_public_key *out_public_key) { 987 MLKEM_public_key *out_public_key)
924 struct public_key *const pub = public_key_768_from_external( 988{
925 out_public_key); 989 switch (private_key->rank) {
926 const struct private_key *const priv = private_key_768_from_external( 990 case RANK768:
927 private_key); 991 memcpy(out_public_key->key_768->bytes,
928 992 private_key->key_768->bytes,
929 *pub = priv->pub; 993 sizeof(struct MLKEM768_public_key));
994 break;
995 case RANK1024:
996 memcpy(out_public_key->key_1024->bytes,
997 private_key->key_1024->bytes,
998 sizeof(struct MLKEM1024_public_key));
999 break;
1000 }
930} 1001}
931 1002
932/* 1003/*
@@ -935,84 +1006,97 @@ MLKEM768_public_from_private(const MLKEM_private_key *private_key,
935 * scheme, since lattice schemes are vulnerable to decryption failure oracles. 1006 * scheme, since lattice schemes are vulnerable to decryption failure oracles.
936 */ 1007 */
937static void 1008static void
938encrypt_cpa(uint8_t out[MLKEM768_CIPHERTEXT_BYTES], 1009encrypt_cpa(uint8_t *out, const struct public_key *pub,
939 const struct public_key *pub, const uint8_t message[32], 1010 const uint8_t message[32], const uint8_t randomness[32],
940 const uint8_t randomness[32]) 1011 size_t rank)
941{ 1012{
1013 scalar secret[RANK1024], error[RANK1024], u[RANK1024];
942 scalar expanded_message, scalar_error; 1014 scalar expanded_message, scalar_error;
943 vector secret, error, u;
944 uint8_t counter = 0; 1015 uint8_t counter = 0;
945 uint8_t input[33]; 1016 uint8_t input[33];
946 scalar v; 1017 scalar v;
1018 int u_bits = kDU768;
1019 int v_bits = kDV768;
947 1020
948 vector_generate_secret_eta_2(&secret, &counter, randomness); 1021 if (rank == RANK1024) {
949 vector_ntt(&secret); 1022 u_bits = kDU1024;
950 vector_generate_secret_eta_2(&error, &counter, randomness); 1023 v_bits = kDV1024;
1024 }
1025 vector_generate_secret_eta_2(&secret[0], &counter, randomness, rank);
1026 vector_ntt(&secret[0], rank);
1027 vector_generate_secret_eta_2(&error[0], &counter, randomness, rank);
951 memcpy(input, randomness, 32); 1028 memcpy(input, randomness, 32);
952 input[32] = counter; 1029 input[32] = counter;
953 scalar_centered_binomial_distribution_eta_2_with_prf(&scalar_error, 1030 scalar_centered_binomial_distribution_eta_2_with_prf(&scalar_error,
954 input); 1031 input);
955 matrix_mult(&u, &pub->m, &secret); 1032 matrix_mult(&u[0], pub->m, &secret[0], rank);
956 vector_inverse_ntt(&u); 1033 vector_inverse_ntt(&u[0], rank);
957 vector_add(&u, &error); 1034 vector_add(&u[0], &error[0], rank);
958 scalar_inner_product(&v, &pub->t, &secret); 1035 scalar_inner_product(&v, &pub->t[0], &secret[0], rank);
959 scalar_inverse_ntt(&v); 1036 scalar_inverse_ntt(&v);
960 scalar_add(&v, &scalar_error); 1037 scalar_add(&v, &scalar_error);
961 scalar_decode_1(&expanded_message, message); 1038 scalar_decode_1(&expanded_message, message);
962 scalar_decompress(&expanded_message, 1); 1039 scalar_decompress(&expanded_message, 1);
963 scalar_add(&v, &expanded_message); 1040 scalar_add(&v, &expanded_message);
964 vector_compress(&u, kDU768); 1041 vector_compress(&u[0], u_bits, rank);
965 vector_encode(out, &u, kDU768); 1042 vector_encode(out, &u[0], u_bits, rank);
966 scalar_compress(&v, kDV768); 1043 scalar_compress(&v, v_bits);
967 scalar_encode(out + kCompressedVectorSize, &v, kDV768); 1044 scalar_encode(out + compressed_vector_size(rank), &v, v_bits);
968} 1045}
969 1046
970/* Calls MLKEM768_encap_external_entropy| with random bytes */ 1047/* Calls mlkem_encap_external_entropy| with random bytes */
971void 1048void
972MLKEM768_encap(const MLKEM_public_key *public_key, 1049mlkem_encap(const MLKEM_public_key *public_key,
973 uint8_t out_ciphertext[MLKEM768_CIPHERTEXT_BYTES], 1050 uint8_t *out_ciphertext,
974 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES]) 1051 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_LENGTH])
975{ 1052{
976 uint8_t entropy[MLKEM_ENCAP_ENTROPY]; 1053 uint8_t entropy[MLKEM_ENCAP_ENTROPY];
977 1054
978 arc4random_buf(entropy, MLKEM_ENCAP_ENTROPY); 1055 arc4random_buf(entropy, MLKEM_ENCAP_ENTROPY);
979 MLKEM768_encap_external_entropy(out_ciphertext, 1056 mlkem_encap_external_entropy(out_ciphertext,
980 out_shared_secret, public_key, entropy); 1057 out_shared_secret, public_key, entropy);
981} 1058}
982 1059
983/* See section 6.2 of the spec. */ 1060/* See section 6.2 of the spec. */
984void 1061void
985MLKEM768_encap_external_entropy( 1062mlkem_encap_external_entropy(uint8_t *out_ciphertext,
986 uint8_t out_ciphertext[MLKEM768_CIPHERTEXT_BYTES], 1063 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_LENGTH],
987 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
988 const MLKEM_public_key *public_key, 1064 const MLKEM_public_key *public_key,
989 const uint8_t entropy[MLKEM_ENCAP_ENTROPY]) 1065 const uint8_t entropy[MLKEM_ENCAP_ENTROPY])
990{ 1066{
991 const struct public_key *pub = public_key_768_from_external(public_key); 1067 struct public_key pub;
992 uint8_t key_and_randomness[64]; 1068 uint8_t key_and_randomness[64];
993 uint8_t input[64]; 1069 uint8_t input[64];
994 1070
1071 public_key_from_external(public_key, &pub);
995 memcpy(input, entropy, MLKEM_ENCAP_ENTROPY); 1072 memcpy(input, entropy, MLKEM_ENCAP_ENTROPY);
996 memcpy(input + MLKEM_ENCAP_ENTROPY, pub->public_key_hash, 1073 memcpy(input + MLKEM_ENCAP_ENTROPY, pub.public_key_hash,
997 sizeof(input) - MLKEM_ENCAP_ENTROPY); 1074 sizeof(input) - MLKEM_ENCAP_ENTROPY);
998 hash_g(key_and_randomness, input, sizeof(input)); 1075 hash_g(key_and_randomness, input, sizeof(input));
999 encrypt_cpa(out_ciphertext, pub, entropy, key_and_randomness + 32); 1076 encrypt_cpa(out_ciphertext, &pub, entropy, key_and_randomness + 32,
1077 public_key->rank);
1000 memcpy(out_shared_secret, key_and_randomness, 32); 1078 memcpy(out_shared_secret, key_and_randomness, 32);
1001} 1079}
1002 1080
1003static void 1081static void
1004decrypt_cpa(uint8_t out[32], const struct private_key *priv, 1082decrypt_cpa(uint8_t out[32], const struct private_key *priv,
1005 const uint8_t ciphertext[MLKEM768_CIPHERTEXT_BYTES]) 1083 const uint8_t *ciphertext, size_t rank)
1006{ 1084{
1085 scalar u[RANK1024];
1007 scalar mask, v; 1086 scalar mask, v;
1008 vector u; 1087 int u_bits = kDU768;
1009 1088 int v_bits = kDV768;
1010 vector_decode(&u, ciphertext, kDU768); 1089
1011 vector_decompress(&u, kDU768); 1090 if (rank == RANK1024) {
1012 vector_ntt(&u); 1091 u_bits = kDU1024;
1013 scalar_decode(&v, ciphertext + kCompressedVectorSize, kDV768); 1092 v_bits = kDV1024;
1014 scalar_decompress(&v, kDV768); 1093 }
1015 scalar_inner_product(&mask, &priv->s, &u); 1094 vector_decode(&u[0], ciphertext, u_bits, rank);
1095 vector_decompress(&u[0], u_bits, rank);
1096 vector_ntt(&u[0], rank);
1097 scalar_decode(&v, ciphertext + compressed_vector_size(rank), v_bits);
1098 scalar_decompress(&v, v_bits);
1099 scalar_inner_product(&mask, &priv->s[0], &u[0], rank);
1016 scalar_inverse_ntt(&mask); 1100 scalar_inverse_ntt(&mask);
1017 scalar_sub(&v, &mask); 1101 scalar_sub(&v, &mask);
1018 scalar_compress(&v, 1); 1102 scalar_compress(&v, 1);
@@ -1021,51 +1105,67 @@ decrypt_cpa(uint8_t out[32], const struct private_key *priv,
1021 1105
1022/* See section 6.3 */ 1106/* See section 6.3 */
1023int 1107int
1024MLKEM768_decap(const MLKEM_private_key *private_key, const uint8_t *ciphertext, 1108mlkem_decap(const MLKEM_private_key *private_key, const uint8_t *ciphertext,
1025 size_t ciphertext_len, uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES]) 1109 size_t ciphertext_len, uint8_t out_shared_secret[MLKEM_SHARED_SECRET_LENGTH])
1026{ 1110{
1027 const struct private_key *priv = private_key_768_from_external( 1111 struct private_key priv;
1028 private_key); 1112 size_t expected_ciphertext_length = private_key->rank == RANK768 ?
1029 uint8_t expected_ciphertext[MLKEM768_CIPHERTEXT_BYTES]; 1113 MLKEM768_CIPHERTEXT_BYTES : MLKEM1024_CIPHERTEXT_BYTES;
1114 uint8_t *expected_ciphertext = NULL;
1030 uint8_t key_and_randomness[64]; 1115 uint8_t key_and_randomness[64];
1031 uint8_t failure_key[32]; 1116 uint8_t failure_key[32];
1032 uint8_t decrypted[64]; 1117 uint8_t decrypted[64];
1033 uint8_t mask; 1118 uint8_t mask;
1034 int i; 1119 int i;
1120 int ret = 0;
1035 1121
1036 if (ciphertext_len != MLKEM768_CIPHERTEXT_BYTES) { 1122 if (ciphertext_len != expected_ciphertext_length) {
1037 arc4random_buf(out_shared_secret, MLKEM_SHARED_SECRET_BYTES); 1123 arc4random_buf(out_shared_secret, MLKEM_SHARED_SECRET_LENGTH);
1038 return 0; 1124 goto err;
1039 } 1125 }
1040 1126
1041 decrypt_cpa(decrypted, priv, ciphertext); 1127 if ((expected_ciphertext = calloc(1, expected_ciphertext_length)) ==
1042 memcpy(decrypted + 32, priv->pub.public_key_hash, 1128 NULL) {
1129 arc4random_buf(out_shared_secret, MLKEM_SHARED_SECRET_LENGTH);
1130 goto err;
1131 }
1132
1133 private_key_from_external(private_key, &priv);
1134 decrypt_cpa(decrypted, &priv, ciphertext, private_key->rank);
1135 memcpy(decrypted + 32, priv.pub.public_key_hash,
1043 sizeof(decrypted) - 32); 1136 sizeof(decrypted) - 32);
1044 hash_g(key_and_randomness, decrypted, sizeof(decrypted)); 1137 hash_g(key_and_randomness, decrypted, sizeof(decrypted));
1045 encrypt_cpa(expected_ciphertext, &priv->pub, decrypted, 1138 encrypt_cpa(expected_ciphertext, &priv.pub, decrypted,
1046 key_and_randomness + 32); 1139 key_and_randomness + 32, private_key->rank);
1047 kdf(failure_key, priv->fo_failure_secret, ciphertext, ciphertext_len); 1140 kdf(failure_key, priv.fo_failure_secret, ciphertext, ciphertext_len);
1048 mask = constant_time_eq_int_8(memcmp(ciphertext, expected_ciphertext, 1141 mask = constant_time_eq_int_8(memcmp(ciphertext, expected_ciphertext,
1049 sizeof(expected_ciphertext)), 0); 1142 expected_ciphertext_length), 0);
1050 for (i = 0; i < MLKEM_SHARED_SECRET_BYTES; i++) { 1143 for (i = 0; i < MLKEM_SHARED_SECRET_LENGTH; i++) {
1051 out_shared_secret[i] = constant_time_select_8(mask, 1144 out_shared_secret[i] = constant_time_select_8(mask,
1052 key_and_randomness[i], failure_key[i]); 1145 key_and_randomness[i], failure_key[i]);
1053 } 1146 }
1054 1147
1055 return 1; 1148 ret = 1;
1149
1150 err:
1151 freezero(expected_ciphertext, expected_ciphertext_length);
1152
1153 return ret;
1056} 1154}
1057 1155
1058int 1156int
1059MLKEM768_marshal_public_key(const MLKEM_public_key *public_key, 1157mlkem_marshal_public_key(const MLKEM_public_key *public_key,
1060 uint8_t **output, size_t *output_len) 1158 uint8_t **output, size_t *output_len)
1061{ 1159{
1160 struct public_key pub;
1062 int ret = 0; 1161 int ret = 0;
1063 CBB cbb; 1162 CBB cbb;
1064 1163
1065 if (!CBB_init(&cbb, MLKEM768_PUBLIC_KEY_BYTES)) 1164 if (!CBB_init(&cbb, public_key->rank == RANK768 ?
1165 MLKEM768_PUBLIC_KEY_BYTES : MLKEM1024_PUBLIC_KEY_BYTES))
1066 goto err; 1166 goto err;
1067 if (!mlkem_marshal_public_key(&cbb, 1167 public_key_from_external(public_key, &pub);
1068 public_key_768_from_external(public_key))) 1168 if (!mlkem_marshal_public_key_internal(&cbb, &pub, public_key->rank))
1069 goto err; 1169 goto err;
1070 if (!CBB_finish(&cbb, output, output_len)) 1170 if (!CBB_finish(&cbb, output, output_len))
1071 goto err; 1171 goto err;
@@ -1083,61 +1183,63 @@ MLKEM768_marshal_public_key(const MLKEM_public_key *public_key,
1083 * the value of |pub->public_key_hash|. 1183 * the value of |pub->public_key_hash|.
1084 */ 1184 */
1085static int 1185static int
1086mlkem_parse_public_key_no_hash(struct public_key *pub, CBS *in) 1186mlkem_parse_public_key_no_hash(struct public_key *pub, CBS *in, size_t rank)
1087{ 1187{
1088 CBS t_bytes; 1188 CBS t_bytes;
1089 1189
1090 if (!CBS_get_bytes(in, &t_bytes, kEncodedVectorSize)) 1190 if (!CBS_get_bytes(in, &t_bytes, encoded_vector_size(rank)))
1091 return 0; 1191 return 0;
1092 if (!vector_decode(&pub->t, CBS_data(&t_bytes), kLog2Prime)) 1192 if (!vector_decode(&pub->t[0], CBS_data(&t_bytes), kLog2Prime, rank))
1093 return 0; 1193 return 0;
1094 1194
1095 memcpy(pub->rho, CBS_data(in), sizeof(pub->rho)); 1195 memcpy(pub->rho, CBS_data(in), 32);
1096 if (!CBS_skip(in, sizeof(pub->rho))) 1196 if (!CBS_skip(in, 32))
1097 return 0; 1197 return 0;
1098 matrix_expand(&pub->m, pub->rho); 1198 matrix_expand(pub->m, pub->rho, rank);
1099 return 1; 1199 return 1;
1100} 1200}
1101 1201
1102int 1202int
1103MLKEM768_parse_public_key(const uint8_t *input, size_t input_len, 1203mlkem_parse_public_key(const uint8_t *input, size_t input_len,
1104 MLKEM_public_key *public_key) 1204 MLKEM_public_key *public_key)
1105{ 1205{
1106 struct public_key *pub = public_key_768_from_external(public_key); 1206 struct public_key pub;
1107 CBS cbs; 1207 CBS cbs;
1108 1208
1209 public_key_from_external(public_key, &pub);
1109 CBS_init(&cbs, input, input_len); 1210 CBS_init(&cbs, input, input_len);
1110 if (!mlkem_parse_public_key_no_hash(pub, &cbs)) 1211 if (!mlkem_parse_public_key_no_hash(&pub, &cbs, public_key->rank))
1111 return 0; 1212 return 0;
1112 if (CBS_len(&cbs) != 0) 1213 if (CBS_len(&cbs) != 0)
1113 return 0; 1214 return 0;
1114 1215
1115 hash_h(pub->public_key_hash, input, input_len); 1216 hash_h(pub.public_key_hash, input, input_len);
1116 1217
1117 return 1; 1218 return 1;
1118} 1219}
1119 1220
1120int 1221int
1121MLKEM768_marshal_private_key(const MLKEM_private_key *private_key, 1222mlkem_marshal_private_key(const MLKEM_private_key *private_key,
1122 uint8_t **out_private_key, size_t *out_private_key_len) 1223 uint8_t **out_private_key, size_t *out_private_key_len)
1123{ 1224{
1124 const struct private_key *const priv = private_key_768_from_external( 1225 struct private_key priv;
1125 private_key); 1226 size_t key_length = private_key->rank == RANK768 ?
1227 MLKEM768_PRIVATE_KEY_BYTES : MLKEM1024_PRIVATE_KEY_BYTES;
1126 CBB cbb; 1228 CBB cbb;
1127 int ret = 0; 1229 int ret = 0;
1128 1230
1129 if (!CBB_init(&cbb, MLKEM768_PRIVATE_KEY_BYTES)) 1231 private_key_from_external(private_key, &priv);
1232 if (!CBB_init(&cbb, key_length))
1130 goto err; 1233 goto err;
1131 1234
1132 if (!vector_encode_cbb(&cbb, &priv->s, kLog2Prime)) 1235 if (!vector_encode_cbb(&cbb, priv.s, kLog2Prime, private_key->rank))
1133 goto err; 1236 goto err;
1134 if (!mlkem_marshal_public_key(&cbb, &priv->pub)) 1237 if (!mlkem_marshal_public_key_internal(&cbb, &priv.pub,
1238 private_key->rank))
1135 goto err; 1239 goto err;
1136 if (!CBB_add_bytes(&cbb, priv->pub.public_key_hash, 1240 if (!CBB_add_bytes(&cbb, priv.pub.public_key_hash, 32))
1137 sizeof(priv->pub.public_key_hash)))
1138 goto err; 1241 goto err;
1139 if (!CBB_add_bytes(&cbb, priv->fo_failure_secret, 1242 if (!CBB_add_bytes(&cbb, priv.fo_failure_secret, 32))
1140 sizeof(priv->fo_failure_secret)))
1141 goto err; 1243 goto err;
1142 1244
1143 if (!CBB_finish(&cbb, out_private_key, out_private_key_len)) 1245 if (!CBB_finish(&cbb, out_private_key, out_private_key_len))
@@ -1152,29 +1254,30 @@ MLKEM768_marshal_private_key(const MLKEM_private_key *private_key,
1152} 1254}
1153 1255
1154int 1256int
1155MLKEM768_parse_private_key(const uint8_t *input, size_t input_len, 1257mlkem_parse_private_key(const uint8_t *input, size_t input_len,
1156 MLKEM_private_key *out_private_key) 1258 MLKEM_private_key *out_private_key)
1157{ 1259{
1158 struct private_key *const priv = private_key_768_from_external( 1260 struct private_key priv;
1159 out_private_key);
1160 CBS cbs, s_bytes; 1261 CBS cbs, s_bytes;
1161 1262
1263 private_key_from_external(out_private_key, &priv);
1162 CBS_init(&cbs, input, input_len); 1264 CBS_init(&cbs, input, input_len);
1163 1265
1164 if (!CBS_get_bytes(&cbs, &s_bytes, kEncodedVectorSize)) 1266 if (!CBS_get_bytes(&cbs, &s_bytes,
1267 encoded_vector_size(out_private_key->rank)))
1165 return 0; 1268 return 0;
1166 if (!vector_decode(&priv->s, CBS_data(&s_bytes), kLog2Prime)) 1269 if (!vector_decode(priv.s, CBS_data(&s_bytes), kLog2Prime,
1270 out_private_key->rank))
1167 return 0; 1271 return 0;
1168 if (!mlkem_parse_public_key_no_hash(&priv->pub, &cbs)) 1272 if (!mlkem_parse_public_key_no_hash(&priv.pub, &cbs,
1273 out_private_key->rank))
1169 return 0; 1274 return 0;
1170 1275
1171 memcpy(priv->pub.public_key_hash, CBS_data(&cbs), 1276 memcpy(priv.pub.public_key_hash, CBS_data(&cbs), 32);
1172 sizeof(priv->pub.public_key_hash)); 1277 if (!CBS_skip(&cbs, 32))
1173 if (!CBS_skip(&cbs, sizeof(priv->pub.public_key_hash)))
1174 return 0; 1278 return 0;
1175 memcpy(priv->fo_failure_secret, CBS_data(&cbs), 1279 memcpy(priv.fo_failure_secret, CBS_data(&cbs), 32);
1176 sizeof(priv->fo_failure_secret)); 1280 if (!CBS_skip(&cbs, 32))
1177 if (!CBS_skip(&cbs, sizeof(priv->fo_failure_secret)))
1178 return 0; 1281 return 0;
1179 if (CBS_len(&cbs) != 0) 1282 if (CBS_len(&cbs) != 0)
1180 return 0; 1283 return 0;
diff --git a/src/lib/libcrypto/mlkem/mlkem_internal.h b/src/lib/libcrypto/mlkem/mlkem_internal.h
index 7e6c313aa9..2b3157256e 100644
--- a/src/lib/libcrypto/mlkem/mlkem_internal.h
+++ b/src/lib/libcrypto/mlkem/mlkem_internal.h
@@ -1,6 +1,7 @@
1/* $OpenBSD: mlkem_internal.h,v 1.9 2025/08/19 21:37:08 tb Exp $ */ 1/* $OpenBSD: mlkem_internal.h,v 1.10 2025/09/05 23:30:12 beck Exp $ */
2/* 2/*
3 * Copyright (c) 2023, Google Inc. 3 * Copyright (c) 2023, Google Inc.
4 * Copyright (c) 2025, Bob Beck <beck@obtuse.com>
4 * 5 *
5 * Permission to use, copy, modify, and/or distribute this software for any 6 * Permission to use, copy, modify, and/or distribute this software for any
6 * purpose with or without fee is hereby granted, provided that the above 7 * purpose with or without fee is hereby granted, provided that the above
@@ -26,402 +27,295 @@ extern "C" {
26#endif 27#endif
27 28
28__BEGIN_HIDDEN_DECLS 29__BEGIN_HIDDEN_DECLS
29/*
30 * MLKEM_SEED_LENGTH is the number of bytes in an ML-KEM seed. An ML-KEM
31 * seed is normally used to represent a private key.
32 */
33#define MLKEM_SEED_LENGTH 64
34 30
35/* 31/* Public opaque ML-KEM key structures. */
36 * MLKEM_SHARED_SECRET_LENGTH is the number of bytes in an ML-KEM shared
37 * secret.
38 */
39#define MLKEM_SHARED_SECRET_LENGTH 32
40 32
41/* 33#define MLKEM_PUBLIC_KEY_UNINITIALIZED 1
42 * |MLKEM_encap_external_entropy| behaves exactly like the public |MLKEM_encap| 34#define MLKEM_PUBLIC_KEY_INITIALIZED 2
43 * with the entropy provided by the caller. It is directly called internally 35#define MLKEM_PRIVATE_KEY_UNINITIALIZED 3
44 * and by tests. 36#define MLKEM_PRIVATE_KEY_INITIALIZED 4
45 */
46int
47MLKEM_encap_external_entropy(const MLKEM_public_key *public_key,
48 const uint8_t *entropy, uint8_t **out_ciphertext,
49 size_t *out_ciphertext_len, uint8_t **out_shared_secret,
50 size_t *out_shared_secret_len);
51 37
52/* 38struct MLKEM_public_key_st {
53 * |MLKEM_generate_key_external_entropy| behaves exactly like the public 39 uint16_t rank;
54 * |MLKEM_generate_key| with the entropy provided by the caller. 40 int state;
55 * It is directly called internally and by tests. 41 struct MLKEM768_public_key *key_768;
56 */ 42 struct MLKEM1024_public_key *key_1024;
57int 43};
58MLKEM_generate_key_external_entropy(MLKEM_private_key *private_key, 44
59 uint8_t **out_encoded_public_key, size_t *out_encoded_public_key_len, 45struct MLKEM_private_key_st {
60 const uint8_t *entropy); 46 uint16_t rank;
47 int state;
48 struct MLKEM768_private_key *key_768;
49 struct MLKEM1024_private_key *key_1024;
50};
61 51
62/* 52/*
63 * ML-KEM-768 53 * ML-KEM-768 and ML-KEM-1024
64 * 54 *
65 * This implements the Module-Lattice-Based Key-Encapsulation Mechanism from 55 * This implements the Module-Lattice-Based Key-Encapsulation Mechanism from
66 * https://csrc.nist.gov/pubs/fips/204/final 56 * https://csrc.nist.gov/pubs/fips/204/final
57 *
58 * You should prefer ML-KEM-768 where possible. ML-KEM-1024 is larger and exists
59 * for people who are obsessed with more 'bits of crypto', and who are also
60 * lacking the knowledge to realize that anything that can count to 256 bits
61 * must likely use an equivalent amount of energy to that of an entire star to
62 * do so.
63 *
64 * ML-KEM-768 is adequate to protect against a future cryptographically relevant
65 * quantum computer, VIC 20, abacus, or carefully calibrated reference dog. I
66 * for one plan on welcoming our new Kardashev-II civilization overlords with
67 * open arms. In the meantime will not waste bytes on the wire by to adding
68 * the fear of the possible future existence of a cryptographically relevant
69 * Dyson sphere to the aforementioned list of fear-inducing future
70 * cryptographically relevant hypotheticals.
71 *
72 * If your carefully calibrated reference dog notices the sun starting to dim,
73 * you might need ML-KEM-1024, but you probably have bigger concerns than
74 * the decryption of your stored past TLS sessions at that point.
67 */ 75 */
68 76
69/* 77/*
70 * MLKEM768_PUBLIC_KEY_BYTES is the number of bytes in an encoded ML-KEM768 public 78 * MLKEM1024_public_key contains an ML-KEM-1024 public key. The contents of this
71 * key. 79 * object should never leave the address space since the format is unstable.
72 */ 80 */
73#define MLKEM768_PUBLIC_KEY_BYTES 1184 81struct MLKEM1024_public_key {
74 82 uint8_t bytes[512 * (4 + 16) + 32 + 32];
75/* MLKEM_SEED_BYTES is the number of bytes in an ML-KEM seed. */ 83 uint16_t alignment;
76#define MLKEM_SEED_BYTES 64 84};
77 85
78/* 86/*
79 * MLKEM_SHARED_SECRET_BYTES is the number of bytes in the ML-KEM768 shared 87 * MLKEM1024_private_key contains a ML-KEM-1024 private key. The contents of
80 * secret. Although the round-3 specification has a variable-length output, the 88 * this object should never leave the address space since the format is
81 * final ML-KEM construction is expected to use a fixed 32-byte output. To 89 * unstable.
82 * simplify the future transition, we apply the same restriction.
83 */ 90 */
84#define MLKEM_SHARED_SECRET_BYTES 32 91struct MLKEM1024_private_key {
92 uint8_t bytes[512 * (4 + 4 + 16) + 32 + 32 + 32];
93 uint16_t alignment;
94};
85 95
86/* 96/*
87 * MLKEM_generate_key generates a random public/private key pair, writes the 97 * MLKEM768_public_key contains a ML-KEM-768 public key. The contents of this
88 * encoded public key to |out_encoded_public_key| and sets |out_private_key| to 98 * object should never leave the address space since the format is unstable.
89 * the private key. If |optional_out_seed| is not NULL then the seed used to
90 * generate the private key is written to it.
91 */ 99 */
92int MLKEM768_generate_key( 100struct MLKEM768_public_key {
93 uint8_t out_encoded_public_key[MLKEM768_PUBLIC_KEY_BYTES], 101 uint8_t bytes[512 * (3 + 9) + 32 + 32];
94 uint8_t optional_out_seed[MLKEM_SEED_BYTES], 102 uint16_t alignment;
95 MLKEM_private_key *out_private_key); 103};
96 104
97/* 105/*
98 * MLKEM768_private_key_from_seed derives a private key from a seed that was 106 * MLKEM768_private_key contains a ML-KEM-768 private key. The contents of this
99 * generated by |MLKEM768_generate_key|. It fails and returns 0 if |seed_len| is 107 * object should never leave the address space since the format is unstable.
100 * incorrect, otherwise it writes |*out_private_key| and returns 1.
101 */ 108 */
102int MLKEM768_private_key_from_seed(const uint8_t *seed, size_t seed_len, 109struct MLKEM768_private_key {
103 MLKEM_private_key *out_private_key); 110 uint8_t bytes[512 * (3 + 3 + 9) + 32 + 32 + 32];
111 uint16_t alignment;
112};
104 113
105/* 114/*
106 * MLKEM_public_from_private sets |*out_public_key| to the public key that 115 * MLKEM_SEED_LENGTH is the number of bytes in an ML-KEM seed. An ML-KEM
107 * corresponds to |private_key|. (This is faster than parsing the output of 116 * seed is normally used to represent a private key.
108 * |MLKEM_generate_key| if, for some reason, you need to encapsulate to a key
109 * that was just generated.)
110 */ 117 */
111void MLKEM768_public_from_private(const MLKEM_private_key *private_key, 118#define MLKEM_SEED_LENGTH 64
112 MLKEM_public_key *out_public_key);
113
114/* MLKEM768_CIPHERTEXT_BYTES is number of bytes in the ML-KEM768 ciphertext. */
115#define MLKEM768_CIPHERTEXT_BYTES 1088
116 119
117/* 120/*
118 * MLKEM768_encap encrypts a random shared secret for |public_key|, writes the 121 * MLKEM_SHARED_SECRET_LENGTH is the number of bytes in an ML-KEM shared
119 * ciphertext to |out_ciphertext|, and writes the random shared secret to 122 * secret.
120 * |out_shared_secret|.
121 */ 123 */
122void MLKEM768_encap(const MLKEM_public_key *public_key, 124#define MLKEM_SHARED_SECRET_LENGTH 32
123 uint8_t out_ciphertext[MLKEM768_CIPHERTEXT_BYTES],
124 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES]);
125 125
126/* 126/*
127 * MLKEM768_decap decrypts a shared secret from |ciphertext| using |private_key| 127 * MLKEM_ENCAP_ENTROPY is the number of bytes of uniformly random entropy
128 * and writes it to |out_shared_secret|. If |ciphertext_len| is incorrect it 128 * necessary to encapsulate a secret. The entropy will be leaked to the
129 * returns 0, otherwise it rreturns 1. If |ciphertext| is invalid, 129 * decapsulating party.
130 * |out_shared_secret| is filled with a key that will always be the same for the
131 * same |ciphertext| and |private_key|, but which appears to be random unless
132 * one has access to |private_key|. These alternatives occur in constant time.
133 * Any subsequent symmetric encryption using |out_shared_secret| must use an
134 * authenticated encryption scheme in order to discover the decapsulation
135 * failure.
136 */ 130 */
137int MLKEM768_decap(const MLKEM_private_key *private_key, 131#define MLKEM_ENCAP_ENTROPY 32
138 const uint8_t *ciphertext, size_t ciphertext_len,
139 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES]);
140 132
141/* Serialisation of keys. */ 133/* MLKEM1024_CIPHERTEXT_BYTES is number of bytes in the ML-KEM-1024 ciphertext. */
134#define MLKEM1024_CIPHERTEXT_BYTES 1568
135
136/* MLKEM768_CIPHERTEXT_BYTES is number of bytes in the ML-KEM768 ciphertext. */
137#define MLKEM768_CIPHERTEXT_BYTES 1088
142 138
143/* 139/*
144 * MLKEM768_marshal_public_key serializes |public_key| to |out| in the standard 140 * MLKEM768_PUBLIC_KEY_BYTES is the number of bytes in an encoded ML-KEM768 public
145 * format for ML-KEM public keys. It returns one on success or zero on allocation 141 * key.
146 * error.
147 */ 142 */
148int MLKEM768_marshal_public_key(const MLKEM_public_key *public_key, 143#define MLKEM768_PUBLIC_KEY_BYTES 1184
149 uint8_t **output, size_t *output_len);
150 144
151/* 145/*
152 * MLKEM768_parse_public_key parses a public key, in the format generated by 146 * MLKEM1024_PUBLIC_KEY_BYTES is the number of bytes in an encoded ML-KEM-1024
153 * |MLKEM_marshal_public_key|, from |in| and writes the result to 147 * public key.
154 * |out_public_key|. It returns one on success or zero on parse error or if
155 * there are trailing bytes in |in|.
156 */ 148 */
157int MLKEM768_parse_public_key(const uint8_t *input, size_t input_len, 149#define MLKEM1024_PUBLIC_KEY_BYTES 1568
158 MLKEM_public_key *out_public_key);
159 150
160/* 151/*
161 * MLKEM_parse_private_key parses a private key, in the format generated by 152 * MLKEM768_PRIVATE_KEY_BYTES is the length of the data produced by
162 * |MLKEM_marshal_private_key|, from |in| and writes the result to 153 * |marshal_private_key| for a RANK768 MLKEM_private_key.
163 * |out_private_key|. It returns one on success or zero on parse error or if
164 * there are trailing bytes in |in|. This formate is verbose and should be avoided.
165 * Private keys should be stored as seeds and parsed using |MLKEM768_private_key_from_seed|.
166 */ 154 */
167int MLKEM768_parse_private_key(const uint8_t *input, size_t input_len, 155#define MLKEM768_PRIVATE_KEY_BYTES 2400
168 MLKEM_private_key *out_private_key);
169 156
170/* 157/*
171 * ML-KEM-1024 158 * MLKEM1024_PRIVATE_KEY_BYTES is the length of the data produced by
172 * 159 * |marshal_private_key| for a RANK1024 MLKEM_private_key.
173 * ML-KEM-1024 also exists. You should prefer ML-KEM-768 where possible.
174 */ 160 */
161#define MLKEM1024_PRIVATE_KEY_BYTES 3168
175 162
176/* 163/*
177 * MLKEM1024_PUBLIC_KEY_BYTES is the number of bytes in an encoded ML-KEM-1024 164 * Internal MLKEM 768 and MLKEM 1024 functions come largely from BoringSSL, but
178 * public key. 165 * converted to C from templated C++. Due to this history, most internal
166 * functions do not allocate, and are expected to be handed memory allocated by
167 * the caller. The caller is generally expected to know what sizes to allocate
168 * based upon the rank of the key (either public or private) that they are
169 * starting with. This avoids the need to handle memory allocation failures
170 * (which boring in C++ just crashes by choice) deep in the implementation, as
171 * what is needed is allocated up front in the public facing functions, and
172 * failure is handled there.
179 */ 173 */
180#define MLKEM1024_PUBLIC_KEY_BYTES 1568 174
175/* Key generation. */
181 176
182/* 177/*
183 * MLKEM1024_generate_key generates a random public/private key pair, writes the 178 * mlkem_generate_key generates a random public/private key pair, writes the
184 * encoded public key to |out_encoded_public_key| and sets |out_private_key| to 179 * encoded public key to |out_encoded_public_key| and sets |out_private_key| to
185 * the private key. If |optional_out_seed| is not NULL then the seed used to 180 * the private key. If |optional_out_seed| is not NULL then the seed used to
186 * generate the private key is written to it. 181 * generate the private key is written to it. The caller is responsible for
182 * ensuring that |out_encoded_public_key| and |out_optonal_seed| point to
183 * enough memory to contain a key and seed for the rank of |out_private_key|.
187 */ 184 */
188int MLKEM1024_generate_key( 185int mlkem_generate_key(uint8_t *out_encoded_public_key,
189 uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES], 186 uint8_t *optional_out_seed, MLKEM_private_key *out_private_key);
190 uint8_t optional_out_seed[MLKEM_SEED_BYTES],
191 MLKEM_private_key *out_private_key);
192 187
193/* 188/*
194 * MLKEM1024_private_key_from_seed derives a private key from a seed that was 189 * mlkem_private_key_from_seed modifies |out_private_key| to generate a key of
195 * generated by |MLKEM1024_generate_key|. It fails and returns 0 if |seed_len| 190 * the rank of |*out_private_key| from a seed that was generated by
196 * is incorrect, otherwise it writes |*out_private_key| and returns 1. 191 * |mlkem_generate_key|. It fails and returns 0 if |seed_len| is incorrect, or
192 * if |*out_private_key| has not been initialized. otherwise it writes to
193 * |*out_private_key| and returns 1.
197 */ 194 */
198int MLKEM1024_private_key_from_seed( 195int mlkem_private_key_from_seed(const uint8_t *seed, size_t seed_len,
199 MLKEM_private_key *out_private_key, const uint8_t *seed, 196 MLKEM_private_key *out_private_key);
200 size_t seed_len);
201 197
202/* 198/*
203 * MLKEM1024_public_from_private sets |*out_public_key| to the public key that 199 * mlkem_public_from_private sets |*out_public_key| to the public key that
204 * corresponds to |private_key|. (This is faster than parsing the output of 200 * corresponds to |*private_key|. (This is faster than parsing the output of
205 * |MLKEM1024_generate_key| if, for some reason, you need to encapsulate to a 201 * |MLKEM_generate_key| if, for some reason, you need to encapsulate to a key
206 * key that was just generated.) 202 * that was just generated.)
207 */ 203 */
208void MLKEM1024_public_from_private(const MLKEM_private_key *private_key, 204void mlkem_public_from_private(const MLKEM_private_key *private_key,
209 MLKEM_public_key *out_public_key); 205 MLKEM_public_key *out_public_key);
210 206
211/* MLKEM1024_CIPHERTEXT_BYTES is number of bytes in the ML-KEM-1024 ciphertext. */ 207
212#define MLKEM1024_CIPHERTEXT_BYTES 1568 208/* Encapsulation and decapsulation of secrets. */
213 209
214/* 210/*
215 * MLKEM1024_encap encrypts a random shared secret for |public_key|, writes the 211 * mlkem_encap encrypts a random shared secret for |public_key|, writes the
216 * ciphertext to |out_ciphertext|, and writes the random shared secret to 212 * ciphertext to |out_ciphertext|, and writes the random shared secret to
217 * |out_shared_secret|. 213 * |out_shared_secret|.
218 */ 214 */
219void MLKEM1024_encap(const MLKEM_public_key *public_key, 215void mlkem_encap(const MLKEM_public_key *public_key,
220 uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES], 216 uint8_t out_ciphertext[MLKEM768_CIPHERTEXT_BYTES],
221 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES]); 217 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_LENGTH]);
222
223 218
224/* 219/*
225 * MLKEM1024_decap decrypts a shared secret from |ciphertext| using 220 * mlkem_decap decrypts a shared secret from |ciphertext| using |private_key|
226 * |private_key| and writes it to |out_shared_secret|. If |ciphertext_len| is 221 * and writes it to |out_shared_secret|. If |ciphertext_len| is incorrect it
227 * incorrect it returns 0, otherwise it returns 1. If |ciphertext| is invalid 222 * returns 0, otherwise it returns 1. If |ciphertext| is invalid,
228 * (but of the correct length), |out_shared_secret| is filled with a key that 223 * |out_shared_secret| is filled with a key that will always be the same for the
229 * will always be the same for the same |ciphertext| and |private_key|, but 224 * same |ciphertext| and |private_key|, but which appears to be random unless
230 * which appears to be random unless one has access to |private_key|. These 225 * one has access to |private_key|. These alternatives occur in constant time.
231 * alternatives occur in constant time. Any subsequent symmetric encryption 226 * Any subsequent symmetric encryption using |out_shared_secret| must use an
232 * using |out_shared_secret| must use an authenticated encryption scheme in 227 * authenticated encryption scheme in order to discover the decapsulation
233 * order to discover the decapsulation failure. 228 * failure.
234 */ 229 */
235int MLKEM1024_decap(const MLKEM_private_key *private_key, 230int mlkem_decap(const MLKEM_private_key *private_key,
236 const uint8_t *ciphertext, size_t ciphertext_len, 231 const uint8_t *ciphertext, size_t ciphertext_len,
237 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES]); 232 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_LENGTH]);
233
234
235/* Serialisation of keys. */
238 236
239/* 237/*
240 * Serialisation of ML-KEM-1024 keys. 238 * mlkem_marshal_public_key serializes |public_key| to |output| in the standard
241 * MLKEM1024_marshal_public_key serializes |public_key| to |out| in the standard 239 * format for ML-KEM public keys. It returns one on success or zero on allocation
242 * format for ML-KEM-1024 public keys. It returns one on success or zero on 240 * error.
243 * allocation error.
244 */ 241 */
245int MLKEM1024_marshal_public_key(const MLKEM_public_key *public_key, 242int mlkem_marshal_public_key(const MLKEM_public_key *public_key,
246 uint8_t **output, size_t *output_len); 243 uint8_t **output, size_t *output_len);
247 244
248
249/* 245/*
250 * MLKEM1024_parse_public_key parses a public key, in the format generated by 246 * mlkem_parse_public_key parses a public key, in the format generated by
251 * |MLKEM1024_marshal_public_key|, from |in| and writes the result to 247 * |MLKEM_marshal_public_key|, from |input| and writes the result to
252 * |out_public_key|. It returns one on success or zero on parse error or if 248 * |out_public_key|. It returns one on success or zero on parse error or if
253 * there are trailing bytes in |in|. 249 * there are trailing bytes in |input|.
254 */ 250 */
255int MLKEM1024_parse_public_key(const uint8_t *input, size_t input_len, 251int mlkem_parse_public_key(const uint8_t *input, size_t input_len,
256 MLKEM_public_key *out_public_key); 252 MLKEM_public_key *out_public_key);
257 253
258
259/* 254/*
260 * MLKEM1024_parse_private_key parses a private key, in NIST's format for 255 * mlkem_parse_private_key parses a private key, in the format generated by
261 * private keys, from |in| and writes the result to |out_private_key|. It 256 * |MLKEM_marshal_private_key|, from |input| and writes the result to
262 * returns one on success or zero on parse error or if there are trailing bytes 257 * |out_private_key|. It returns one on success or zero on parse error or if
263 * in |in|. This format is verbose and should be avoided. Private keys should be 258 * there are trailing bytes in |input|. This formate is verbose and should be avoided.
264 * stored as seeds and parsed using |MLKEM1024_private_key_from_seed|. 259 * Private keys should be stored as seeds and parsed using |mlkem_private_key_from_seed|.
265 */ 260 */
266int MLKEM1024_parse_private_key(const uint8_t *input, size_t input_len, 261int mlkem_parse_private_key(const uint8_t *input, size_t input_len,
267 MLKEM_private_key *out_private_key); 262 MLKEM_private_key *out_private_key);
268
269
270/* XXXX Truly internal stuff below, also in need of de-duping */
271 263
272/*
273 * MLKEM_ENCAP_ENTROPY is the number of bytes of uniformly random entropy
274 * necessary to encapsulate a secret. The entropy will be leaked to the
275 * decapsulating party.
276 */
277#define MLKEM_ENCAP_ENTROPY 32
278 264
279/* 265/* Functions that are only used for test purposes. */
280 * MLKEM768_public_key contains a ML-KEM-768 public key. The contents of this
281 * object should never leave the address space since the format is unstable.
282 */
283struct MLKEM768_public_key {
284 union {
285 uint8_t bytes[512 * (3 + 9) + 32 + 32];
286 uint16_t alignment;
287 } opaque;
288};
289 266
290/* 267/*
291 * MLKEM768_private_key contains a ML-KEM-768 private key. The contents of this 268 * mlkem_generate_key_external_entropy is a deterministic function to create a
292 * object should never leave the address space since the format is unstable.
293 */
294struct MLKEM768_private_key {
295 union {
296 uint8_t bytes[512 * (3 + 3 + 9) + 32 + 32 + 32];
297 uint16_t alignment;
298 } opaque;
299};
300
301/* Public opaque ML-KEM key structures. */
302
303#define MLKEM_PUBLIC_KEY_UNINITIALIZED 1
304#define MLKEM_PUBLIC_KEY_INITIALIZED 2
305#define MLKEM_PRIVATE_KEY_UNINITIALIZED 3
306#define MLKEM_PRIVATE_KEY_INITIALIZED 4
307
308struct MLKEM_public_key_st {
309 uint16_t rank;
310 int state;
311 struct MLKEM768_public_key *key_768;
312 struct MLKEM1024_public_key *key_1024;
313};
314
315struct MLKEM_private_key_st {
316 uint16_t rank;
317 int state;
318 struct MLKEM768_private_key *key_768;
319 struct MLKEM1024_private_key *key_1024;
320};
321
322/*
323 * MLKEM768_generate_key_external_entropy is a deterministic function to create a
324 * pair of ML-KEM 768 keys, using the supplied entropy. The entropy needs to be 269 * pair of ML-KEM 768 keys, using the supplied entropy. The entropy needs to be
325 * uniformly random generated. This function is should only be used for tests, 270 * uniformly random generated. This function is should only be used for tests,
326 * regular callers should use the non-deterministic |MLKEM_generate_key| 271 * regular callers should use the non-deterministic |MLKEM_generate_key|
327 * directly. 272 * directly.
328 */ 273 */
329int MLKEM768_generate_key_external_entropy( 274int mlkem_generate_key_external_entropy(
330 uint8_t out_encoded_public_key[MLKEM768_PUBLIC_KEY_BYTES], 275 uint8_t out_encoded_public_key[MLKEM768_PUBLIC_KEY_BYTES],
331 MLKEM_private_key *out_private_key, 276 MLKEM_private_key *out_private_key,
332 const uint8_t entropy[MLKEM_SEED_BYTES]); 277 const uint8_t entropy[MLKEM_SEED_LENGTH]);
333
334/*
335 * MLKEM768_PRIVATE_KEY_BYTES is the length of the data produced by
336 * |MLKEM768_marshal_private_key|.
337 */
338#define MLKEM768_PRIVATE_KEY_BYTES 2400
339 278
340/* 279/*
341 * MLKEM768_marshal_private_key serializes |private_key| to |out| in the standard 280 * mlkem_marshal_private_key serializes |private_key| to |out_private_key| in the standard
342 * format for ML-KEM private keys. It returns one on success or zero on 281 * format for ML-KEM private keys. It returns one on success or zero on
343 * allocation error. 282 * allocation error.
344 */ 283 */
345int MLKEM768_marshal_private_key(const MLKEM_private_key *private_key, 284int mlkem_marshal_private_key(const MLKEM_private_key *private_key,
346 uint8_t **out_private_key, size_t *out_private_key_len); 285 uint8_t **out_private_key, size_t *out_private_key_len);
347 286
348/* 287/*
349 * MLKEM768_encap_external_entropy behaves like |MLKEM768_encap|, but uses 288 * mlkem_encap_external_entropy behaves like |mlkem_encap|, but uses
350 * |MLKEM_ENCAP_ENTROPY| bytes of |entropy| for randomization. The decapsulating 289 * |MLKEM_ENCAP_ENTROPY| bytes of |entropy| for randomization. The decapsulating
351 * side will be able to recover |entropy| in full. This function should only be 290 * side will be able to recover |entropy| in full. This function should only be
352 * used for tests, regular callers should use the non-deterministic 291 * used for tests, regular callers should use the non-deterministic
353 * |MLKEM_encap| directly. 292 * |MLKEM_encap| directly.
354 */ 293 */
355void MLKEM768_encap_external_entropy( 294void mlkem_encap_external_entropy(
356 uint8_t out_ciphertext[MLKEM768_CIPHERTEXT_BYTES], 295 uint8_t out_ciphertext[MLKEM768_CIPHERTEXT_BYTES],
357 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES], 296 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_LENGTH],
358 const MLKEM_public_key *public_key, 297 const MLKEM_public_key *public_key,
359 const uint8_t entropy[MLKEM_ENCAP_ENTROPY]); 298 const uint8_t entropy[MLKEM_ENCAP_ENTROPY]);
360 299
361
362/* 300/*
363 * MLKEM1024_public_key contains an ML-KEM-1024 public key. The contents of this 301 * |MLKEM_encap_external_entropy| behaves exactly like the public |MLKEM_encap|
364 * object should never leave the address space since the format is unstable. 302 * with the entropy provided by the caller. It is directly called internally
365 */ 303 * and by tests.
366struct MLKEM1024_public_key {
367 union {
368 uint8_t bytes[512 * (4 + 16) + 32 + 32];
369 uint16_t alignment;
370 } opaque;
371};
372
373/*
374 * MLKEM1024_private_key contains a ML-KEM-1024 private key. The contents of
375 * this object should never leave the address space since the format is
376 * unstable.
377 */
378struct MLKEM1024_private_key {
379 union {
380 uint8_t bytes[512 * (4 + 4 + 16) + 32 + 32 + 32];
381 uint16_t alignment;
382 } opaque;
383};
384
385
386/*
387 * MLKEM1024_generate_key_external_entropy is a deterministic function to create a
388 * pair of ML-KEM 1024 keys, using the supplied entropy. The entropy needs to be
389 * uniformly random generated. This function is should only be used for tests,
390 * regular callers should use the non-deterministic |MLKEM_generate_key|
391 * directly.
392 */
393int MLKEM1024_generate_key_external_entropy(
394 uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES],
395 MLKEM_private_key *out_private_key,
396 const uint8_t entropy[MLKEM_SEED_BYTES]);
397
398/*
399 * MLKEM1024_PRIVATE_KEY_BYTES is the length of the data produced by
400 * |MLKEM1024_marshal_private_key|.
401 */ 304 */
402#define MLKEM1024_PRIVATE_KEY_BYTES 3168 305int MLKEM_encap_external_entropy(const MLKEM_public_key *public_key,
306 const uint8_t *entropy, uint8_t **out_ciphertext,
307 size_t *out_ciphertext_len, uint8_t **out_shared_secret,
308 size_t *out_shared_secret_len);
403 309
404/* 310/*
405 * MLKEM1024_marshal_private_key serializes |private_key| to |out| in the 311 * |MLKEM_generate_key_external_entropy| behaves exactly like the public
406 * standard format for ML-KEM private keys. It returns one on success or zero on 312 * |MLKEM_generate_key| with the entropy provided by the caller.
407 * allocation error. 313 * It is directly called internally and by tests.
408 */ 314 */
409int MLKEM1024_marshal_private_key( 315int MLKEM_generate_key_external_entropy(MLKEM_private_key *private_key,
410 const MLKEM_private_key *private_key, uint8_t **out_private_key, 316 uint8_t **out_encoded_public_key, size_t *out_encoded_public_key_len,
411 size_t *out_private_key_len); 317 const uint8_t *entropy);
412 318
413/*
414 * MLKEM_encap_external_entropy behaves like |MLKEM_encap|, but uses
415 * |MLKEM_ENCAP_ENTROPY| bytes of |entropy| for randomization. The decapsulating
416 * side will be able to recover |entropy| in full. This function should only be
417 * used for tests, regular callers should use the non-deterministic
418 * |MLKEM_encap| directly.
419 */
420void MLKEM1024_encap_external_entropy(
421 uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES],
422 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
423 const MLKEM_public_key *public_key,
424 const uint8_t entropy[MLKEM_ENCAP_ENTROPY]);
425 319
426__END_HIDDEN_DECLS 320__END_HIDDEN_DECLS
427 321
diff --git a/src/lib/libcrypto/mlkem/mlkem_key.c b/src/lib/libcrypto/mlkem/mlkem_key.c
index 051d8f2b88..146814d040 100644
--- a/src/lib/libcrypto/mlkem/mlkem_key.c
+++ b/src/lib/libcrypto/mlkem/mlkem_key.c
@@ -1,6 +1,6 @@
1/* $OpenBSD: mlkem_key.c,v 1.1 2025/08/14 15:48:48 beck Exp $ */ 1/* $OpenBSD: mlkem_key.c,v 1.2 2025/09/05 23:30:12 beck Exp $ */
2/* 2/*
3 * Copyright (c) 2025 Bob Beck <beck@openbsd.org> 3 * Copyright (c) 2025 Bob Beck <beck@obtuse.com>
4 * 4 *
5 * Permission to use, copy, modify, and distribute this software for any 5 * Permission to use, copy, modify, and distribute this software for any
6 * purpose with or without fee is hereby granted, provided that the above 6 * purpose with or without fee is hereby granted, provided that the above