summaryrefslogtreecommitdiff
path: root/src/lib
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib')
-rw-r--r--src/lib/libcrypto/Makefile3
-rw-r--r--src/lib/libcrypto/hidden/openssl/mlkem.h10
-rw-r--r--src/lib/libcrypto/mlkem/mlkem.h119
-rw-r--r--src/lib/libcrypto/mlkem/mlkem1024.c1121
-rw-r--r--src/lib/libcrypto/mlkem/mlkem_internal.h39
5 files changed, 1290 insertions, 2 deletions
diff --git a/src/lib/libcrypto/Makefile b/src/lib/libcrypto/Makefile
index f43b09d176..ab2349103d 100644
--- a/src/lib/libcrypto/Makefile
+++ b/src/lib/libcrypto/Makefile
@@ -1,4 +1,4 @@
1# $OpenBSD: Makefile,v 1.229 2024/12/13 00:03:57 beck Exp $ 1# $OpenBSD: Makefile,v 1.230 2024/12/13 00:17:17 beck Exp $
2 2
3LIB= crypto 3LIB= crypto
4LIBREBUILD=y 4LIBREBUILD=y
@@ -374,6 +374,7 @@ SRCS+= md5.c
374 374
375# mlkem/ 375# mlkem/
376SRCS+= mlkem768.c 376SRCS+= mlkem768.c
377SRCS+= mlkem1024.c
377 378
378# modes/ 379# modes/
379SRCS+= cbc128.c 380SRCS+= cbc128.c
diff --git a/src/lib/libcrypto/hidden/openssl/mlkem.h b/src/lib/libcrypto/hidden/openssl/mlkem.h
index 01ac28cffd..103144d1a1 100644
--- a/src/lib/libcrypto/hidden/openssl/mlkem.h
+++ b/src/lib/libcrypto/hidden/openssl/mlkem.h
@@ -1,4 +1,4 @@
1/* $OpenBSD: mlkem.h,v 1.1 2024/12/13 00:03:57 beck Exp $ */ 1/* $OpenBSD: mlkem.h,v 1.2 2024/12/13 00:17:17 beck Exp $ */
2/* 2/*
3 * Copyright (c) 2024 Bob Beck <beck@obtuse.com> 3 * Copyright (c) 2024 Bob Beck <beck@obtuse.com>
4 * 4 *
@@ -35,6 +35,14 @@ LCRYPTO_USED(MLKEM768_marshal_public_key);
35LCRYPTO_USED(MLKEM768_parse_public_key); 35LCRYPTO_USED(MLKEM768_parse_public_key);
36LCRYPTO_USED(MLKEM768_private_key_from_seed); 36LCRYPTO_USED(MLKEM768_private_key_from_seed);
37LCRYPTO_USED(MLKEM768_parse_private_key); 37LCRYPTO_USED(MLKEM768_parse_private_key);
38LCRYPTO_USED(MLKEM1024_generate_key);
39LCRYPTO_USED(MLKEM1024_public_from_private);
40LCRYPTO_USED(MLKEM1024_encap);
41LCRYPTO_USED(MLKEM1024_decap);
42LCRYPTO_USED(MLKEM1024_marshal_public_key);
43LCRYPTO_USED(MLKEM1024_parse_public_key);
44LCRYPTO_USED(MLKEM1024_private_key_from_seed);
45LCRYPTO_USED(MLKEM1024_parse_private_key);
38#endif 46#endif
39 47
40#endif /* _LIBCRYPTO_MLKEM_H */ 48#endif /* _LIBCRYPTO_MLKEM_H */
diff --git a/src/lib/libcrypto/mlkem/mlkem.h b/src/lib/libcrypto/mlkem/mlkem.h
index 8040f4844b..1033b89a60 100644
--- a/src/lib/libcrypto/mlkem/mlkem.h
+++ b/src/lib/libcrypto/mlkem/mlkem.h
@@ -161,6 +161,125 @@ int MLKEM768_parse_public_key(struct MLKEM768_public_key *out_public_key,
161int MLKEM768_parse_private_key(struct MLKEM768_private_key *out_private_key, 161int MLKEM768_parse_private_key(struct MLKEM768_private_key *out_private_key,
162 struct cbs_st *in); 162 struct cbs_st *in);
163 163
164/*
165 * ML-KEM-1024
166 *
167 * ML-KEM-1024 also exists. You should prefer ML-KEM-768 where possible.
168 */
169
170/*
171 * MLKEM1024_public_key contains an ML-KEM-1024 public key. The contents of this
172 * object should never leave the address space since the format is unstable.
173 */
174struct MLKEM1024_public_key {
175 union {
176 uint8_t bytes[512 * (4 + 16) + 32 + 32];
177 uint16_t alignment;
178 } opaque;
179};
180
181/*
182 * MLKEM1024_private_key contains a ML-KEM-1024 private key. The contents of
183 * this object should never leave the address space since the format is
184 * unstable.
185 */
186struct MLKEM1024_private_key {
187 union {
188 uint8_t bytes[512 * (4 + 4 + 16) + 32 + 32 + 32];
189 uint16_t alignment;
190 } opaque;
191};
192
193/*
194 * MLKEM1024_PUBLIC_KEY_BYTES is the number of bytes in an encoded ML-KEM-1024
195 * public key.
196 */
197#define MLKEM1024_PUBLIC_KEY_BYTES 1568
198
199/*
200 * MLKEM1024_generate_key generates a random public/private key pair, writes the
201 * encoded public key to |out_encoded_public_key| and sets |out_private_key| to
202 * the private key. If |optional_out_seed| is not NULL then the seed used to
203 * generate the private key is written to it.
204 */
205void MLKEM1024_generate_key(
206 uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES],
207 uint8_t optional_out_seed[MLKEM_SEED_BYTES],
208 struct MLKEM1024_private_key *out_private_key);
209
210/*
211 * MLKEM1024_private_key_from_seed derives a private key from a seed that was
212 * generated by |MLKEM1024_generate_key|. It fails and returns 0 if |seed_len|
213 * is incorrect, otherwise it writes |*out_private_key| and returns 1.
214 */
215int MLKEM1024_private_key_from_seed(
216 struct MLKEM1024_private_key *out_private_key, const uint8_t *seed,
217 size_t seed_len);
218
219/*
220 * MLKEM1024_public_from_private sets |*out_public_key| to the public key that
221 * corresponds to |private_key|. (This is faster than parsing the output of
222 * |MLKEM1024_generate_key| if, for some reason, you need to encapsulate to a
223 * key that was just generated.)
224 */
225void MLKEM1024_public_from_private(struct MLKEM1024_public_key *out_public_key,
226 const struct MLKEM1024_private_key *private_key);
227
228/* MLKEM1024_CIPHERTEXT_BYTES is number of bytes in the ML-KEM-1024 ciphertext. */
229#define MLKEM1024_CIPHERTEXT_BYTES 1568
230
231/*
232 * MLKEM1024_encap encrypts a random shared secret for |public_key|, writes the
233 * ciphertext to |out_ciphertext|, and writes the random shared secret to
234 * |out_shared_secret|.
235 */
236void MLKEM1024_encap(uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES],
237 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
238 const struct MLKEM1024_public_key *public_key);
239
240/*
241 * MLKEM1024_decap decrypts a shared secret from |ciphertext| using
242 * |private_key| and writes it to |out_shared_secret|. If |ciphertext_len| is
243 * incorrect it returns 0, otherwise it returns 1. If |ciphertext| is invalid
244 * (but of the correct length), |out_shared_secret| is filled with a key that
245 * will always be the same for the same |ciphertext| and |private_key|, but
246 * which appears to be random unless one has access to |private_key|. These
247 * alternatives occur in constant time. Any subsequent symmetric encryption
248 * using |out_shared_secret| must use an authenticated encryption scheme in
249 * order to discover the decapsulation failure.
250 */
251int MLKEM1024_decap(uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
252 const uint8_t *ciphertext, size_t ciphertext_len,
253 const struct MLKEM1024_private_key *private_key);
254
255/*
256 * Serialisation of ML-KEM-1024 keys.
257 * MLKEM1024_marshal_public_key serializes |public_key| to |out| in the standard
258 * format for ML-KEM-1024 public keys. It returns one on success or zero on
259 * allocation error.
260 */
261int MLKEM1024_marshal_public_key(struct cbb_st *out,
262 const struct MLKEM1024_public_key *public_key);
263
264/*
265 * MLKEM1024_parse_public_key parses a public key, in the format generated by
266 * |MLKEM1024_marshal_public_key|, from |in| and writes the result to
267 * |out_public_key|. It returns one on success or zero on parse error or if
268 * there are trailing bytes in |in|.
269 */
270int MLKEM1024_parse_public_key(struct MLKEM1024_public_key *out_public_key,
271 struct cbs_st *in);
272
273/*
274 * MLKEM1024_parse_private_key parses a private key, in NIST's format for
275 * private keys, from |in| and writes the result to |out_private_key|. It
276 * returns one on success or zero on parse error or if there are trailing bytes
277 * in |in|. This format is verbose and should be avoided. Private keys should be
278 * stored as seeds and parsed using |MLKEM1024_private_key_from_seed|.
279 */
280int MLKEM1024_parse_private_key(struct MLKEM1024_private_key *out_private_key,
281 struct cbs_st *in);
282
164#if defined(__cplusplus) 283#if defined(__cplusplus)
165} 284}
166#endif 285#endif
diff --git a/src/lib/libcrypto/mlkem/mlkem1024.c b/src/lib/libcrypto/mlkem/mlkem1024.c
new file mode 100644
index 0000000000..e0a71f335b
--- /dev/null
+++ b/src/lib/libcrypto/mlkem/mlkem1024.c
@@ -0,0 +1,1121 @@
1/* $OpenBSD: mlkem1024.c,v 1.1 2024/12/13 00:17:17 beck Exp $ */
2/*
3 * Copyright (c) 2024, Google Inc.
4 * Copyright (c) 2024, Bob Beck <beck@obtuse.com>
5 *
6 * Permission to use, copy, modify, and/or distribute this software for any
7 * purpose with or without fee is hereby granted, provided that the above
8 * copyright notice and this permission notice appear in all copies.
9 *
10 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
13 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
15 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
16 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17 */
18
19#include <openssl/mlkem.h>
20
21#include <assert.h>
22#include <stdlib.h>
23#include <string.h>
24
25#include "bytestring.h"
26
27#include "sha3_internal.h"
28#include "mlkem_internal.h"
29#include "constant_time.h"
30#include "crypto_internal.h"
31
32/* Remove later */
33#undef LCRYPTO_ALIAS
34#define LCRYPTO_ALIAS(A)
35
36/*
37 * See
38 * https://csrc.nist.gov/pubs/fips/203/final
39 */
40
41static void
42prf(uint8_t *out, size_t out_len, const uint8_t in[33])
43{
44 sha3_ctx ctx;
45 shake256_init(&ctx);
46 shake_update(&ctx, in, 33);
47 shake_xof(&ctx);
48 shake_out(&ctx, out, out_len);
49}
50
51/* Section 4.1 */
52static void
53hash_h(uint8_t out[32], const uint8_t *in, size_t len)
54{
55 sha3_ctx ctx;
56 sha3_init(&ctx, 32);
57 sha3_update(&ctx, in, len);
58 sha3_final(out, &ctx);
59}
60
61static void
62hash_g(uint8_t out[64], const uint8_t *in, size_t len)
63{
64 sha3_ctx ctx;
65 sha3_init(&ctx, 64);
66 sha3_update(&ctx, in, len);
67 sha3_final(out, &ctx);
68}
69
70/* this is called 'J' in the spec */
71static void
72kdf(uint8_t out[MLKEM_SHARED_SECRET_BYTES], const uint8_t failure_secret[32],
73 const uint8_t *in, size_t len)
74{
75 sha3_ctx ctx;
76 shake256_init(&ctx);
77 shake_update(&ctx, failure_secret, 32);
78 shake_update(&ctx, in, len);
79 shake_xof(&ctx);
80 shake_out(&ctx, out, MLKEM_SHARED_SECRET_BYTES);
81}
82
83#define DEGREE 256
84#define RANK1024 4
85
86static const size_t kBarrettMultiplier = 5039;
87static const unsigned kBarrettShift = 24;
88static const uint16_t kPrime = 3329;
89static const int kLog2Prime = 12;
90static const uint16_t kHalfPrime = (/*kPrime=*/3329 - 1) / 2;
91static const int kDU1024 = 11;
92static const int kDV1024 = 5;
93
94/*
95 * kInverseDegree is 128^-1 mod 3329; 128 because kPrime does not have a 512th
96 * root of unity.
97 */
98static const uint16_t kInverseDegree = 3303;
99static const size_t kEncodedVectorSize =
100 (/*kLog2Prime=*/12 * DEGREE / 8) * RANK1024;
101static const size_t kCompressedVectorSize = /*kDU1024=*/ 11 * RANK1024 * DEGREE /
102 8;
103
104typedef struct scalar {
105 /* On every function entry and exit, 0 <= c < kPrime. */
106 uint16_t c[DEGREE];
107} scalar;
108
109typedef struct vector {
110 scalar v[RANK1024];
111} vector;
112
113typedef struct matrix {
114 scalar v[RANK1024][RANK1024];
115} matrix;
116
117/*
118 * This bit of Python will be referenced in some of the following comments:
119 *
120 * p = 3329
121 *
122 * def bitreverse(i):
123 * ret = 0
124 * for n in range(7):
125 * bit = i & 1
126 * ret <<= 1
127 * ret |= bit
128 * i >>= 1
129 * return ret
130 */
131
132/* kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)] */
133static const uint16_t kNTTRoots[128] = {
134 1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797,
135 2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333,
136 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756,
137 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915,
138 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648,
139 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100,
140 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789,
141 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641,
142 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757,
143 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886,
144 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154,
145};
146
147/* kInverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)] */
148static const uint16_t kInverseNTTRoots[128] = {
149 1, 1600, 40, 749, 2481, 1432, 2699, 687, 1583, 2760, 69, 543,
150 2532, 3136, 1410, 2267, 2508, 1355, 450, 936, 447, 2794, 1235, 1903,
151 1996, 1089, 3273, 283, 1853, 1990, 882, 3033, 2419, 2102, 219, 855,
152 2681, 1848, 712, 682, 927, 1795, 461, 1891, 2877, 2522, 1894, 1010,
153 1414, 2009, 3296, 464, 2697, 816, 1352, 2679, 1274, 1052, 1025, 2132,
154 1573, 76, 2998, 3040, 1175, 2444, 394, 1219, 2300, 1455, 2117, 1607,
155 2443, 554, 1179, 2186, 2303, 2926, 2237, 525, 735, 863, 2768, 1230,
156 2572, 556, 3010, 2266, 1684, 1239, 780, 2954, 109, 1292, 1031, 1745,
157 2688, 3061, 992, 2596, 941, 892, 1021, 2390, 642, 1868, 2377, 1482,
158 1540, 540, 1678, 1626, 279, 314, 1173, 2573, 3096, 48, 667, 1920,
159 2229, 1041, 2606, 1692, 680, 2746, 568, 3312,
160};
161
162/* kModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)] */
163static const uint16_t kModRoots[128] = {
164 17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606,
165 2288, 1041, 1100, 2229, 1409, 1920, 2662, 667, 3281, 48, 233, 3096,
166 756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678,
167 2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642,
168 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992,
169 268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109,
170 375, 2954, 2549, 780, 2090, 1239, 1645, 1684, 1063, 2266, 319, 3010,
171 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735,
172 2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179,
173 2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300,
174 2110, 1219, 2935, 394, 885, 2444, 2154, 1175,
175};
176
177/* reduce_once reduces 0 <= x < 2*kPrime, mod kPrime. */
178static uint16_t
179reduce_once(uint16_t x)
180{
181 assert(x < 2 * kPrime);
182 const uint16_t subtracted = x - kPrime;
183 uint16_t mask = 0u - (subtracted >> 15);
184 /*
185 * On Aarch64, omitting a |value_barrier_u16| results in a 2x speedup of
186 * ML-KEM overall and Clang still produces constant-time code using
187 * `csel`. On other platforms & compilers on godbolt that we care about,
188 * this code also produces constant-time output.
189 */
190 return (mask & x) | (~mask & subtracted);
191}
192
193/*
194 * constant time reduce x mod kPrime using Barrett reduction. x must be less
195 * than kPrime + 2×kPrime².
196 */
197static uint16_t
198reduce(uint32_t x)
199{
200 uint64_t product = (uint64_t)x * kBarrettMultiplier;
201 uint32_t quotient = (uint32_t)(product >> kBarrettShift);
202 uint32_t remainder = x - quotient * kPrime;
203
204 assert(x < kPrime + 2u * kPrime * kPrime);
205 return reduce_once(remainder);
206}
207
208static void
209scalar_zero(scalar *out)
210{
211 memset(out, 0, sizeof(*out));
212}
213
214static void
215vector_zero(vector *out)
216{
217 memset(out, 0, sizeof(*out));
218}
219
220/*
221 * In place number theoretic transform of a given scalar.
222 * Note that MLKEM's kPrime 3329 does not have a 512th root of unity, so this
223 * transform leaves off the last iteration of the usual FFT code, with the 128
224 * relevant roots of unity being stored in |kNTTRoots|. This means the output
225 * should be seen as 128 elements in GF(3329^2), with the coefficients of the
226 * elements being consecutive entries in |s->c|.
227 */
228static void
229scalar_ntt(scalar *s)
230{
231 int offset = DEGREE;
232 int step;
233 /*
234 * `int` is used here because using `size_t` throughout caused a ~5% slowdown
235 * with Clang 14 on Aarch64.
236 */
237 for (step = 1; step < DEGREE / 2; step <<= 1) {
238 int i, j, k = 0;
239
240 offset >>= 1;
241 for (i = 0; i < step; i++) {
242 const uint32_t step_root = kNTTRoots[i + step];
243
244 for (j = k; j < k + offset; j++) {
245 uint16_t odd, even;
246
247 odd = reduce(step_root * s->c[j + offset]);
248 even = s->c[j];
249 s->c[j] = reduce_once(odd + even);
250 s->c[j + offset] = reduce_once(even - odd +
251 kPrime);
252 }
253 k += 2 * offset;
254 }
255 }
256}
257
258static void
259vector_ntt(vector *a)
260{
261 int i;
262
263 for (i = 0; i < RANK1024; i++) {
264 scalar_ntt(&a->v[i]);
265 }
266}
267
268/*
269 * In place inverse number theoretic transform of a given scalar, with pairs of
270 * entries of s->v being interpreted as elements of GF(3329^2). Just as with the
271 * number theoretic transform, this leaves off the first step of the normal iFFT
272 * to account for the fact that 3329 does not have a 512th root of unity, using
273 * the precomputed 128 roots of unity stored in |kInverseNTTRoots|.
274 */
275static void
276scalar_inverse_ntt(scalar *s)
277{
278 int i, j, k, offset, step = DEGREE / 2;
279
280 /*
281 * `int` is used here because using `size_t` throughout caused a ~5% slowdown
282 * with Clang 14 on Aarch64.
283 */
284 for (offset = 2; offset < DEGREE; offset <<= 1) {
285 step >>= 1;
286 k = 0;
287 for (i = 0; i < step; i++) {
288 uint32_t step_root = kInverseNTTRoots[i + step];
289 for (j = k; j < k + offset; j++) {
290 uint16_t odd, even;
291 odd = s->c[j + offset];
292 even = s->c[j];
293 s->c[j] = reduce_once(odd + even);
294 s->c[j + offset] = reduce(step_root *
295 (even - odd + kPrime));
296 }
297 k += 2 * offset;
298 }
299 }
300 for (i = 0; i < DEGREE; i++) {
301 s->c[i] = reduce(s->c[i] * kInverseDegree);
302 }
303}
304
305static void
306vector_inverse_ntt(vector *a)
307{
308 int i;
309
310 for (i = 0; i < RANK1024; i++) {
311 scalar_inverse_ntt(&a->v[i]);
312 }
313}
314
315static void
316scalar_add(scalar *lhs, const scalar *rhs)
317{
318 int i;
319
320 for (i = 0; i < DEGREE; i++) {
321 lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
322 }
323}
324
325static void
326scalar_sub(scalar *lhs, const scalar *rhs)
327{
328 int i;
329
330 for (i = 0; i < DEGREE; i++) {
331 lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
332 }
333}
334
335/*
336 * Multiplying two scalars in the number theoretically transformed state. Since
337 * 3329 does not have a 512th root of unity, this means we have to interpret
338 * the 2*ith and (2*i+1)th entries of the scalar as elements of GF(3329)[X]/(X^2
339 * - 17^(2*bitreverse(i)+1)) The value of 17^(2*bitreverse(i)+1) mod 3329 is
340 * stored in the precomputed |kModRoots| table. Note that our Barrett transform
341 * only allows us to multipy two reduced numbers together, so we need some
342 * intermediate reduction steps, even if an uint64_t could hold 3 multiplied
343 * numbers.
344 */
345static void
346scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs)
347{
348 int i;
349
350 for (i = 0; i < DEGREE / 2; i++) {
351 uint32_t real_real = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i];
352 uint32_t img_img = (uint32_t)lhs->c[2 * i + 1] *
353 rhs->c[2 * i + 1];
354 uint32_t real_img = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i + 1];
355 uint32_t img_real = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i];
356
357 out->c[2 * i] =
358 reduce(real_real +
359 (uint32_t)reduce(img_img) * kModRoots[i]);
360 out->c[2 * i + 1] = reduce(img_real + real_img);
361 }
362}
363
364static void
365vector_add(vector *lhs, const vector *rhs)
366{
367 int i;
368
369 for (i = 0; i < RANK1024; i++) {
370 scalar_add(&lhs->v[i], &rhs->v[i]);
371 }
372}
373
374static void
375matrix_mult(vector *out, const matrix *m, const vector *a)
376{
377 int i, j;
378
379 vector_zero(out);
380 for (i = 0; i < RANK1024; i++) {
381 for (j = 0; j < RANK1024; j++) {
382 scalar product;
383
384 scalar_mult(&product, &m->v[i][j], &a->v[j]);
385 scalar_add(&out->v[i], &product);
386 }
387 }
388}
389
390static void
391matrix_mult_transpose(vector *out, const matrix *m,
392 const vector *a)
393{
394 int i, j;
395
396 vector_zero(out);
397 for (i = 0; i < RANK1024; i++) {
398 for (j = 0; j < RANK1024; j++) {
399 scalar product;
400
401 scalar_mult(&product, &m->v[j][i], &a->v[j]);
402 scalar_add(&out->v[i], &product);
403 }
404 }
405}
406
407static void
408scalar_inner_product(scalar *out, const vector *lhs,
409 const vector *rhs)
410{
411 int i;
412 scalar_zero(out);
413 for (i = 0; i < RANK1024; i++) {
414 scalar product;
415
416 scalar_mult(&product, &lhs->v[i], &rhs->v[i]);
417 scalar_add(out, &product);
418 }
419}
420
421/*
422 * Algorithm 6 of spec. Rejection samples a Keccak stream to get uniformly
423 * distributed elements. This is used for matrix expansion and only operates on
424 * public inputs.
425 */
426static void
427scalar_from_keccak_vartime(scalar *out, sha3_ctx *keccak_ctx)
428{
429 int i, done = 0;
430
431 while (done < DEGREE) {
432 uint8_t block[168];
433
434 shake_out(keccak_ctx, block, sizeof(block));
435 for (i = 0; i < sizeof(block) && done < DEGREE; i += 3) {
436 uint16_t d1 = block[i] + 256 * (block[i + 1] % 16);
437 uint16_t d2 = block[i + 1] / 16 + 16 * block[i + 2];
438
439 if (d1 < kPrime) {
440 out->c[done++] = d1;
441 }
442 if (d2 < kPrime && done < DEGREE) {
443 out->c[done++] = d2;
444 }
445 }
446 }
447}
448
449/*
450 * Algorithm 7 of the spec, with eta fixed to two and the PRF call
451 * included. Creates binominally distributed elements by sampling 2*|eta| bits,
452 * and setting the coefficient to the count of the first bits minus the count of
453 * the second bits, resulting in a centered binomial distribution. Since eta is
454 * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
455 * and 0 with probability 3/8.
456 */
457static void
458scalar_centered_binomial_distribution_eta_2_with_prf(scalar *out,
459 const uint8_t input[33])
460{
461 uint8_t entropy[128];
462 int i;
463
464 CTASSERT(sizeof(entropy) == 2 * /*kEta=*/ 2 * DEGREE / 8);
465 prf(entropy, sizeof(entropy), input);
466
467 for (i = 0; i < DEGREE; i += 2) {
468 uint8_t byte = entropy[i / 2];
469 uint16_t value = kPrime;
470
471 value += (byte & 1) + ((byte >> 1) & 1);
472 value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
473 out->c[i] = reduce_once(value);
474
475 byte >>= 4;
476 value = kPrime;
477 value += (byte & 1) + ((byte >> 1) & 1);
478 value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
479 out->c[i + 1] = reduce_once(value);
480 }
481}
482
483/*
484 * Generates a secret vector by using
485 * |scalar_centered_binomial_distribution_eta_2_with_prf|, using the given seed
486 * appending and incrementing |counter| for entry of the vector.
487 */
488static void
489vector_generate_secret_eta_2(vector *out, uint8_t *counter,
490 const uint8_t seed[32])
491{
492 uint8_t input[33];
493 int i;
494
495 memcpy(input, seed, 32);
496 for (i = 0; i < RANK1024; i++) {
497 input[32] = (*counter)++;
498 scalar_centered_binomial_distribution_eta_2_with_prf(&out->v[i],
499 input);
500 }
501}
502
503/* Expands the matrix of a seed for key generation and for encaps-CPA. */
504static void
505matrix_expand(matrix *out, const uint8_t rho[32])
506{
507 uint8_t input[34];
508 int i, j;
509
510 memcpy(input, rho, 32);
511 for (i = 0; i < RANK1024; i++) {
512 for (j = 0; j < RANK1024; j++) {
513 sha3_ctx keccak_ctx;
514
515 input[32] = i;
516 input[33] = j;
517 shake128_init(&keccak_ctx);
518 shake_update(&keccak_ctx, input, sizeof(input));
519 shake_xof(&keccak_ctx);
520 scalar_from_keccak_vartime(&out->v[i][j], &keccak_ctx);
521 }
522 }
523}
524
525static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f,
526 0x1f, 0x3f, 0x7f, 0xff};
527
528static void
529scalar_encode(uint8_t *out, const scalar *s, int bits)
530{
531 uint8_t out_byte = 0;
532 int i, out_byte_bits = 0;
533
534 assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1);
535 for (i = 0; i < DEGREE; i++) {
536 uint16_t element = s->c[i];
537 int element_bits_done = 0;
538
539 while (element_bits_done < bits) {
540 int chunk_bits = bits - element_bits_done;
541 int out_bits_remaining = 8 - out_byte_bits;
542
543 if (chunk_bits >= out_bits_remaining) {
544 chunk_bits = out_bits_remaining;
545 out_byte |= (element &
546 kMasks[chunk_bits - 1]) << out_byte_bits;
547 *out = out_byte;
548 out++;
549 out_byte_bits = 0;
550 out_byte = 0;
551 } else {
552 out_byte |= (element &
553 kMasks[chunk_bits - 1]) << out_byte_bits;
554 out_byte_bits += chunk_bits;
555 }
556
557 element_bits_done += chunk_bits;
558 element >>= chunk_bits;
559 }
560 }
561
562 if (out_byte_bits > 0) {
563 *out = out_byte;
564 }
565}
566
567/* scalar_encode_1 is |scalar_encode| specialised for |bits| == 1. */
568static void
569scalar_encode_1(uint8_t out[32], const scalar *s)
570{
571 int i, j;
572
573 for (i = 0; i < DEGREE; i += 8) {
574 uint8_t out_byte = 0;
575
576 for (j = 0; j < 8; j++) {
577 out_byte |= (s->c[i + j] & 1) << j;
578 }
579 *out = out_byte;
580 out++;
581 }
582}
583
584/*
585 * Encodes an entire vector into 32*|RANK1024|*|bits| bytes. Note that since 256
586 * (DEGREE) is divisible by 8, the individual vector entries will always fill a
587 * whole number of bytes, so we do not need to worry about bit packing here.
588 */
589static void
590vector_encode(uint8_t *out, const vector *a, int bits)
591{
592 int i;
593
594 for (i = 0; i < RANK1024; i++) {
595 scalar_encode(out + i * bits * DEGREE / 8, &a->v[i], bits);
596 }
597}
598
599/*
600 * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
601 * |out|. It returns one on success and zero if any parsed value is >=
602 * |kPrime|.
603 */
604static int
605scalar_decode(scalar *out, const uint8_t *in, int bits)
606{
607 uint8_t in_byte = 0;
608 int i, in_byte_bits_left = 0;
609
610 assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1);
611
612 for (i = 0; i < DEGREE; i++) {
613 uint16_t element = 0;
614 int element_bits_done = 0;
615
616 while (element_bits_done < bits) {
617 int chunk_bits = bits - element_bits_done;
618
619 if (in_byte_bits_left == 0) {
620 in_byte = *in;
621 in++;
622 in_byte_bits_left = 8;
623 }
624
625 if (chunk_bits > in_byte_bits_left) {
626 chunk_bits = in_byte_bits_left;
627 }
628
629 element |= (in_byte & kMasks[chunk_bits - 1]) <<
630 element_bits_done;
631 in_byte_bits_left -= chunk_bits;
632 in_byte >>= chunk_bits;
633
634 element_bits_done += chunk_bits;
635 }
636
637 if (element >= kPrime) {
638 return 0;
639 }
640 out->c[i] = element;
641 }
642
643 return 1;
644}
645
646/* scalar_decode_1 is |scalar_decode| specialised for |bits| == 1. */
647static void
648scalar_decode_1(scalar *out, const uint8_t in[32])
649{
650 int i, j;
651
652 for (i = 0; i < DEGREE; i += 8) {
653 uint8_t in_byte = *in;
654
655 in++;
656 for (j = 0; j < 8; j++) {
657 out->c[i + j] = in_byte & 1;
658 in_byte >>= 1;
659 }
660 }
661}
662
663/*
664 * Decodes 32*|RANK1024|*|bits| bytes from |in| into |out|. It returns one on
665 * success or zero if any parsed value is >= |kPrime|.
666 */
667static int
668vector_decode(vector *out, const uint8_t *in, int bits)
669{
670 int i;
671
672 for (i = 0; i < RANK1024; i++) {
673 if (!scalar_decode(&out->v[i], in + i * bits * DEGREE / 8,
674 bits)) {
675 return 0;
676 }
677 }
678 return 1;
679}
680
681/*
682 * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
683 * numbers close to each other together. The formula used is
684 * round(2^|bits|/kPrime*x) mod 2^|bits|.
685 * Uses Barrett reduction to achieve constant time. Since we need both the
686 * remainder (for rounding) and the quotient (as the result), we cannot use
687 * |reduce| here, but need to do the Barrett reduction directly.
688 */
689static uint16_t
690compress(uint16_t x, int bits)
691{
692 uint32_t shifted = (uint32_t)x << bits;
693 uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
694 uint32_t quotient = (uint32_t)(product >> kBarrettShift);
695 uint32_t remainder = shifted - quotient * kPrime;
696
697 /*
698 * Adjust the quotient to round correctly:
699 * 0 <= remainder <= kHalfPrime round to 0
700 * kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
701 * kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
702 */
703 assert(remainder < 2u * kPrime);
704 quotient += 1 & constant_time_lt(kHalfPrime, remainder);
705 quotient += 1 & constant_time_lt(kPrime + kHalfPrime, remainder);
706 return quotient & ((1 << bits) - 1);
707}
708
709/*
710 * Decompresses |x| by using an equi-distant representative. The formula is
711 * round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us to
712 * implement this logic using only bit operations.
713 */
714static uint16_t
715decompress(uint16_t x, int bits)
716{
717 uint32_t product = (uint32_t)x * kPrime;
718 uint32_t power = 1 << bits;
719 /* This is |product| % power, since |power| is a power of 2. */
720 uint32_t remainder = product & (power - 1);
721 /* This is |product| / power, since |power| is a power of 2. */
722 uint32_t lower = product >> bits;
723
724 /*
725 * The rounding logic works since the first half of numbers mod |power| have a
726 * 0 as first bit, and the second half has a 1 as first bit, since |power| is
727 * a power of 2. As a 12 bit number, |remainder| is always positive, so we
728 * will shift in 0s for a right shift.
729 */
730 return lower + (remainder >> (bits - 1));
731}
732
733static void
734scalar_compress(scalar *s, int bits)
735{
736 int i;
737
738 for (i = 0; i < DEGREE; i++) {
739 s->c[i] = compress(s->c[i], bits);
740 }
741}
742
743static void
744scalar_decompress(scalar *s, int bits)
745{
746 int i;
747
748 for (i = 0; i < DEGREE; i++) {
749 s->c[i] = decompress(s->c[i], bits);
750 }
751}
752
753static void
754vector_compress(vector *a, int bits)
755{
756 int i;
757
758 for (i = 0; i < RANK1024; i++) {
759 scalar_compress(&a->v[i], bits);
760 }
761}
762
763static void
764vector_decompress(vector *a, int bits)
765{
766 int i;
767
768 for (i = 0; i < RANK1024; i++) {
769 scalar_decompress(&a->v[i], bits);
770 }
771}
772
773struct public_key {
774 vector t;
775 uint8_t rho[32];
776 uint8_t public_key_hash[32];
777 matrix m;
778};
779
780static struct public_key *
781public_key_1024_from_external(const struct MLKEM1024_public_key *external)
782{
783 return (struct public_key *)external;
784}
785
786struct private_key {
787 struct public_key pub;
788 vector s;
789 uint8_t fo_failure_secret[32];
790};
791
792static struct private_key *
793private_key_1024_from_external(const struct MLKEM1024_private_key *external)
794{
795 return (struct private_key *)external;
796}
797
798/*
799 * Calls |MLKEM1024_generate_key_external_entropy| with random bytes from
800 * |RAND_bytes|.
801 */
802void
803MLKEM1024_generate_key(uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES],
804 uint8_t optional_out_seed[MLKEM_SEED_BYTES],
805 struct MLKEM1024_private_key *out_private_key)
806{
807 uint8_t entropy_buf[MLKEM_SEED_BYTES];
808 uint8_t *entropy = optional_out_seed != NULL ? optional_out_seed :
809 entropy_buf;
810
811 arc4random_buf(entropy, MLKEM_SEED_BYTES);
812 MLKEM1024_generate_key_external_entropy(out_encoded_public_key,
813 out_private_key, entropy);
814}
815LCRYPTO_ALIAS(MLKEM1024_generate_key);
816
817int
818MLKEM1024_private_key_from_seed(struct MLKEM1024_private_key *out_private_key,
819 const uint8_t *seed, size_t seed_len)
820{
821 uint8_t public_key_bytes[MLKEM1024_PUBLIC_KEY_BYTES];
822
823 if (seed_len != MLKEM_SEED_BYTES) {
824 return 0;
825 }
826 MLKEM1024_generate_key_external_entropy(public_key_bytes,
827 out_private_key, seed);
828
829 return 1;
830}
831LCRYPTO_ALIAS(MLKEM1024_private_key_from_seed);
832
833static int
834mlkem_marshal_public_key(CBB *out, const struct public_key *pub)
835{
836 uint8_t *vector_output;
837
838 if (!CBB_add_space(out, &vector_output, kEncodedVectorSize)) {
839 return 0;
840 }
841 vector_encode(vector_output, &pub->t, kLog2Prime);
842 if (!CBB_add_bytes(out, pub->rho, sizeof(pub->rho))) {
843 return 0;
844 }
845 return 1;
846}
847
848void
849MLKEM1024_generate_key_external_entropy(
850 uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES],
851 struct MLKEM1024_private_key *out_private_key,
852 const uint8_t entropy[MLKEM_SEED_BYTES])
853{
854 struct private_key *priv = private_key_1024_from_external(
855 out_private_key);
856 uint8_t augmented_seed[33];
857 uint8_t *rho, *sigma;
858 uint8_t counter = 0;
859 uint8_t hashed[64];
860 vector error;
861 CBB cbb;
862
863 memcpy(augmented_seed, entropy, 32);
864 augmented_seed[32] = RANK1024;
865 hash_g(hashed, augmented_seed, 33);
866 rho = hashed;
867 sigma = hashed + 32;
868 memcpy(priv->pub.rho, hashed, sizeof(priv->pub.rho));
869 matrix_expand(&priv->pub.m, rho);
870 vector_generate_secret_eta_2(&priv->s, &counter, sigma);
871 vector_ntt(&priv->s);
872 vector_generate_secret_eta_2(&error, &counter, sigma);
873 vector_ntt(&error);
874 matrix_mult_transpose(&priv->pub.t, &priv->pub.m, &priv->s);
875 vector_add(&priv->pub.t, &error);
876
877 CBB_init_fixed(&cbb, out_encoded_public_key, MLKEM1024_PUBLIC_KEY_BYTES);
878 if (!mlkem_marshal_public_key(&cbb, &priv->pub)) {
879 abort();
880 }
881
882 hash_h(priv->pub.public_key_hash, out_encoded_public_key,
883 MLKEM1024_PUBLIC_KEY_BYTES);
884 memcpy(priv->fo_failure_secret, entropy + 32, 32);
885}
886
887void
888MLKEM1024_public_from_private(struct MLKEM1024_public_key *out_public_key,
889 const struct MLKEM1024_private_key *private_key)
890{
891 struct public_key *const pub = public_key_1024_from_external(
892 out_public_key);
893 const struct private_key *const priv = private_key_1024_from_external(
894 private_key);
895
896 *pub = priv->pub;
897}
898LCRYPTO_ALIAS(MLKEM1024_public_from_private);
899
900/*
901 * Encrypts a message with given randomness to the ciphertext in |out|. Without
902 * applying the Fujisaki-Okamoto transform this would not result in a CCA secure
903 * scheme, since lattice schemes are vulnerable to decryption failure oracles.
904 */
905static void
906encrypt_cpa(uint8_t out[MLKEM1024_CIPHERTEXT_BYTES],
907 const struct public_key *pub, const uint8_t message[32],
908 const uint8_t randomness[32])
909{
910 scalar expanded_message, scalar_error;
911 vector secret, error, u;
912 uint8_t counter = 0;
913 uint8_t input[33];
914 scalar v;
915
916 vector_generate_secret_eta_2(&secret, &counter, randomness);
917 vector_ntt(&secret);
918 vector_generate_secret_eta_2(&error, &counter, randomness);
919 memcpy(input, randomness, 32);
920 input[32] = counter;
921 scalar_centered_binomial_distribution_eta_2_with_prf(&scalar_error,
922 input);
923 matrix_mult(&u, &pub->m, &secret);
924 vector_inverse_ntt(&u);
925 vector_add(&u, &error);
926 scalar_inner_product(&v, &pub->t, &secret);
927 scalar_inverse_ntt(&v);
928 scalar_add(&v, &scalar_error);
929 scalar_decode_1(&expanded_message, message);
930 scalar_decompress(&expanded_message, 1);
931 scalar_add(&v, &expanded_message);
932 vector_compress(&u, kDU1024);
933 vector_encode(out, &u, kDU1024);
934 scalar_compress(&v, kDV1024);
935 scalar_encode(out + kCompressedVectorSize, &v, kDV1024);
936}
937
938/* Calls MLKEM1024_encap_external_entropy| with random bytes */
939void
940MLKEM1024_encap(uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES],
941 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
942 const struct MLKEM1024_public_key *public_key)
943{
944 uint8_t entropy[MLKEM_ENCAP_ENTROPY];
945
946 arc4random_buf(entropy, MLKEM_ENCAP_ENTROPY);
947 MLKEM1024_encap_external_entropy(out_ciphertext, out_shared_secret,
948 public_key, entropy);
949}
950LCRYPTO_ALIAS(MLKEM1024_encap);
951
952/* See section 6.2 of the spec. */
953void
954MLKEM1024_encap_external_entropy(
955 uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES],
956 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
957 const struct MLKEM1024_public_key *public_key,
958 const uint8_t entropy[MLKEM_ENCAP_ENTROPY])
959{
960 const struct public_key *pub = public_key_1024_from_external(public_key);
961 uint8_t key_and_randomness[64];
962 uint8_t input[64];
963
964 memcpy(input, entropy, MLKEM_ENCAP_ENTROPY);
965 memcpy(input + MLKEM_ENCAP_ENTROPY, pub->public_key_hash,
966 sizeof(input) - MLKEM_ENCAP_ENTROPY);
967 hash_g(key_and_randomness, input, sizeof(input));
968 encrypt_cpa(out_ciphertext, pub, entropy, key_and_randomness + 32);
969 memcpy(out_shared_secret, key_and_randomness, 32);
970}
971
972static void
973decrypt_cpa(uint8_t out[32], const struct private_key *priv,
974 const uint8_t ciphertext[MLKEM1024_CIPHERTEXT_BYTES])
975{
976 scalar mask, v;
977 vector u;
978
979 vector_decode(&u, ciphertext, kDU1024);
980 vector_decompress(&u, kDU1024);
981 vector_ntt(&u);
982 scalar_decode(&v, ciphertext + kCompressedVectorSize, kDV1024);
983 scalar_decompress(&v, kDV1024);
984 scalar_inner_product(&mask, &priv->s, &u);
985 scalar_inverse_ntt(&mask);
986 scalar_sub(&v, &mask);
987 scalar_compress(&v, 1);
988 scalar_encode_1(out, &v);
989}
990
991/* See section 6.3 */
992int
993MLKEM1024_decap(uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
994 const uint8_t *ciphertext, size_t ciphertext_len,
995 const struct MLKEM1024_private_key *private_key)
996{
997 const struct private_key *priv = private_key_1024_from_external(
998 private_key);
999 uint8_t expected_ciphertext[MLKEM1024_CIPHERTEXT_BYTES];
1000 uint8_t key_and_randomness[64];
1001 uint8_t failure_key[32];
1002 uint8_t decrypted[64];
1003 uint8_t mask;
1004 int i;
1005
1006 if (ciphertext_len != MLKEM1024_CIPHERTEXT_BYTES) {
1007 arc4random_buf(out_shared_secret, MLKEM_SHARED_SECRET_BYTES);
1008 return 0;
1009 }
1010
1011 decrypt_cpa(decrypted, priv, ciphertext);
1012 memcpy(decrypted + 32, priv->pub.public_key_hash,
1013 sizeof(decrypted) - 32);
1014 hash_g(key_and_randomness, decrypted, sizeof(decrypted));
1015 encrypt_cpa(expected_ciphertext, &priv->pub, decrypted,
1016 key_and_randomness + 32);
1017 kdf(failure_key, priv->fo_failure_secret, ciphertext, ciphertext_len);
1018 mask = constant_time_eq_int_8(memcmp(ciphertext, expected_ciphertext,
1019 sizeof(expected_ciphertext)), 0);
1020 for (i = 0; i < MLKEM_SHARED_SECRET_BYTES; i++) {
1021 out_shared_secret[i] = constant_time_select_8(mask,
1022 key_and_randomness[i], failure_key[i]);
1023 }
1024
1025 return 1;
1026}
1027LCRYPTO_ALIAS(MLKEM1024_decap);
1028
1029int
1030MLKEM1024_marshal_public_key(CBB *out,
1031 const struct MLKEM1024_public_key *public_key)
1032{
1033 return mlkem_marshal_public_key(out,
1034 public_key_1024_from_external(public_key));
1035}
1036LCRYPTO_ALIAS(MLKEM1024_marshal_public_key);
1037
1038/*
1039 * mlkem_parse_public_key_no_hash parses |in| into |pub| but doesn't calculate
1040 * the value of |pub->public_key_hash|.
1041 */
1042static int
1043mlkem_parse_public_key_no_hash(struct public_key *pub, CBS *in)
1044{
1045 CBS t_bytes;
1046
1047 if (!CBS_get_bytes(in, &t_bytes, kEncodedVectorSize) ||
1048 !vector_decode(&pub->t, CBS_data(&t_bytes), kLog2Prime)) {
1049 return 0;
1050 }
1051 memcpy(pub->rho, CBS_data(in), sizeof(pub->rho));
1052 if (!CBS_skip(in, sizeof(pub->rho)))
1053 return 0;
1054 matrix_expand(&pub->m, pub->rho);
1055 return 1;
1056}
1057
1058int
1059MLKEM1024_parse_public_key(struct MLKEM1024_public_key *public_key, CBS *in)
1060{
1061 struct public_key *pub = public_key_1024_from_external(public_key);
1062 CBS orig_in = *in;
1063
1064 if (!mlkem_parse_public_key_no_hash(pub, in) ||
1065 CBS_len(in) != 0) {
1066 return 0;
1067 }
1068 hash_h(pub->public_key_hash, CBS_data(&orig_in), CBS_len(&orig_in));
1069 return 1;
1070}
1071LCRYPTO_ALIAS(MLKEM1024_parse_public_key);
1072
1073int
1074MLKEM1024_marshal_private_key(CBB *out,
1075 const struct MLKEM1024_private_key *private_key)
1076{
1077 const struct private_key *const priv = private_key_1024_from_external(
1078 private_key);
1079 uint8_t *s_output;
1080
1081 if (!CBB_add_space(out, &s_output, kEncodedVectorSize)) {
1082 return 0;
1083 }
1084 vector_encode(s_output, &priv->s, kLog2Prime);
1085 if (!mlkem_marshal_public_key(out, &priv->pub) ||
1086 !CBB_add_bytes(out, priv->pub.public_key_hash,
1087 sizeof(priv->pub.public_key_hash)) ||
1088 !CBB_add_bytes(out, priv->fo_failure_secret,
1089 sizeof(priv->fo_failure_secret))) {
1090 return 0;
1091 }
1092 return 1;
1093}
1094
1095int
1096MLKEM1024_parse_private_key(struct MLKEM1024_private_key *out_private_key,
1097 CBS *in)
1098{
1099 struct private_key *const priv = private_key_1024_from_external(
1100 out_private_key);
1101 CBS s_bytes;
1102
1103 if (!CBS_get_bytes(in, &s_bytes, kEncodedVectorSize) ||
1104 !vector_decode(&priv->s, CBS_data(&s_bytes), kLog2Prime) ||
1105 !mlkem_parse_public_key_no_hash(&priv->pub, in)) {
1106 return 0;
1107 }
1108 memcpy(priv->pub.public_key_hash, CBS_data(in),
1109 sizeof(priv->pub.public_key_hash));
1110 if (!CBS_skip(in, sizeof(priv->pub.public_key_hash)))
1111 return 0;
1112 memcpy(priv->fo_failure_secret, CBS_data(in),
1113 sizeof(priv->fo_failure_secret));
1114 if (!CBS_skip(in, sizeof(priv->fo_failure_secret)))
1115 return 0;
1116 if (CBS_len(in) != 0)
1117 return 0;
1118
1119 return 1;
1120}
1121LCRYPTO_ALIAS(MLKEM1024_parse_private_key);
diff --git a/src/lib/libcrypto/mlkem/mlkem_internal.h b/src/lib/libcrypto/mlkem/mlkem_internal.h
index 3ef877f6d1..3141160ac2 100644
--- a/src/lib/libcrypto/mlkem/mlkem_internal.h
+++ b/src/lib/libcrypto/mlkem/mlkem_internal.h
@@ -69,6 +69,45 @@ void MLKEM768_encap_external_entropy(
69 const struct MLKEM768_public_key *public_key, 69 const struct MLKEM768_public_key *public_key,
70 const uint8_t entropy[MLKEM_ENCAP_ENTROPY]); 70 const uint8_t entropy[MLKEM_ENCAP_ENTROPY]);
71 71
72/*
73 * MLKEM1024_generate_key_external_entropy is a deterministic function to create a
74 * pair of ML-KEM 1024 keys, using the supplied entropy. The entropy needs to be
75 * uniformly random generated. This function is should only be used for tests,
76 * regular callers should use the non-deterministic |MLKEM_generate_key|
77 * directly.
78 */
79void MLKEM1024_generate_key_external_entropy(
80 uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES],
81 struct MLKEM1024_private_key *out_private_key,
82 const uint8_t entropy[MLKEM_SEED_BYTES]);
83
84/*
85 * MLKEM1024_PRIVATE_KEY_BYTES is the length of the data produced by
86 * |MLKEM1024_marshal_private_key|.
87 */
88#define MLKEM1024_PRIVATE_KEY_BYTES 3168
89
90/*
91 * MLKEM1024_marshal_private_key serializes |private_key| to |out| in the
92 * standard format for ML-KEM private keys. It returns one on success or zero on
93 * allocation error.
94 */
95int MLKEM1024_marshal_private_key(CBB *out,
96 const struct MLKEM1024_private_key *private_key);
97
98/*
99 * MLKEM_encap_external_entropy behaves like |MLKEM_encap|, but uses
100 * |MLKEM_ENCAP_ENTROPY| bytes of |entropy| for randomization. The decapsulating
101 * side will be able to recover |entropy| in full. This function should only be
102 * used for tests, regular callers should use the non-deterministic
103 * |MLKEM_encap| directly.
104 */
105void MLKEM1024_encap_external_entropy(
106 uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES],
107 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
108 const struct MLKEM1024_public_key *public_key,
109 const uint8_t entropy[MLKEM_ENCAP_ENTROPY]);
110
72__END_HIDDEN_DECLS 111__END_HIDDEN_DECLS
73 112
74#if defined(__cplusplus) 113#if defined(__cplusplus)