summaryrefslogtreecommitdiff
path: root/src/lib/libcrypto/mlkem
diff options
context:
space:
mode:
authorcvs2svn <admin@example.com>2025-04-14 17:32:06 +0000
committercvs2svn <admin@example.com>2025-04-14 17:32:06 +0000
commiteb8dd9dca1228af0cd132f515509051ecfabf6f6 (patch)
treeedb6da6af7e865d488dc1a29309f1e1ec226e603 /src/lib/libcrypto/mlkem
parent247f0352e0ed72a4f476db9dc91f4d982bc83eb2 (diff)
downloadopenbsd-tb_20250414.tar.gz
openbsd-tb_20250414.tar.bz2
openbsd-tb_20250414.zip
This commit was manufactured by cvs2git to create tag 'tb_20250414'.tb_20250414
Diffstat (limited to 'src/lib/libcrypto/mlkem')
-rw-r--r--src/lib/libcrypto/mlkem/mlkem.h285
-rw-r--r--src/lib/libcrypto/mlkem/mlkem1024.c1139
-rw-r--r--src/lib/libcrypto/mlkem/mlkem768.c1138
-rw-r--r--src/lib/libcrypto/mlkem/mlkem_internal.h121
4 files changed, 0 insertions, 2683 deletions
diff --git a/src/lib/libcrypto/mlkem/mlkem.h b/src/lib/libcrypto/mlkem/mlkem.h
deleted file mode 100644
index 055d92290e..0000000000
--- a/src/lib/libcrypto/mlkem/mlkem.h
+++ /dev/null
@@ -1,285 +0,0 @@
1/* $OpenBSD: mlkem.h,v 1.5 2025/03/28 12:17:16 tb Exp $ */
2/*
3 * Copyright (c) 2024, Google Inc.
4 *
5 * Permission to use, copy, modify, and/or distribute this software for any
6 * purpose with or without fee is hereby granted, provided that the above
7 * copyright notice and this permission notice appear in all copies.
8 *
9 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
12 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
14 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
15 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16 */
17
18#ifndef OPENSSL_HEADER_MLKEM_H
19#define OPENSSL_HEADER_MLKEM_H
20
21#include <sys/types.h>
22#include <stdint.h>
23
24#if defined(__cplusplus)
25extern "C" {
26#endif
27
28/* Hack for now */
29struct cbs_st;
30struct cbb_st;
31
32/*
33 * ML-KEM-768
34 *
35 * This implements the Module-Lattice-Based Key-Encapsulation Mechanism from
36 * https://csrc.nist.gov/pubs/fips/204/final
37 */
38
39/*
40 * MLKEM768_public_key contains a ML-KEM-768 public key. The contents of this
41 * object should never leave the address space since the format is unstable.
42 */
43struct MLKEM768_public_key {
44 union {
45 uint8_t bytes[512 * (3 + 9) + 32 + 32];
46 uint16_t alignment;
47 } opaque;
48};
49
50/*
51 * MLKEM768_private_key contains a ML-KEM-768 private key. The contents of this
52 * object should never leave the address space since the format is unstable.
53 */
54struct MLKEM768_private_key {
55 union {
56 uint8_t bytes[512 * (3 + 3 + 9) + 32 + 32 + 32];
57 uint16_t alignment;
58 } opaque;
59};
60
61/*
62 * MLKEM768_PUBLIC_KEY_BYTES is the number of bytes in an encoded ML-KEM768 public
63 * key.
64 */
65#define MLKEM768_PUBLIC_KEY_BYTES 1184
66
67/* MLKEM_SEED_BYTES is the number of bytes in an ML-KEM seed. */
68#define MLKEM_SEED_BYTES 64
69
70/*
71 * MLKEM_SHARED_SECRET_BYTES is the number of bytes in the ML-KEM768 shared
72 * secret. Although the round-3 specification has a variable-length output, the
73 * final ML-KEM construction is expected to use a fixed 32-byte output. To
74 * simplify the future transition, we apply the same restriction.
75 */
76#define MLKEM_SHARED_SECRET_BYTES 32
77
78/*
79 * MLKEM_generate_key generates a random public/private key pair, writes the
80 * encoded public key to |out_encoded_public_key| and sets |out_private_key| to
81 * the private key. If |optional_out_seed| is not NULL then the seed used to
82 * generate the private key is written to it.
83 */
84void MLKEM768_generate_key(
85 uint8_t out_encoded_public_key[MLKEM768_PUBLIC_KEY_BYTES],
86 uint8_t optional_out_seed[MLKEM_SEED_BYTES],
87 struct MLKEM768_private_key *out_private_key);
88
89/*
90 * MLKEM768_private_key_from_seed derives a private key from a seed that was
91 * generated by |MLKEM768_generate_key|. It fails and returns 0 if |seed_len| is
92 * incorrect, otherwise it writes |*out_private_key| and returns 1.
93 */
94int MLKEM768_private_key_from_seed(struct MLKEM768_private_key *out_private_key,
95 const uint8_t *seed, size_t seed_len);
96
97/*
98 * MLKEM_public_from_private sets |*out_public_key| to the public key that
99 * corresponds to |private_key|. (This is faster than parsing the output of
100 * |MLKEM_generate_key| if, for some reason, you need to encapsulate to a key
101 * that was just generated.)
102 */
103void MLKEM768_public_from_private(struct MLKEM768_public_key *out_public_key,
104 const struct MLKEM768_private_key *private_key);
105
106/* MLKEM768_CIPHERTEXT_BYTES is number of bytes in the ML-KEM768 ciphertext. */
107#define MLKEM768_CIPHERTEXT_BYTES 1088
108
109/*
110 * MLKEM768_encap encrypts a random shared secret for |public_key|, writes the
111 * ciphertext to |out_ciphertext|, and writes the random shared secret to
112 * |out_shared_secret|.
113 */
114void MLKEM768_encap(uint8_t out_ciphertext[MLKEM768_CIPHERTEXT_BYTES],
115 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
116 const struct MLKEM768_public_key *public_key);
117
118/*
119 * MLKEM768_decap decrypts a shared secret from |ciphertext| using |private_key|
120 * and writes it to |out_shared_secret|. If |ciphertext_len| is incorrect it
121 * returns 0, otherwise it rreturns 1. If |ciphertext| is invalid,
122 * |out_shared_secret| is filled with a key that will always be the same for the
123 * same |ciphertext| and |private_key|, but which appears to be random unless
124 * one has access to |private_key|. These alternatives occur in constant time.
125 * Any subsequent symmetric encryption using |out_shared_secret| must use an
126 * authenticated encryption scheme in order to discover the decapsulation
127 * failure.
128 */
129int MLKEM768_decap(uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
130 const uint8_t *ciphertext, size_t ciphertext_len,
131 const struct MLKEM768_private_key *private_key);
132
133/* Serialisation of keys. */
134
135/*
136 * MLKEM768_marshal_public_key serializes |public_key| to |out| in the standard
137 * format for ML-KEM public keys. It returns one on success or zero on allocation
138 * error.
139 */
140int MLKEM768_marshal_public_key(struct cbb_st *out,
141 const struct MLKEM768_public_key *public_key);
142
143/*
144 * MLKEM768_parse_public_key parses a public key, in the format generated by
145 * |MLKEM_marshal_public_key|, from |in| and writes the result to
146 * |out_public_key|. It returns one on success or zero on parse error or if
147 * there are trailing bytes in |in|.
148 */
149int MLKEM768_parse_public_key(struct MLKEM768_public_key *out_public_key,
150 struct cbs_st *in);
151
152/*
153 * MLKEM_parse_private_key parses a private key, in the format generated by
154 * |MLKEM_marshal_private_key|, from |in| and writes the result to
155 * |out_private_key|. It returns one on success or zero on parse error or if
156 * there are trailing bytes in |in|. This formate is verbose and should be avoided.
157 * Private keys should be stored as seeds and parsed using |MLKEM768_private_key_from_seed|.
158 */
159int MLKEM768_parse_private_key(struct MLKEM768_private_key *out_private_key,
160 struct cbs_st *in);
161
162/*
163 * ML-KEM-1024
164 *
165 * ML-KEM-1024 also exists. You should prefer ML-KEM-768 where possible.
166 */
167
168/*
169 * MLKEM1024_public_key contains an ML-KEM-1024 public key. The contents of this
170 * object should never leave the address space since the format is unstable.
171 */
172struct MLKEM1024_public_key {
173 union {
174 uint8_t bytes[512 * (4 + 16) + 32 + 32];
175 uint16_t alignment;
176 } opaque;
177};
178
179/*
180 * MLKEM1024_private_key contains a ML-KEM-1024 private key. The contents of
181 * this object should never leave the address space since the format is
182 * unstable.
183 */
184struct MLKEM1024_private_key {
185 union {
186 uint8_t bytes[512 * (4 + 4 + 16) + 32 + 32 + 32];
187 uint16_t alignment;
188 } opaque;
189};
190
191/*
192 * MLKEM1024_PUBLIC_KEY_BYTES is the number of bytes in an encoded ML-KEM-1024
193 * public key.
194 */
195#define MLKEM1024_PUBLIC_KEY_BYTES 1568
196
197/*
198 * MLKEM1024_generate_key generates a random public/private key pair, writes the
199 * encoded public key to |out_encoded_public_key| and sets |out_private_key| to
200 * the private key. If |optional_out_seed| is not NULL then the seed used to
201 * generate the private key is written to it.
202 */
203void MLKEM1024_generate_key(
204 uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES],
205 uint8_t optional_out_seed[MLKEM_SEED_BYTES],
206 struct MLKEM1024_private_key *out_private_key);
207
208/*
209 * MLKEM1024_private_key_from_seed derives a private key from a seed that was
210 * generated by |MLKEM1024_generate_key|. It fails and returns 0 if |seed_len|
211 * is incorrect, otherwise it writes |*out_private_key| and returns 1.
212 */
213int MLKEM1024_private_key_from_seed(
214 struct MLKEM1024_private_key *out_private_key, const uint8_t *seed,
215 size_t seed_len);
216
217/*
218 * MLKEM1024_public_from_private sets |*out_public_key| to the public key that
219 * corresponds to |private_key|. (This is faster than parsing the output of
220 * |MLKEM1024_generate_key| if, for some reason, you need to encapsulate to a
221 * key that was just generated.)
222 */
223void MLKEM1024_public_from_private(struct MLKEM1024_public_key *out_public_key,
224 const struct MLKEM1024_private_key *private_key);
225
226/* MLKEM1024_CIPHERTEXT_BYTES is number of bytes in the ML-KEM-1024 ciphertext. */
227#define MLKEM1024_CIPHERTEXT_BYTES 1568
228
229/*
230 * MLKEM1024_encap encrypts a random shared secret for |public_key|, writes the
231 * ciphertext to |out_ciphertext|, and writes the random shared secret to
232 * |out_shared_secret|.
233 */
234void MLKEM1024_encap(uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES],
235 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
236 const struct MLKEM1024_public_key *public_key);
237
238/*
239 * MLKEM1024_decap decrypts a shared secret from |ciphertext| using
240 * |private_key| and writes it to |out_shared_secret|. If |ciphertext_len| is
241 * incorrect it returns 0, otherwise it returns 1. If |ciphertext| is invalid
242 * (but of the correct length), |out_shared_secret| is filled with a key that
243 * will always be the same for the same |ciphertext| and |private_key|, but
244 * which appears to be random unless one has access to |private_key|. These
245 * alternatives occur in constant time. Any subsequent symmetric encryption
246 * using |out_shared_secret| must use an authenticated encryption scheme in
247 * order to discover the decapsulation failure.
248 */
249int MLKEM1024_decap(uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
250 const uint8_t *ciphertext, size_t ciphertext_len,
251 const struct MLKEM1024_private_key *private_key);
252
253/*
254 * Serialisation of ML-KEM-1024 keys.
255 * MLKEM1024_marshal_public_key serializes |public_key| to |out| in the standard
256 * format for ML-KEM-1024 public keys. It returns one on success or zero on
257 * allocation error.
258 */
259int MLKEM1024_marshal_public_key(struct cbb_st *out,
260 const struct MLKEM1024_public_key *public_key);
261
262/*
263 * MLKEM1024_parse_public_key parses a public key, in the format generated by
264 * |MLKEM1024_marshal_public_key|, from |in| and writes the result to
265 * |out_public_key|. It returns one on success or zero on parse error or if
266 * there are trailing bytes in |in|.
267 */
268int MLKEM1024_parse_public_key(struct MLKEM1024_public_key *out_public_key,
269 struct cbs_st *in);
270
271/*
272 * MLKEM1024_parse_private_key parses a private key, in NIST's format for
273 * private keys, from |in| and writes the result to |out_private_key|. It
274 * returns one on success or zero on parse error or if there are trailing bytes
275 * in |in|. This format is verbose and should be avoided. Private keys should be
276 * stored as seeds and parsed using |MLKEM1024_private_key_from_seed|.
277 */
278int MLKEM1024_parse_private_key(struct MLKEM1024_private_key *out_private_key,
279 struct cbs_st *in);
280
281#if defined(__cplusplus)
282}
283#endif
284
285#endif /* OPENSSL_HEADER_MLKEM_H */
diff --git a/src/lib/libcrypto/mlkem/mlkem1024.c b/src/lib/libcrypto/mlkem/mlkem1024.c
deleted file mode 100644
index f6fccdf6a8..0000000000
--- a/src/lib/libcrypto/mlkem/mlkem1024.c
+++ /dev/null
@@ -1,1139 +0,0 @@
1/* $OpenBSD: mlkem1024.c,v 1.6 2025/01/03 08:19:24 tb Exp $ */
2/*
3 * Copyright (c) 2024, Google Inc.
4 * Copyright (c) 2024, Bob Beck <beck@obtuse.com>
5 *
6 * Permission to use, copy, modify, and/or distribute this software for any
7 * purpose with or without fee is hereby granted, provided that the above
8 * copyright notice and this permission notice appear in all copies.
9 *
10 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
13 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
15 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
16 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17 */
18
19#include <assert.h>
20#include <stdlib.h>
21#include <string.h>
22
23#include "bytestring.h"
24#include "mlkem.h"
25
26#include "sha3_internal.h"
27#include "mlkem_internal.h"
28#include "constant_time.h"
29#include "crypto_internal.h"
30
31/* Remove later */
32#undef LCRYPTO_ALIAS
33#define LCRYPTO_ALIAS(A)
34
35/*
36 * See
37 * https://csrc.nist.gov/pubs/fips/203/final
38 */
39
40static void
41prf(uint8_t *out, size_t out_len, const uint8_t in[33])
42{
43 sha3_ctx ctx;
44 shake256_init(&ctx);
45 shake_update(&ctx, in, 33);
46 shake_xof(&ctx);
47 shake_out(&ctx, out, out_len);
48}
49
50/* Section 4.1 */
51static void
52hash_h(uint8_t out[32], const uint8_t *in, size_t len)
53{
54 sha3_ctx ctx;
55 sha3_init(&ctx, 32);
56 sha3_update(&ctx, in, len);
57 sha3_final(out, &ctx);
58}
59
60static void
61hash_g(uint8_t out[64], const uint8_t *in, size_t len)
62{
63 sha3_ctx ctx;
64 sha3_init(&ctx, 64);
65 sha3_update(&ctx, in, len);
66 sha3_final(out, &ctx);
67}
68
69/* this is called 'J' in the spec */
70static void
71kdf(uint8_t out[MLKEM_SHARED_SECRET_BYTES], const uint8_t failure_secret[32],
72 const uint8_t *in, size_t len)
73{
74 sha3_ctx ctx;
75 shake256_init(&ctx);
76 shake_update(&ctx, failure_secret, 32);
77 shake_update(&ctx, in, len);
78 shake_xof(&ctx);
79 shake_out(&ctx, out, MLKEM_SHARED_SECRET_BYTES);
80}
81
82#define DEGREE 256
83#define RANK1024 4
84
85static const size_t kBarrettMultiplier = 5039;
86static const unsigned kBarrettShift = 24;
87static const uint16_t kPrime = 3329;
88static const int kLog2Prime = 12;
89static const uint16_t kHalfPrime = (/*kPrime=*/3329 - 1) / 2;
90static const int kDU1024 = 11;
91static const int kDV1024 = 5;
92
93/*
94 * kInverseDegree is 128^-1 mod 3329; 128 because kPrime does not have a 512th
95 * root of unity.
96 */
97static const uint16_t kInverseDegree = 3303;
98static const size_t kEncodedVectorSize =
99 (/*kLog2Prime=*/12 * DEGREE / 8) * RANK1024;
100static const size_t kCompressedVectorSize = /*kDU1024=*/ 11 * RANK1024 * DEGREE /
101 8;
102
103typedef struct scalar {
104 /* On every function entry and exit, 0 <= c < kPrime. */
105 uint16_t c[DEGREE];
106} scalar;
107
108typedef struct vector {
109 scalar v[RANK1024];
110} vector;
111
112typedef struct matrix {
113 scalar v[RANK1024][RANK1024];
114} matrix;
115
116/*
117 * This bit of Python will be referenced in some of the following comments:
118 *
119 * p = 3329
120 *
121 * def bitreverse(i):
122 * ret = 0
123 * for n in range(7):
124 * bit = i & 1
125 * ret <<= 1
126 * ret |= bit
127 * i >>= 1
128 * return ret
129 */
130
131/* kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)] */
132static const uint16_t kNTTRoots[128] = {
133 1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797,
134 2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333,
135 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756,
136 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915,
137 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648,
138 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100,
139 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789,
140 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641,
141 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757,
142 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886,
143 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154,
144};
145
146/* kInverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)] */
147static const uint16_t kInverseNTTRoots[128] = {
148 1, 1600, 40, 749, 2481, 1432, 2699, 687, 1583, 2760, 69, 543,
149 2532, 3136, 1410, 2267, 2508, 1355, 450, 936, 447, 2794, 1235, 1903,
150 1996, 1089, 3273, 283, 1853, 1990, 882, 3033, 2419, 2102, 219, 855,
151 2681, 1848, 712, 682, 927, 1795, 461, 1891, 2877, 2522, 1894, 1010,
152 1414, 2009, 3296, 464, 2697, 816, 1352, 2679, 1274, 1052, 1025, 2132,
153 1573, 76, 2998, 3040, 1175, 2444, 394, 1219, 2300, 1455, 2117, 1607,
154 2443, 554, 1179, 2186, 2303, 2926, 2237, 525, 735, 863, 2768, 1230,
155 2572, 556, 3010, 2266, 1684, 1239, 780, 2954, 109, 1292, 1031, 1745,
156 2688, 3061, 992, 2596, 941, 892, 1021, 2390, 642, 1868, 2377, 1482,
157 1540, 540, 1678, 1626, 279, 314, 1173, 2573, 3096, 48, 667, 1920,
158 2229, 1041, 2606, 1692, 680, 2746, 568, 3312,
159};
160
161/* kModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)] */
162static const uint16_t kModRoots[128] = {
163 17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606,
164 2288, 1041, 1100, 2229, 1409, 1920, 2662, 667, 3281, 48, 233, 3096,
165 756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678,
166 2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642,
167 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992,
168 268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109,
169 375, 2954, 2549, 780, 2090, 1239, 1645, 1684, 1063, 2266, 319, 3010,
170 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735,
171 2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179,
172 2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300,
173 2110, 1219, 2935, 394, 885, 2444, 2154, 1175,
174};
175
176/* reduce_once reduces 0 <= x < 2*kPrime, mod kPrime. */
177static uint16_t
178reduce_once(uint16_t x)
179{
180 assert(x < 2 * kPrime);
181 const uint16_t subtracted = x - kPrime;
182 uint16_t mask = 0u - (subtracted >> 15);
183
184 /*
185 * Although this is a constant-time select, we omit a value barrier here.
186 * Value barriers impede auto-vectorization (likely because it forces the
187 * value to transit through a general-purpose register). On AArch64, this
188 * is a difference of 2x.
189 *
190 * We usually add value barriers to selects because Clang turns
191 * consecutive selects with the same condition into a branch instead of
192 * CMOV/CSEL. This condition does not occur in ML-KEM, so omitting it
193 * seems to be safe so far but see
194 * |scalar_centered_binomial_distribution_eta_2_with_prf|.
195 */
196 return (mask & x) | (~mask & subtracted);
197}
198
199/*
200 * constant time reduce x mod kPrime using Barrett reduction. x must be less
201 * than kPrime + 2×kPrime².
202 */
203static uint16_t
204reduce(uint32_t x)
205{
206 uint64_t product = (uint64_t)x * kBarrettMultiplier;
207 uint32_t quotient = (uint32_t)(product >> kBarrettShift);
208 uint32_t remainder = x - quotient * kPrime;
209
210 assert(x < kPrime + 2u * kPrime * kPrime);
211 return reduce_once(remainder);
212}
213
214static void
215scalar_zero(scalar *out)
216{
217 memset(out, 0, sizeof(*out));
218}
219
220static void
221vector_zero(vector *out)
222{
223 memset(out, 0, sizeof(*out));
224}
225
226/*
227 * In place number theoretic transform of a given scalar.
228 * Note that MLKEM's kPrime 3329 does not have a 512th root of unity, so this
229 * transform leaves off the last iteration of the usual FFT code, with the 128
230 * relevant roots of unity being stored in |kNTTRoots|. This means the output
231 * should be seen as 128 elements in GF(3329^2), with the coefficients of the
232 * elements being consecutive entries in |s->c|.
233 */
234static void
235scalar_ntt(scalar *s)
236{
237 int offset = DEGREE;
238 int step;
239 /*
240 * `int` is used here because using `size_t` throughout caused a ~5% slowdown
241 * with Clang 14 on Aarch64.
242 */
243 for (step = 1; step < DEGREE / 2; step <<= 1) {
244 int i, j, k = 0;
245
246 offset >>= 1;
247 for (i = 0; i < step; i++) {
248 const uint32_t step_root = kNTTRoots[i + step];
249
250 for (j = k; j < k + offset; j++) {
251 uint16_t odd, even;
252
253 odd = reduce(step_root * s->c[j + offset]);
254 even = s->c[j];
255 s->c[j] = reduce_once(odd + even);
256 s->c[j + offset] = reduce_once(even - odd +
257 kPrime);
258 }
259 k += 2 * offset;
260 }
261 }
262}
263
264static void
265vector_ntt(vector *a)
266{
267 int i;
268
269 for (i = 0; i < RANK1024; i++) {
270 scalar_ntt(&a->v[i]);
271 }
272}
273
274/*
275 * In place inverse number theoretic transform of a given scalar, with pairs of
276 * entries of s->v being interpreted as elements of GF(3329^2). Just as with the
277 * number theoretic transform, this leaves off the first step of the normal iFFT
278 * to account for the fact that 3329 does not have a 512th root of unity, using
279 * the precomputed 128 roots of unity stored in |kInverseNTTRoots|.
280 */
281static void
282scalar_inverse_ntt(scalar *s)
283{
284 int i, j, k, offset, step = DEGREE / 2;
285
286 /*
287 * `int` is used here because using `size_t` throughout caused a ~5% slowdown
288 * with Clang 14 on Aarch64.
289 */
290 for (offset = 2; offset < DEGREE; offset <<= 1) {
291 step >>= 1;
292 k = 0;
293 for (i = 0; i < step; i++) {
294 uint32_t step_root = kInverseNTTRoots[i + step];
295 for (j = k; j < k + offset; j++) {
296 uint16_t odd, even;
297 odd = s->c[j + offset];
298 even = s->c[j];
299 s->c[j] = reduce_once(odd + even);
300 s->c[j + offset] = reduce(step_root *
301 (even - odd + kPrime));
302 }
303 k += 2 * offset;
304 }
305 }
306 for (i = 0; i < DEGREE; i++) {
307 s->c[i] = reduce(s->c[i] * kInverseDegree);
308 }
309}
310
311static void
312vector_inverse_ntt(vector *a)
313{
314 int i;
315
316 for (i = 0; i < RANK1024; i++) {
317 scalar_inverse_ntt(&a->v[i]);
318 }
319}
320
321static void
322scalar_add(scalar *lhs, const scalar *rhs)
323{
324 int i;
325
326 for (i = 0; i < DEGREE; i++) {
327 lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
328 }
329}
330
331static void
332scalar_sub(scalar *lhs, const scalar *rhs)
333{
334 int i;
335
336 for (i = 0; i < DEGREE; i++) {
337 lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
338 }
339}
340
341/*
342 * Multiplying two scalars in the number theoretically transformed state.
343 * Since 3329 does not have a 512th root of unity, this means we have to
344 * interpret the 2*ith and (2*i+1)th entries of the scalar as elements of
345 * GF(3329)[X]/(X^2 - 17^(2*bitreverse(i)+1)).
346 * The value of 17^(2*bitreverse(i)+1) mod 3329 is stored in the precomputed
347 * |kModRoots| table. Our Barrett transform only allows us to multiply two
348 * reduced numbers together, so we need some intermediate reduction steps,
349 * even if an uint64_t could hold 3 multiplied numbers.
350 */
351static void
352scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs)
353{
354 int i;
355
356 for (i = 0; i < DEGREE / 2; i++) {
357 uint32_t real_real = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i];
358 uint32_t img_img = (uint32_t)lhs->c[2 * i + 1] *
359 rhs->c[2 * i + 1];
360 uint32_t real_img = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i + 1];
361 uint32_t img_real = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i];
362
363 out->c[2 * i] =
364 reduce(real_real +
365 (uint32_t)reduce(img_img) * kModRoots[i]);
366 out->c[2 * i + 1] = reduce(img_real + real_img);
367 }
368}
369
370static void
371vector_add(vector *lhs, const vector *rhs)
372{
373 int i;
374
375 for (i = 0; i < RANK1024; i++) {
376 scalar_add(&lhs->v[i], &rhs->v[i]);
377 }
378}
379
380static void
381matrix_mult(vector *out, const matrix *m, const vector *a)
382{
383 int i, j;
384
385 vector_zero(out);
386 for (i = 0; i < RANK1024; i++) {
387 for (j = 0; j < RANK1024; j++) {
388 scalar product;
389
390 scalar_mult(&product, &m->v[i][j], &a->v[j]);
391 scalar_add(&out->v[i], &product);
392 }
393 }
394}
395
396static void
397matrix_mult_transpose(vector *out, const matrix *m,
398 const vector *a)
399{
400 int i, j;
401
402 vector_zero(out);
403 for (i = 0; i < RANK1024; i++) {
404 for (j = 0; j < RANK1024; j++) {
405 scalar product;
406
407 scalar_mult(&product, &m->v[j][i], &a->v[j]);
408 scalar_add(&out->v[i], &product);
409 }
410 }
411}
412
413static void
414scalar_inner_product(scalar *out, const vector *lhs,
415 const vector *rhs)
416{
417 int i;
418 scalar_zero(out);
419 for (i = 0; i < RANK1024; i++) {
420 scalar product;
421
422 scalar_mult(&product, &lhs->v[i], &rhs->v[i]);
423 scalar_add(out, &product);
424 }
425}
426
427/*
428 * Algorithm 6 of spec. Rejection samples a Keccak stream to get uniformly
429 * distributed elements. This is used for matrix expansion and only operates on
430 * public inputs.
431 */
432static void
433scalar_from_keccak_vartime(scalar *out, sha3_ctx *keccak_ctx)
434{
435 int i, done = 0;
436
437 while (done < DEGREE) {
438 uint8_t block[168];
439
440 shake_out(keccak_ctx, block, sizeof(block));
441 for (i = 0; i < sizeof(block) && done < DEGREE; i += 3) {
442 uint16_t d1 = block[i] + 256 * (block[i + 1] % 16);
443 uint16_t d2 = block[i + 1] / 16 + 16 * block[i + 2];
444
445 if (d1 < kPrime) {
446 out->c[done++] = d1;
447 }
448 if (d2 < kPrime && done < DEGREE) {
449 out->c[done++] = d2;
450 }
451 }
452 }
453}
454
455/*
456 * Algorithm 7 of the spec, with eta fixed to two and the PRF call
457 * included. Creates binominally distributed elements by sampling 2*|eta| bits,
458 * and setting the coefficient to the count of the first bits minus the count of
459 * the second bits, resulting in a centered binomial distribution. Since eta is
460 * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
461 * and 0 with probability 3/8.
462 */
463static void
464scalar_centered_binomial_distribution_eta_2_with_prf(scalar *out,
465 const uint8_t input[33])
466{
467 uint8_t entropy[128];
468 int i;
469
470 CTASSERT(sizeof(entropy) == 2 * /*kEta=*/ 2 * DEGREE / 8);
471 prf(entropy, sizeof(entropy), input);
472
473 for (i = 0; i < DEGREE; i += 2) {
474 uint8_t byte = entropy[i / 2];
475 uint16_t mask;
476 uint16_t value = (byte & 1) + ((byte >> 1) & 1);
477
478 value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
479
480 /*
481 * Add |kPrime| if |value| underflowed. See |reduce_once| for a
482 * discussion on why the value barrier is omitted. While this
483 * could have been written reduce_once(value + kPrime), this is
484 * one extra addition and small range of |value| tempts some
485 * versions of Clang to emit a branch.
486 */
487 mask = 0u - (value >> 15);
488 out->c[i] = ((value + kPrime) & mask) | (value & ~mask);
489
490 byte >>= 4;
491 value = (byte & 1) + ((byte >> 1) & 1);
492 value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
493 /* See above. */
494 mask = 0u - (value >> 15);
495 out->c[i + 1] = ((value + kPrime) & mask) | (value & ~mask);
496 }
497}
498
499/*
500 * Generates a secret vector by using
501 * |scalar_centered_binomial_distribution_eta_2_with_prf|, using the given seed
502 * appending and incrementing |counter| for entry of the vector.
503 */
504static void
505vector_generate_secret_eta_2(vector *out, uint8_t *counter,
506 const uint8_t seed[32])
507{
508 uint8_t input[33];
509 int i;
510
511 memcpy(input, seed, 32);
512 for (i = 0; i < RANK1024; i++) {
513 input[32] = (*counter)++;
514 scalar_centered_binomial_distribution_eta_2_with_prf(&out->v[i],
515 input);
516 }
517}
518
519/* Expands the matrix of a seed for key generation and for encaps-CPA. */
520static void
521matrix_expand(matrix *out, const uint8_t rho[32])
522{
523 uint8_t input[34];
524 int i, j;
525
526 memcpy(input, rho, 32);
527 for (i = 0; i < RANK1024; i++) {
528 for (j = 0; j < RANK1024; j++) {
529 sha3_ctx keccak_ctx;
530
531 input[32] = i;
532 input[33] = j;
533 shake128_init(&keccak_ctx);
534 shake_update(&keccak_ctx, input, sizeof(input));
535 shake_xof(&keccak_ctx);
536 scalar_from_keccak_vartime(&out->v[i][j], &keccak_ctx);
537 }
538 }
539}
540
541static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f,
542 0x1f, 0x3f, 0x7f, 0xff};
543
544static void
545scalar_encode(uint8_t *out, const scalar *s, int bits)
546{
547 uint8_t out_byte = 0;
548 int i, out_byte_bits = 0;
549
550 assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1);
551 for (i = 0; i < DEGREE; i++) {
552 uint16_t element = s->c[i];
553 int element_bits_done = 0;
554
555 while (element_bits_done < bits) {
556 int chunk_bits = bits - element_bits_done;
557 int out_bits_remaining = 8 - out_byte_bits;
558
559 if (chunk_bits >= out_bits_remaining) {
560 chunk_bits = out_bits_remaining;
561 out_byte |= (element &
562 kMasks[chunk_bits - 1]) << out_byte_bits;
563 *out = out_byte;
564 out++;
565 out_byte_bits = 0;
566 out_byte = 0;
567 } else {
568 out_byte |= (element &
569 kMasks[chunk_bits - 1]) << out_byte_bits;
570 out_byte_bits += chunk_bits;
571 }
572
573 element_bits_done += chunk_bits;
574 element >>= chunk_bits;
575 }
576 }
577
578 if (out_byte_bits > 0) {
579 *out = out_byte;
580 }
581}
582
583/* scalar_encode_1 is |scalar_encode| specialised for |bits| == 1. */
584static void
585scalar_encode_1(uint8_t out[32], const scalar *s)
586{
587 int i, j;
588
589 for (i = 0; i < DEGREE; i += 8) {
590 uint8_t out_byte = 0;
591
592 for (j = 0; j < 8; j++) {
593 out_byte |= (s->c[i + j] & 1) << j;
594 }
595 *out = out_byte;
596 out++;
597 }
598}
599
600/*
601 * Encodes an entire vector into 32*|RANK1024|*|bits| bytes. Note that since 256
602 * (DEGREE) is divisible by 8, the individual vector entries will always fill a
603 * whole number of bytes, so we do not need to worry about bit packing here.
604 */
605static void
606vector_encode(uint8_t *out, const vector *a, int bits)
607{
608 int i;
609
610 for (i = 0; i < RANK1024; i++) {
611 scalar_encode(out + i * bits * DEGREE / 8, &a->v[i], bits);
612 }
613}
614
615/*
616 * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
617 * |out|. It returns one on success and zero if any parsed value is >=
618 * |kPrime|.
619 */
620static int
621scalar_decode(scalar *out, const uint8_t *in, int bits)
622{
623 uint8_t in_byte = 0;
624 int i, in_byte_bits_left = 0;
625
626 assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1);
627
628 for (i = 0; i < DEGREE; i++) {
629 uint16_t element = 0;
630 int element_bits_done = 0;
631
632 while (element_bits_done < bits) {
633 int chunk_bits = bits - element_bits_done;
634
635 if (in_byte_bits_left == 0) {
636 in_byte = *in;
637 in++;
638 in_byte_bits_left = 8;
639 }
640
641 if (chunk_bits > in_byte_bits_left) {
642 chunk_bits = in_byte_bits_left;
643 }
644
645 element |= (in_byte & kMasks[chunk_bits - 1]) <<
646 element_bits_done;
647 in_byte_bits_left -= chunk_bits;
648 in_byte >>= chunk_bits;
649
650 element_bits_done += chunk_bits;
651 }
652
653 if (element >= kPrime) {
654 return 0;
655 }
656 out->c[i] = element;
657 }
658
659 return 1;
660}
661
662/* scalar_decode_1 is |scalar_decode| specialised for |bits| == 1. */
663static void
664scalar_decode_1(scalar *out, const uint8_t in[32])
665{
666 int i, j;
667
668 for (i = 0; i < DEGREE; i += 8) {
669 uint8_t in_byte = *in;
670
671 in++;
672 for (j = 0; j < 8; j++) {
673 out->c[i + j] = in_byte & 1;
674 in_byte >>= 1;
675 }
676 }
677}
678
679/*
680 * Decodes 32*|RANK1024|*|bits| bytes from |in| into |out|. It returns one on
681 * success or zero if any parsed value is >= |kPrime|.
682 */
683static int
684vector_decode(vector *out, const uint8_t *in, int bits)
685{
686 int i;
687
688 for (i = 0; i < RANK1024; i++) {
689 if (!scalar_decode(&out->v[i], in + i * bits * DEGREE / 8,
690 bits)) {
691 return 0;
692 }
693 }
694 return 1;
695}
696
697/*
698 * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
699 * numbers close to each other together. The formula used is
700 * round(2^|bits|/kPrime*x) mod 2^|bits|.
701 * Uses Barrett reduction to achieve constant time. Since we need both the
702 * remainder (for rounding) and the quotient (as the result), we cannot use
703 * |reduce| here, but need to do the Barrett reduction directly.
704 */
705static uint16_t
706compress(uint16_t x, int bits)
707{
708 uint32_t shifted = (uint32_t)x << bits;
709 uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
710 uint32_t quotient = (uint32_t)(product >> kBarrettShift);
711 uint32_t remainder = shifted - quotient * kPrime;
712
713 /*
714 * Adjust the quotient to round correctly:
715 * 0 <= remainder <= kHalfPrime round to 0
716 * kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
717 * kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
718 */
719 assert(remainder < 2u * kPrime);
720 quotient += 1 & constant_time_lt(kHalfPrime, remainder);
721 quotient += 1 & constant_time_lt(kPrime + kHalfPrime, remainder);
722 return quotient & ((1 << bits) - 1);
723}
724
725/*
726 * Decompresses |x| by using an equi-distant representative. The formula is
727 * round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us to
728 * implement this logic using only bit operations.
729 */
730static uint16_t
731decompress(uint16_t x, int bits)
732{
733 uint32_t product = (uint32_t)x * kPrime;
734 uint32_t power = 1 << bits;
735 /* This is |product| % power, since |power| is a power of 2. */
736 uint32_t remainder = product & (power - 1);
737 /* This is |product| / power, since |power| is a power of 2. */
738 uint32_t lower = product >> bits;
739
740 /*
741 * The rounding logic works since the first half of numbers mod |power| have a
742 * 0 as first bit, and the second half has a 1 as first bit, since |power| is
743 * a power of 2. As a 12 bit number, |remainder| is always positive, so we
744 * will shift in 0s for a right shift.
745 */
746 return lower + (remainder >> (bits - 1));
747}
748
749static void
750scalar_compress(scalar *s, int bits)
751{
752 int i;
753
754 for (i = 0; i < DEGREE; i++) {
755 s->c[i] = compress(s->c[i], bits);
756 }
757}
758
759static void
760scalar_decompress(scalar *s, int bits)
761{
762 int i;
763
764 for (i = 0; i < DEGREE; i++) {
765 s->c[i] = decompress(s->c[i], bits);
766 }
767}
768
769static void
770vector_compress(vector *a, int bits)
771{
772 int i;
773
774 for (i = 0; i < RANK1024; i++) {
775 scalar_compress(&a->v[i], bits);
776 }
777}
778
779static void
780vector_decompress(vector *a, int bits)
781{
782 int i;
783
784 for (i = 0; i < RANK1024; i++) {
785 scalar_decompress(&a->v[i], bits);
786 }
787}
788
789struct public_key {
790 vector t;
791 uint8_t rho[32];
792 uint8_t public_key_hash[32];
793 matrix m;
794};
795
796static struct public_key *
797public_key_1024_from_external(const struct MLKEM1024_public_key *external)
798{
799 return (struct public_key *)external;
800}
801
802struct private_key {
803 struct public_key pub;
804 vector s;
805 uint8_t fo_failure_secret[32];
806};
807
808static struct private_key *
809private_key_1024_from_external(const struct MLKEM1024_private_key *external)
810{
811 return (struct private_key *)external;
812}
813
814/*
815 * Calls |MLKEM1024_generate_key_external_entropy| with random bytes from
816 * |RAND_bytes|.
817 */
818void
819MLKEM1024_generate_key(uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES],
820 uint8_t optional_out_seed[MLKEM_SEED_BYTES],
821 struct MLKEM1024_private_key *out_private_key)
822{
823 uint8_t entropy_buf[MLKEM_SEED_BYTES];
824 uint8_t *entropy = optional_out_seed != NULL ? optional_out_seed :
825 entropy_buf;
826
827 arc4random_buf(entropy, MLKEM_SEED_BYTES);
828 MLKEM1024_generate_key_external_entropy(out_encoded_public_key,
829 out_private_key, entropy);
830}
831LCRYPTO_ALIAS(MLKEM1024_generate_key);
832
833int
834MLKEM1024_private_key_from_seed(struct MLKEM1024_private_key *out_private_key,
835 const uint8_t *seed, size_t seed_len)
836{
837 uint8_t public_key_bytes[MLKEM1024_PUBLIC_KEY_BYTES];
838
839 if (seed_len != MLKEM_SEED_BYTES) {
840 return 0;
841 }
842 MLKEM1024_generate_key_external_entropy(public_key_bytes,
843 out_private_key, seed);
844
845 return 1;
846}
847LCRYPTO_ALIAS(MLKEM1024_private_key_from_seed);
848
849static int
850mlkem_marshal_public_key(CBB *out, const struct public_key *pub)
851{
852 uint8_t *vector_output;
853
854 if (!CBB_add_space(out, &vector_output, kEncodedVectorSize)) {
855 return 0;
856 }
857 vector_encode(vector_output, &pub->t, kLog2Prime);
858 if (!CBB_add_bytes(out, pub->rho, sizeof(pub->rho))) {
859 return 0;
860 }
861 return 1;
862}
863
864void
865MLKEM1024_generate_key_external_entropy(
866 uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES],
867 struct MLKEM1024_private_key *out_private_key,
868 const uint8_t entropy[MLKEM_SEED_BYTES])
869{
870 struct private_key *priv = private_key_1024_from_external(
871 out_private_key);
872 uint8_t augmented_seed[33];
873 uint8_t *rho, *sigma;
874 uint8_t counter = 0;
875 uint8_t hashed[64];
876 vector error;
877 CBB cbb;
878
879 memcpy(augmented_seed, entropy, 32);
880 augmented_seed[32] = RANK1024;
881 hash_g(hashed, augmented_seed, 33);
882 rho = hashed;
883 sigma = hashed + 32;
884 memcpy(priv->pub.rho, hashed, sizeof(priv->pub.rho));
885 matrix_expand(&priv->pub.m, rho);
886 vector_generate_secret_eta_2(&priv->s, &counter, sigma);
887 vector_ntt(&priv->s);
888 vector_generate_secret_eta_2(&error, &counter, sigma);
889 vector_ntt(&error);
890 matrix_mult_transpose(&priv->pub.t, &priv->pub.m, &priv->s);
891 vector_add(&priv->pub.t, &error);
892
893 /* XXX - error checking. */
894 CBB_init_fixed(&cbb, out_encoded_public_key, MLKEM1024_PUBLIC_KEY_BYTES);
895 if (!mlkem_marshal_public_key(&cbb, &priv->pub)) {
896 abort();
897 }
898 CBB_cleanup(&cbb);
899
900 hash_h(priv->pub.public_key_hash, out_encoded_public_key,
901 MLKEM1024_PUBLIC_KEY_BYTES);
902 memcpy(priv->fo_failure_secret, entropy + 32, 32);
903}
904
905void
906MLKEM1024_public_from_private(struct MLKEM1024_public_key *out_public_key,
907 const struct MLKEM1024_private_key *private_key)
908{
909 struct public_key *const pub = public_key_1024_from_external(
910 out_public_key);
911 const struct private_key *const priv = private_key_1024_from_external(
912 private_key);
913
914 *pub = priv->pub;
915}
916LCRYPTO_ALIAS(MLKEM1024_public_from_private);
917
918/*
919 * Encrypts a message with given randomness to the ciphertext in |out|. Without
920 * applying the Fujisaki-Okamoto transform this would not result in a CCA secure
921 * scheme, since lattice schemes are vulnerable to decryption failure oracles.
922 */
923static void
924encrypt_cpa(uint8_t out[MLKEM1024_CIPHERTEXT_BYTES],
925 const struct public_key *pub, const uint8_t message[32],
926 const uint8_t randomness[32])
927{
928 scalar expanded_message, scalar_error;
929 vector secret, error, u;
930 uint8_t counter = 0;
931 uint8_t input[33];
932 scalar v;
933
934 vector_generate_secret_eta_2(&secret, &counter, randomness);
935 vector_ntt(&secret);
936 vector_generate_secret_eta_2(&error, &counter, randomness);
937 memcpy(input, randomness, 32);
938 input[32] = counter;
939 scalar_centered_binomial_distribution_eta_2_with_prf(&scalar_error,
940 input);
941 matrix_mult(&u, &pub->m, &secret);
942 vector_inverse_ntt(&u);
943 vector_add(&u, &error);
944 scalar_inner_product(&v, &pub->t, &secret);
945 scalar_inverse_ntt(&v);
946 scalar_add(&v, &scalar_error);
947 scalar_decode_1(&expanded_message, message);
948 scalar_decompress(&expanded_message, 1);
949 scalar_add(&v, &expanded_message);
950 vector_compress(&u, kDU1024);
951 vector_encode(out, &u, kDU1024);
952 scalar_compress(&v, kDV1024);
953 scalar_encode(out + kCompressedVectorSize, &v, kDV1024);
954}
955
956/* Calls MLKEM1024_encap_external_entropy| with random bytes */
957void
958MLKEM1024_encap(uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES],
959 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
960 const struct MLKEM1024_public_key *public_key)
961{
962 uint8_t entropy[MLKEM_ENCAP_ENTROPY];
963
964 arc4random_buf(entropy, MLKEM_ENCAP_ENTROPY);
965 MLKEM1024_encap_external_entropy(out_ciphertext, out_shared_secret,
966 public_key, entropy);
967}
968LCRYPTO_ALIAS(MLKEM1024_encap);
969
970/* See section 6.2 of the spec. */
971void
972MLKEM1024_encap_external_entropy(
973 uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES],
974 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
975 const struct MLKEM1024_public_key *public_key,
976 const uint8_t entropy[MLKEM_ENCAP_ENTROPY])
977{
978 const struct public_key *pub = public_key_1024_from_external(public_key);
979 uint8_t key_and_randomness[64];
980 uint8_t input[64];
981
982 memcpy(input, entropy, MLKEM_ENCAP_ENTROPY);
983 memcpy(input + MLKEM_ENCAP_ENTROPY, pub->public_key_hash,
984 sizeof(input) - MLKEM_ENCAP_ENTROPY);
985 hash_g(key_and_randomness, input, sizeof(input));
986 encrypt_cpa(out_ciphertext, pub, entropy, key_and_randomness + 32);
987 memcpy(out_shared_secret, key_and_randomness, 32);
988}
989
990static void
991decrypt_cpa(uint8_t out[32], const struct private_key *priv,
992 const uint8_t ciphertext[MLKEM1024_CIPHERTEXT_BYTES])
993{
994 scalar mask, v;
995 vector u;
996
997 vector_decode(&u, ciphertext, kDU1024);
998 vector_decompress(&u, kDU1024);
999 vector_ntt(&u);
1000 scalar_decode(&v, ciphertext + kCompressedVectorSize, kDV1024);
1001 scalar_decompress(&v, kDV1024);
1002 scalar_inner_product(&mask, &priv->s, &u);
1003 scalar_inverse_ntt(&mask);
1004 scalar_sub(&v, &mask);
1005 scalar_compress(&v, 1);
1006 scalar_encode_1(out, &v);
1007}
1008
1009/* See section 6.3 */
1010int
1011MLKEM1024_decap(uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
1012 const uint8_t *ciphertext, size_t ciphertext_len,
1013 const struct MLKEM1024_private_key *private_key)
1014{
1015 const struct private_key *priv = private_key_1024_from_external(
1016 private_key);
1017 uint8_t expected_ciphertext[MLKEM1024_CIPHERTEXT_BYTES];
1018 uint8_t key_and_randomness[64];
1019 uint8_t failure_key[32];
1020 uint8_t decrypted[64];
1021 uint8_t mask;
1022 int i;
1023
1024 if (ciphertext_len != MLKEM1024_CIPHERTEXT_BYTES) {
1025 arc4random_buf(out_shared_secret, MLKEM_SHARED_SECRET_BYTES);
1026 return 0;
1027 }
1028
1029 decrypt_cpa(decrypted, priv, ciphertext);
1030 memcpy(decrypted + 32, priv->pub.public_key_hash,
1031 sizeof(decrypted) - 32);
1032 hash_g(key_and_randomness, decrypted, sizeof(decrypted));
1033 encrypt_cpa(expected_ciphertext, &priv->pub, decrypted,
1034 key_and_randomness + 32);
1035 kdf(failure_key, priv->fo_failure_secret, ciphertext, ciphertext_len);
1036 mask = constant_time_eq_int_8(memcmp(ciphertext, expected_ciphertext,
1037 sizeof(expected_ciphertext)), 0);
1038 for (i = 0; i < MLKEM_SHARED_SECRET_BYTES; i++) {
1039 out_shared_secret[i] = constant_time_select_8(mask,
1040 key_and_randomness[i], failure_key[i]);
1041 }
1042
1043 return 1;
1044}
1045LCRYPTO_ALIAS(MLKEM1024_decap);
1046
1047int
1048MLKEM1024_marshal_public_key(CBB *out,
1049 const struct MLKEM1024_public_key *public_key)
1050{
1051 return mlkem_marshal_public_key(out,
1052 public_key_1024_from_external(public_key));
1053}
1054LCRYPTO_ALIAS(MLKEM1024_marshal_public_key);
1055
1056/*
1057 * mlkem_parse_public_key_no_hash parses |in| into |pub| but doesn't calculate
1058 * the value of |pub->public_key_hash|.
1059 */
1060static int
1061mlkem_parse_public_key_no_hash(struct public_key *pub, CBS *in)
1062{
1063 CBS t_bytes;
1064
1065 if (!CBS_get_bytes(in, &t_bytes, kEncodedVectorSize) ||
1066 !vector_decode(&pub->t, CBS_data(&t_bytes), kLog2Prime)) {
1067 return 0;
1068 }
1069 memcpy(pub->rho, CBS_data(in), sizeof(pub->rho));
1070 if (!CBS_skip(in, sizeof(pub->rho)))
1071 return 0;
1072 matrix_expand(&pub->m, pub->rho);
1073 return 1;
1074}
1075
1076int
1077MLKEM1024_parse_public_key(struct MLKEM1024_public_key *public_key, CBS *in)
1078{
1079 struct public_key *pub = public_key_1024_from_external(public_key);
1080 CBS orig_in = *in;
1081
1082 if (!mlkem_parse_public_key_no_hash(pub, in) ||
1083 CBS_len(in) != 0) {
1084 return 0;
1085 }
1086 hash_h(pub->public_key_hash, CBS_data(&orig_in), CBS_len(&orig_in));
1087 return 1;
1088}
1089LCRYPTO_ALIAS(MLKEM1024_parse_public_key);
1090
1091int
1092MLKEM1024_marshal_private_key(CBB *out,
1093 const struct MLKEM1024_private_key *private_key)
1094{
1095 const struct private_key *const priv = private_key_1024_from_external(
1096 private_key);
1097 uint8_t *s_output;
1098
1099 if (!CBB_add_space(out, &s_output, kEncodedVectorSize)) {
1100 return 0;
1101 }
1102 vector_encode(s_output, &priv->s, kLog2Prime);
1103 if (!mlkem_marshal_public_key(out, &priv->pub) ||
1104 !CBB_add_bytes(out, priv->pub.public_key_hash,
1105 sizeof(priv->pub.public_key_hash)) ||
1106 !CBB_add_bytes(out, priv->fo_failure_secret,
1107 sizeof(priv->fo_failure_secret))) {
1108 return 0;
1109 }
1110 return 1;
1111}
1112
1113int
1114MLKEM1024_parse_private_key(struct MLKEM1024_private_key *out_private_key,
1115 CBS *in)
1116{
1117 struct private_key *const priv = private_key_1024_from_external(
1118 out_private_key);
1119 CBS s_bytes;
1120
1121 if (!CBS_get_bytes(in, &s_bytes, kEncodedVectorSize) ||
1122 !vector_decode(&priv->s, CBS_data(&s_bytes), kLog2Prime) ||
1123 !mlkem_parse_public_key_no_hash(&priv->pub, in)) {
1124 return 0;
1125 }
1126 memcpy(priv->pub.public_key_hash, CBS_data(in),
1127 sizeof(priv->pub.public_key_hash));
1128 if (!CBS_skip(in, sizeof(priv->pub.public_key_hash)))
1129 return 0;
1130 memcpy(priv->fo_failure_secret, CBS_data(in),
1131 sizeof(priv->fo_failure_secret));
1132 if (!CBS_skip(in, sizeof(priv->fo_failure_secret)))
1133 return 0;
1134 if (CBS_len(in) != 0)
1135 return 0;
1136
1137 return 1;
1138}
1139LCRYPTO_ALIAS(MLKEM1024_parse_private_key);
diff --git a/src/lib/libcrypto/mlkem/mlkem768.c b/src/lib/libcrypto/mlkem/mlkem768.c
deleted file mode 100644
index bacde0c0b7..0000000000
--- a/src/lib/libcrypto/mlkem/mlkem768.c
+++ /dev/null
@@ -1,1138 +0,0 @@
1/* $OpenBSD: mlkem768.c,v 1.7 2025/01/03 08:19:24 tb Exp $ */
2/*
3 * Copyright (c) 2024, Google Inc.
4 * Copyright (c) 2024, Bob Beck <beck@obtuse.com>
5 *
6 * Permission to use, copy, modify, and/or distribute this software for any
7 * purpose with or without fee is hereby granted, provided that the above
8 * copyright notice and this permission notice appear in all copies.
9 *
10 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
13 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
15 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
16 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17 */
18
19#include <assert.h>
20#include <stdlib.h>
21#include <string.h>
22
23#include "bytestring.h"
24#include "mlkem.h"
25
26#include "sha3_internal.h"
27#include "mlkem_internal.h"
28#include "constant_time.h"
29#include "crypto_internal.h"
30
31/* Remove later */
32#undef LCRYPTO_ALIAS
33#define LCRYPTO_ALIAS(A)
34
35/*
36 * See
37 * https://csrc.nist.gov/pubs/fips/203/final
38 */
39
40static void
41prf(uint8_t *out, size_t out_len, const uint8_t in[33])
42{
43 sha3_ctx ctx;
44 shake256_init(&ctx);
45 shake_update(&ctx, in, 33);
46 shake_xof(&ctx);
47 shake_out(&ctx, out, out_len);
48}
49
50/* Section 4.1 */
51static void
52hash_h(uint8_t out[32], const uint8_t *in, size_t len)
53{
54 sha3_ctx ctx;
55 sha3_init(&ctx, 32);
56 sha3_update(&ctx, in, len);
57 sha3_final(out, &ctx);
58}
59
60static void
61hash_g(uint8_t out[64], const uint8_t *in, size_t len)
62{
63 sha3_ctx ctx;
64 sha3_init(&ctx, 64);
65 sha3_update(&ctx, in, len);
66 sha3_final(out, &ctx);
67}
68
69/* this is called 'J' in the spec */
70static void
71kdf(uint8_t out[MLKEM_SHARED_SECRET_BYTES], const uint8_t failure_secret[32],
72 const uint8_t *in, size_t len)
73{
74 sha3_ctx ctx;
75 shake256_init(&ctx);
76 shake_update(&ctx, failure_secret, 32);
77 shake_update(&ctx, in, len);
78 shake_xof(&ctx);
79 shake_out(&ctx, out, MLKEM_SHARED_SECRET_BYTES);
80}
81
82#define DEGREE 256
83#define RANK768 3
84
85static const size_t kBarrettMultiplier = 5039;
86static const unsigned kBarrettShift = 24;
87static const uint16_t kPrime = 3329;
88static const int kLog2Prime = 12;
89static const uint16_t kHalfPrime = (/*kPrime=*/3329 - 1) / 2;
90static const int kDU768 = 10;
91static const int kDV768 = 4;
92/*
93 * kInverseDegree is 128^-1 mod 3329; 128 because kPrime does not have a 512th
94 * root of unity.
95 */
96static const uint16_t kInverseDegree = 3303;
97static const size_t kEncodedVectorSize =
98 (/*kLog2Prime=*/12 * DEGREE / 8) * RANK768;
99static const size_t kCompressedVectorSize = /*kDU768=*/ 10 * RANK768 * DEGREE /
100 8;
101
102typedef struct scalar {
103 /* On every function entry and exit, 0 <= c < kPrime. */
104 uint16_t c[DEGREE];
105} scalar;
106
107typedef struct vector {
108 scalar v[RANK768];
109} vector;
110
111typedef struct matrix {
112 scalar v[RANK768][RANK768];
113} matrix;
114
115/*
116 * This bit of Python will be referenced in some of the following comments:
117 *
118 * p = 3329
119 *
120 * def bitreverse(i):
121 * ret = 0
122 * for n in range(7):
123 * bit = i & 1
124 * ret <<= 1
125 * ret |= bit
126 * i >>= 1
127 * return ret
128 */
129
130/* kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)] */
131static const uint16_t kNTTRoots[128] = {
132 1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797,
133 2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333,
134 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756,
135 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915,
136 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648,
137 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100,
138 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789,
139 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641,
140 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757,
141 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886,
142 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154,
143};
144
145/* kInverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)] */
146static const uint16_t kInverseNTTRoots[128] = {
147 1, 1600, 40, 749, 2481, 1432, 2699, 687, 1583, 2760, 69, 543,
148 2532, 3136, 1410, 2267, 2508, 1355, 450, 936, 447, 2794, 1235, 1903,
149 1996, 1089, 3273, 283, 1853, 1990, 882, 3033, 2419, 2102, 219, 855,
150 2681, 1848, 712, 682, 927, 1795, 461, 1891, 2877, 2522, 1894, 1010,
151 1414, 2009, 3296, 464, 2697, 816, 1352, 2679, 1274, 1052, 1025, 2132,
152 1573, 76, 2998, 3040, 1175, 2444, 394, 1219, 2300, 1455, 2117, 1607,
153 2443, 554, 1179, 2186, 2303, 2926, 2237, 525, 735, 863, 2768, 1230,
154 2572, 556, 3010, 2266, 1684, 1239, 780, 2954, 109, 1292, 1031, 1745,
155 2688, 3061, 992, 2596, 941, 892, 1021, 2390, 642, 1868, 2377, 1482,
156 1540, 540, 1678, 1626, 279, 314, 1173, 2573, 3096, 48, 667, 1920,
157 2229, 1041, 2606, 1692, 680, 2746, 568, 3312,
158};
159
160/* kModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)] */
161static const uint16_t kModRoots[128] = {
162 17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606,
163 2288, 1041, 1100, 2229, 1409, 1920, 2662, 667, 3281, 48, 233, 3096,
164 756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678,
165 2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642,
166 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992,
167 268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109,
168 375, 2954, 2549, 780, 2090, 1239, 1645, 1684, 1063, 2266, 319, 3010,
169 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735,
170 2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179,
171 2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300,
172 2110, 1219, 2935, 394, 885, 2444, 2154, 1175,
173};
174
175/* reduce_once reduces 0 <= x < 2*kPrime, mod kPrime. */
176static uint16_t
177reduce_once(uint16_t x)
178{
179 assert(x < 2 * kPrime);
180 const uint16_t subtracted = x - kPrime;
181 uint16_t mask = 0u - (subtracted >> 15);
182
183 /*
184 * Although this is a constant-time select, we omit a value barrier here.
185 * Value barriers impede auto-vectorization (likely because it forces the
186 * value to transit through a general-purpose register). On AArch64, this
187 * is a difference of 2x.
188 *
189 * We usually add value barriers to selects because Clang turns
190 * consecutive selects with the same condition into a branch instead of
191 * CMOV/CSEL. This condition does not occur in ML-KEM, so omitting it
192 * seems to be safe so far but see
193 * |scalar_centered_binomial_distribution_eta_2_with_prf|.
194 */
195 return (mask & x) | (~mask & subtracted);
196}
197
198/*
199 * constant time reduce x mod kPrime using Barrett reduction. x must be less
200 * than kPrime + 2×kPrime².
201 */
202static uint16_t
203reduce(uint32_t x)
204{
205 uint64_t product = (uint64_t)x * kBarrettMultiplier;
206 uint32_t quotient = (uint32_t)(product >> kBarrettShift);
207 uint32_t remainder = x - quotient * kPrime;
208
209 assert(x < kPrime + 2u * kPrime * kPrime);
210 return reduce_once(remainder);
211}
212
213static void
214scalar_zero(scalar *out)
215{
216 memset(out, 0, sizeof(*out));
217}
218
219static void
220vector_zero(vector *out)
221{
222 memset(out, 0, sizeof(*out));
223}
224
225/*
226 * In place number theoretic transform of a given scalar.
227 * Note that MLKEM's kPrime 3329 does not have a 512th root of unity, so this
228 * transform leaves off the last iteration of the usual FFT code, with the 128
229 * relevant roots of unity being stored in |kNTTRoots|. This means the output
230 * should be seen as 128 elements in GF(3329^2), with the coefficients of the
231 * elements being consecutive entries in |s->c|.
232 */
233static void
234scalar_ntt(scalar *s)
235{
236 int offset = DEGREE;
237 int step;
238 /*
239 * `int` is used here because using `size_t` throughout caused a ~5% slowdown
240 * with Clang 14 on Aarch64.
241 */
242 for (step = 1; step < DEGREE / 2; step <<= 1) {
243 int i, j, k = 0;
244
245 offset >>= 1;
246 for (i = 0; i < step; i++) {
247 const uint32_t step_root = kNTTRoots[i + step];
248
249 for (j = k; j < k + offset; j++) {
250 uint16_t odd, even;
251
252 odd = reduce(step_root * s->c[j + offset]);
253 even = s->c[j];
254 s->c[j] = reduce_once(odd + even);
255 s->c[j + offset] = reduce_once(even - odd +
256 kPrime);
257 }
258 k += 2 * offset;
259 }
260 }
261}
262
263static void
264vector_ntt(vector *a)
265{
266 int i;
267
268 for (i = 0; i < RANK768; i++) {
269 scalar_ntt(&a->v[i]);
270 }
271}
272
273/*
274 * In place inverse number theoretic transform of a given scalar, with pairs of
275 * entries of s->v being interpreted as elements of GF(3329^2). Just as with the
276 * number theoretic transform, this leaves off the first step of the normal iFFT
277 * to account for the fact that 3329 does not have a 512th root of unity, using
278 * the precomputed 128 roots of unity stored in |kInverseNTTRoots|.
279 */
280static void
281scalar_inverse_ntt(scalar *s)
282{
283 int i, j, k, offset, step = DEGREE / 2;
284
285 /*
286 * `int` is used here because using `size_t` throughout caused a ~5% slowdown
287 * with Clang 14 on Aarch64.
288 */
289 for (offset = 2; offset < DEGREE; offset <<= 1) {
290 step >>= 1;
291 k = 0;
292 for (i = 0; i < step; i++) {
293 uint32_t step_root = kInverseNTTRoots[i + step];
294 for (j = k; j < k + offset; j++) {
295 uint16_t odd, even;
296 odd = s->c[j + offset];
297 even = s->c[j];
298 s->c[j] = reduce_once(odd + even);
299 s->c[j + offset] = reduce(step_root *
300 (even - odd + kPrime));
301 }
302 k += 2 * offset;
303 }
304 }
305 for (i = 0; i < DEGREE; i++) {
306 s->c[i] = reduce(s->c[i] * kInverseDegree);
307 }
308}
309
310static void
311vector_inverse_ntt(vector *a)
312{
313 int i;
314
315 for (i = 0; i < RANK768; i++) {
316 scalar_inverse_ntt(&a->v[i]);
317 }
318}
319
320static void
321scalar_add(scalar *lhs, const scalar *rhs)
322{
323 int i;
324
325 for (i = 0; i < DEGREE; i++) {
326 lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
327 }
328}
329
330static void
331scalar_sub(scalar *lhs, const scalar *rhs)
332{
333 int i;
334
335 for (i = 0; i < DEGREE; i++) {
336 lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
337 }
338}
339
340/*
341 * Multiplying two scalars in the number theoretically transformed state.
342 * Since 3329 does not have a 512th root of unity, this means we have to
343 * interpret the 2*ith and (2*i+1)th entries of the scalar as elements of
344 * GF(3329)[X]/(X^2 - 17^(2*bitreverse(i)+1)).
345 * The value of 17^(2*bitreverse(i)+1) mod 3329 is stored in the precomputed
346 * |kModRoots| table. Our Barrett transform only allows us to multiply two
347 * reduced numbers together, so we need some intermediate reduction steps,
348 * even if an uint64_t could hold 3 multiplied numbers.
349 */
350static void
351scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs)
352{
353 int i;
354
355 for (i = 0; i < DEGREE / 2; i++) {
356 uint32_t real_real = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i];
357 uint32_t img_img = (uint32_t)lhs->c[2 * i + 1] *
358 rhs->c[2 * i + 1];
359 uint32_t real_img = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i + 1];
360 uint32_t img_real = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i];
361
362 out->c[2 * i] =
363 reduce(real_real +
364 (uint32_t)reduce(img_img) * kModRoots[i]);
365 out->c[2 * i + 1] = reduce(img_real + real_img);
366 }
367}
368
369static void
370vector_add(vector *lhs, const vector *rhs)
371{
372 int i;
373
374 for (i = 0; i < RANK768; i++) {
375 scalar_add(&lhs->v[i], &rhs->v[i]);
376 }
377}
378
379static void
380matrix_mult(vector *out, const matrix *m, const vector *a)
381{
382 int i, j;
383
384 vector_zero(out);
385 for (i = 0; i < RANK768; i++) {
386 for (j = 0; j < RANK768; j++) {
387 scalar product;
388
389 scalar_mult(&product, &m->v[i][j], &a->v[j]);
390 scalar_add(&out->v[i], &product);
391 }
392 }
393}
394
395static void
396matrix_mult_transpose(vector *out, const matrix *m,
397 const vector *a)
398{
399 int i, j;
400
401 vector_zero(out);
402 for (i = 0; i < RANK768; i++) {
403 for (j = 0; j < RANK768; j++) {
404 scalar product;
405
406 scalar_mult(&product, &m->v[j][i], &a->v[j]);
407 scalar_add(&out->v[i], &product);
408 }
409 }
410}
411
412static void
413scalar_inner_product(scalar *out, const vector *lhs,
414 const vector *rhs)
415{
416 int i;
417 scalar_zero(out);
418 for (i = 0; i < RANK768; i++) {
419 scalar product;
420
421 scalar_mult(&product, &lhs->v[i], &rhs->v[i]);
422 scalar_add(out, &product);
423 }
424}
425
426/*
427 * Algorithm 6 of spec. Rejection samples a Keccak stream to get uniformly
428 * distributed elements. This is used for matrix expansion and only operates on
429 * public inputs.
430 */
431static void
432scalar_from_keccak_vartime(scalar *out, sha3_ctx *keccak_ctx)
433{
434 int i, done = 0;
435
436 while (done < DEGREE) {
437 uint8_t block[168];
438
439 shake_out(keccak_ctx, block, sizeof(block));
440 for (i = 0; i < sizeof(block) && done < DEGREE; i += 3) {
441 uint16_t d1 = block[i] + 256 * (block[i + 1] % 16);
442 uint16_t d2 = block[i + 1] / 16 + 16 * block[i + 2];
443
444 if (d1 < kPrime) {
445 out->c[done++] = d1;
446 }
447 if (d2 < kPrime && done < DEGREE) {
448 out->c[done++] = d2;
449 }
450 }
451 }
452}
453
454/*
455 * Algorithm 7 of the spec, with eta fixed to two and the PRF call
456 * included. Creates binominally distributed elements by sampling 2*|eta| bits,
457 * and setting the coefficient to the count of the first bits minus the count of
458 * the second bits, resulting in a centered binomial distribution. Since eta is
459 * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
460 * and 0 with probability 3/8.
461 */
462static void
463scalar_centered_binomial_distribution_eta_2_with_prf(scalar *out,
464 const uint8_t input[33])
465{
466 uint8_t entropy[128];
467 int i;
468
469 CTASSERT(sizeof(entropy) == 2 * /*kEta=*/ 2 * DEGREE / 8);
470 prf(entropy, sizeof(entropy), input);
471
472 for (i = 0; i < DEGREE; i += 2) {
473 uint8_t byte = entropy[i / 2];
474 uint16_t mask;
475 uint16_t value = (byte & 1) + ((byte >> 1) & 1);
476
477 value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
478
479 /*
480 * Add |kPrime| if |value| underflowed. See |reduce_once| for a
481 * discussion on why the value barrier is omitted. While this
482 * could have been written reduce_once(value + kPrime), this is
483 * one extra addition and small range of |value| tempts some
484 * versions of Clang to emit a branch.
485 */
486 mask = 0u - (value >> 15);
487 out->c[i] = ((value + kPrime) & mask) | (value & ~mask);
488
489 byte >>= 4;
490 value = (byte & 1) + ((byte >> 1) & 1);
491 value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
492 /* See above. */
493 mask = 0u - (value >> 15);
494 out->c[i + 1] = ((value + kPrime) & mask) | (value & ~mask);
495 }
496}
497
498/*
499 * Generates a secret vector by using
500 * |scalar_centered_binomial_distribution_eta_2_with_prf|, using the given seed
501 * appending and incrementing |counter| for entry of the vector.
502 */
503static void
504vector_generate_secret_eta_2(vector *out, uint8_t *counter,
505 const uint8_t seed[32])
506{
507 uint8_t input[33];
508 int i;
509
510 memcpy(input, seed, 32);
511 for (i = 0; i < RANK768; i++) {
512 input[32] = (*counter)++;
513 scalar_centered_binomial_distribution_eta_2_with_prf(&out->v[i],
514 input);
515 }
516}
517
518/* Expands the matrix of a seed for key generation and for encaps-CPA. */
519static void
520matrix_expand(matrix *out, const uint8_t rho[32])
521{
522 uint8_t input[34];
523 int i, j;
524
525 memcpy(input, rho, 32);
526 for (i = 0; i < RANK768; i++) {
527 for (j = 0; j < RANK768; j++) {
528 sha3_ctx keccak_ctx;
529
530 input[32] = i;
531 input[33] = j;
532 shake128_init(&keccak_ctx);
533 shake_update(&keccak_ctx, input, sizeof(input));
534 shake_xof(&keccak_ctx);
535 scalar_from_keccak_vartime(&out->v[i][j], &keccak_ctx);
536 }
537 }
538}
539
540static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f,
541 0x1f, 0x3f, 0x7f, 0xff};
542
543static void
544scalar_encode(uint8_t *out, const scalar *s, int bits)
545{
546 uint8_t out_byte = 0;
547 int i, out_byte_bits = 0;
548
549 assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1);
550 for (i = 0; i < DEGREE; i++) {
551 uint16_t element = s->c[i];
552 int element_bits_done = 0;
553
554 while (element_bits_done < bits) {
555 int chunk_bits = bits - element_bits_done;
556 int out_bits_remaining = 8 - out_byte_bits;
557
558 if (chunk_bits >= out_bits_remaining) {
559 chunk_bits = out_bits_remaining;
560 out_byte |= (element &
561 kMasks[chunk_bits - 1]) << out_byte_bits;
562 *out = out_byte;
563 out++;
564 out_byte_bits = 0;
565 out_byte = 0;
566 } else {
567 out_byte |= (element &
568 kMasks[chunk_bits - 1]) << out_byte_bits;
569 out_byte_bits += chunk_bits;
570 }
571
572 element_bits_done += chunk_bits;
573 element >>= chunk_bits;
574 }
575 }
576
577 if (out_byte_bits > 0) {
578 *out = out_byte;
579 }
580}
581
582/* scalar_encode_1 is |scalar_encode| specialised for |bits| == 1. */
583static void
584scalar_encode_1(uint8_t out[32], const scalar *s)
585{
586 int i, j;
587
588 for (i = 0; i < DEGREE; i += 8) {
589 uint8_t out_byte = 0;
590
591 for (j = 0; j < 8; j++) {
592 out_byte |= (s->c[i + j] & 1) << j;
593 }
594 *out = out_byte;
595 out++;
596 }
597}
598
599/*
600 * Encodes an entire vector into 32*|RANK768|*|bits| bytes. Note that since 256
601 * (DEGREE) is divisible by 8, the individual vector entries will always fill a
602 * whole number of bytes, so we do not need to worry about bit packing here.
603 */
604static void
605vector_encode(uint8_t *out, const vector *a, int bits)
606{
607 int i;
608
609 for (i = 0; i < RANK768; i++) {
610 scalar_encode(out + i * bits * DEGREE / 8, &a->v[i], bits);
611 }
612}
613
614/*
615 * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
616 * |out|. It returns one on success and zero if any parsed value is >=
617 * |kPrime|.
618 */
619static int
620scalar_decode(scalar *out, const uint8_t *in, int bits)
621{
622 uint8_t in_byte = 0;
623 int i, in_byte_bits_left = 0;
624
625 assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1);
626
627 for (i = 0; i < DEGREE; i++) {
628 uint16_t element = 0;
629 int element_bits_done = 0;
630
631 while (element_bits_done < bits) {
632 int chunk_bits = bits - element_bits_done;
633
634 if (in_byte_bits_left == 0) {
635 in_byte = *in;
636 in++;
637 in_byte_bits_left = 8;
638 }
639
640 if (chunk_bits > in_byte_bits_left) {
641 chunk_bits = in_byte_bits_left;
642 }
643
644 element |= (in_byte & kMasks[chunk_bits - 1]) <<
645 element_bits_done;
646 in_byte_bits_left -= chunk_bits;
647 in_byte >>= chunk_bits;
648
649 element_bits_done += chunk_bits;
650 }
651
652 if (element >= kPrime) {
653 return 0;
654 }
655 out->c[i] = element;
656 }
657
658 return 1;
659}
660
661/* scalar_decode_1 is |scalar_decode| specialised for |bits| == 1. */
662static void
663scalar_decode_1(scalar *out, const uint8_t in[32])
664{
665 int i, j;
666
667 for (i = 0; i < DEGREE; i += 8) {
668 uint8_t in_byte = *in;
669
670 in++;
671 for (j = 0; j < 8; j++) {
672 out->c[i + j] = in_byte & 1;
673 in_byte >>= 1;
674 }
675 }
676}
677
678/*
679 * Decodes 32*|RANK768|*|bits| bytes from |in| into |out|. It returns one on
680 * success or zero if any parsed value is >= |kPrime|.
681 */
682static int
683vector_decode(vector *out, const uint8_t *in, int bits)
684{
685 int i;
686
687 for (i = 0; i < RANK768; i++) {
688 if (!scalar_decode(&out->v[i], in + i * bits * DEGREE / 8,
689 bits)) {
690 return 0;
691 }
692 }
693 return 1;
694}
695
696/*
697 * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
698 * numbers close to each other together. The formula used is
699 * round(2^|bits|/kPrime*x) mod 2^|bits|.
700 * Uses Barrett reduction to achieve constant time. Since we need both the
701 * remainder (for rounding) and the quotient (as the result), we cannot use
702 * |reduce| here, but need to do the Barrett reduction directly.
703 */
704static uint16_t
705compress(uint16_t x, int bits)
706{
707 uint32_t shifted = (uint32_t)x << bits;
708 uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
709 uint32_t quotient = (uint32_t)(product >> kBarrettShift);
710 uint32_t remainder = shifted - quotient * kPrime;
711
712 /*
713 * Adjust the quotient to round correctly:
714 * 0 <= remainder <= kHalfPrime round to 0
715 * kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
716 * kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
717 */
718 assert(remainder < 2u * kPrime);
719 quotient += 1 & constant_time_lt(kHalfPrime, remainder);
720 quotient += 1 & constant_time_lt(kPrime + kHalfPrime, remainder);
721 return quotient & ((1 << bits) - 1);
722}
723
724/*
725 * Decompresses |x| by using an equi-distant representative. The formula is
726 * round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us to
727 * implement this logic using only bit operations.
728 */
729static uint16_t
730decompress(uint16_t x, int bits)
731{
732 uint32_t product = (uint32_t)x * kPrime;
733 uint32_t power = 1 << bits;
734 /* This is |product| % power, since |power| is a power of 2. */
735 uint32_t remainder = product & (power - 1);
736 /* This is |product| / power, since |power| is a power of 2. */
737 uint32_t lower = product >> bits;
738
739 /*
740 * The rounding logic works since the first half of numbers mod |power| have a
741 * 0 as first bit, and the second half has a 1 as first bit, since |power| is
742 * a power of 2. As a 12 bit number, |remainder| is always positive, so we
743 * will shift in 0s for a right shift.
744 */
745 return lower + (remainder >> (bits - 1));
746}
747
748static void
749scalar_compress(scalar *s, int bits)
750{
751 int i;
752
753 for (i = 0; i < DEGREE; i++) {
754 s->c[i] = compress(s->c[i], bits);
755 }
756}
757
758static void
759scalar_decompress(scalar *s, int bits)
760{
761 int i;
762
763 for (i = 0; i < DEGREE; i++) {
764 s->c[i] = decompress(s->c[i], bits);
765 }
766}
767
768static void
769vector_compress(vector *a, int bits)
770{
771 int i;
772
773 for (i = 0; i < RANK768; i++) {
774 scalar_compress(&a->v[i], bits);
775 }
776}
777
778static void
779vector_decompress(vector *a, int bits)
780{
781 int i;
782
783 for (i = 0; i < RANK768; i++) {
784 scalar_decompress(&a->v[i], bits);
785 }
786}
787
788struct public_key {
789 vector t;
790 uint8_t rho[32];
791 uint8_t public_key_hash[32];
792 matrix m;
793};
794
795static struct public_key *
796public_key_768_from_external(const struct MLKEM768_public_key *external)
797{
798 return (struct public_key *)external;
799}
800
801struct private_key {
802 struct public_key pub;
803 vector s;
804 uint8_t fo_failure_secret[32];
805};
806
807static struct private_key *
808private_key_768_from_external(const struct MLKEM768_private_key *external)
809{
810 return (struct private_key *)external;
811}
812
813/*
814 * Calls |MLKEM768_generate_key_external_entropy| with random bytes from
815 * |RAND_bytes|.
816 */
817void
818MLKEM768_generate_key(uint8_t out_encoded_public_key[MLKEM768_PUBLIC_KEY_BYTES],
819 uint8_t optional_out_seed[MLKEM_SEED_BYTES],
820 struct MLKEM768_private_key *out_private_key)
821{
822 uint8_t entropy_buf[MLKEM_SEED_BYTES];
823 uint8_t *entropy = optional_out_seed != NULL ? optional_out_seed :
824 entropy_buf;
825
826 arc4random_buf(entropy, MLKEM_SEED_BYTES);
827 MLKEM768_generate_key_external_entropy(out_encoded_public_key,
828 out_private_key, entropy);
829}
830LCRYPTO_ALIAS(MLKEM768_generate_key);
831
832int
833MLKEM768_private_key_from_seed(struct MLKEM768_private_key *out_private_key,
834 const uint8_t *seed, size_t seed_len)
835{
836 uint8_t public_key_bytes[MLKEM768_PUBLIC_KEY_BYTES];
837
838 if (seed_len != MLKEM_SEED_BYTES) {
839 return 0;
840 }
841 MLKEM768_generate_key_external_entropy(public_key_bytes,
842 out_private_key, seed);
843
844 return 1;
845}
846LCRYPTO_ALIAS(MLKEM768_private_key_from_seed);
847
848static int
849mlkem_marshal_public_key(CBB *out, const struct public_key *pub)
850{
851 uint8_t *vector_output;
852
853 if (!CBB_add_space(out, &vector_output, kEncodedVectorSize)) {
854 return 0;
855 }
856 vector_encode(vector_output, &pub->t, kLog2Prime);
857 if (!CBB_add_bytes(out, pub->rho, sizeof(pub->rho))) {
858 return 0;
859 }
860 return 1;
861}
862
863void
864MLKEM768_generate_key_external_entropy(
865 uint8_t out_encoded_public_key[MLKEM768_PUBLIC_KEY_BYTES],
866 struct MLKEM768_private_key *out_private_key,
867 const uint8_t entropy[MLKEM_SEED_BYTES])
868{
869 struct private_key *priv = private_key_768_from_external(
870 out_private_key);
871 uint8_t augmented_seed[33];
872 uint8_t *rho, *sigma;
873 uint8_t counter = 0;
874 uint8_t hashed[64];
875 vector error;
876 CBB cbb;
877
878 memcpy(augmented_seed, entropy, 32);
879 augmented_seed[32] = RANK768;
880 hash_g(hashed, augmented_seed, 33);
881 rho = hashed;
882 sigma = hashed + 32;
883 memcpy(priv->pub.rho, hashed, sizeof(priv->pub.rho));
884 matrix_expand(&priv->pub.m, rho);
885 vector_generate_secret_eta_2(&priv->s, &counter, sigma);
886 vector_ntt(&priv->s);
887 vector_generate_secret_eta_2(&error, &counter, sigma);
888 vector_ntt(&error);
889 matrix_mult_transpose(&priv->pub.t, &priv->pub.m, &priv->s);
890 vector_add(&priv->pub.t, &error);
891
892 /* XXX - error checking */
893 CBB_init_fixed(&cbb, out_encoded_public_key, MLKEM768_PUBLIC_KEY_BYTES);
894 if (!mlkem_marshal_public_key(&cbb, &priv->pub)) {
895 abort();
896 }
897 CBB_cleanup(&cbb);
898
899 hash_h(priv->pub.public_key_hash, out_encoded_public_key,
900 MLKEM768_PUBLIC_KEY_BYTES);
901 memcpy(priv->fo_failure_secret, entropy + 32, 32);
902}
903
904void
905MLKEM768_public_from_private(struct MLKEM768_public_key *out_public_key,
906 const struct MLKEM768_private_key *private_key)
907{
908 struct public_key *const pub = public_key_768_from_external(
909 out_public_key);
910 const struct private_key *const priv = private_key_768_from_external(
911 private_key);
912
913 *pub = priv->pub;
914}
915LCRYPTO_ALIAS(MLKEM768_public_from_private);
916
917/*
918 * Encrypts a message with given randomness to the ciphertext in |out|. Without
919 * applying the Fujisaki-Okamoto transform this would not result in a CCA secure
920 * scheme, since lattice schemes are vulnerable to decryption failure oracles.
921 */
922static void
923encrypt_cpa(uint8_t out[MLKEM768_CIPHERTEXT_BYTES],
924 const struct public_key *pub, const uint8_t message[32],
925 const uint8_t randomness[32])
926{
927 scalar expanded_message, scalar_error;
928 vector secret, error, u;
929 uint8_t counter = 0;
930 uint8_t input[33];
931 scalar v;
932
933 vector_generate_secret_eta_2(&secret, &counter, randomness);
934 vector_ntt(&secret);
935 vector_generate_secret_eta_2(&error, &counter, randomness);
936 memcpy(input, randomness, 32);
937 input[32] = counter;
938 scalar_centered_binomial_distribution_eta_2_with_prf(&scalar_error,
939 input);
940 matrix_mult(&u, &pub->m, &secret);
941 vector_inverse_ntt(&u);
942 vector_add(&u, &error);
943 scalar_inner_product(&v, &pub->t, &secret);
944 scalar_inverse_ntt(&v);
945 scalar_add(&v, &scalar_error);
946 scalar_decode_1(&expanded_message, message);
947 scalar_decompress(&expanded_message, 1);
948 scalar_add(&v, &expanded_message);
949 vector_compress(&u, kDU768);
950 vector_encode(out, &u, kDU768);
951 scalar_compress(&v, kDV768);
952 scalar_encode(out + kCompressedVectorSize, &v, kDV768);
953}
954
955/* Calls MLKEM768_encap_external_entropy| with random bytes */
956void
957MLKEM768_encap(uint8_t out_ciphertext[MLKEM768_CIPHERTEXT_BYTES],
958 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
959 const struct MLKEM768_public_key *public_key)
960{
961 uint8_t entropy[MLKEM_ENCAP_ENTROPY];
962
963 arc4random_buf(entropy, MLKEM_ENCAP_ENTROPY);
964 MLKEM768_encap_external_entropy(out_ciphertext, out_shared_secret,
965 public_key, entropy);
966}
967LCRYPTO_ALIAS(MLKEM768_encap);
968
969/* See section 6.2 of the spec. */
970void
971MLKEM768_encap_external_entropy(
972 uint8_t out_ciphertext[MLKEM768_CIPHERTEXT_BYTES],
973 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
974 const struct MLKEM768_public_key *public_key,
975 const uint8_t entropy[MLKEM_ENCAP_ENTROPY])
976{
977 const struct public_key *pub = public_key_768_from_external(public_key);
978 uint8_t key_and_randomness[64];
979 uint8_t input[64];
980
981 memcpy(input, entropy, MLKEM_ENCAP_ENTROPY);
982 memcpy(input + MLKEM_ENCAP_ENTROPY, pub->public_key_hash,
983 sizeof(input) - MLKEM_ENCAP_ENTROPY);
984 hash_g(key_and_randomness, input, sizeof(input));
985 encrypt_cpa(out_ciphertext, pub, entropy, key_and_randomness + 32);
986 memcpy(out_shared_secret, key_and_randomness, 32);
987}
988
989static void
990decrypt_cpa(uint8_t out[32], const struct private_key *priv,
991 const uint8_t ciphertext[MLKEM768_CIPHERTEXT_BYTES])
992{
993 scalar mask, v;
994 vector u;
995
996 vector_decode(&u, ciphertext, kDU768);
997 vector_decompress(&u, kDU768);
998 vector_ntt(&u);
999 scalar_decode(&v, ciphertext + kCompressedVectorSize, kDV768);
1000 scalar_decompress(&v, kDV768);
1001 scalar_inner_product(&mask, &priv->s, &u);
1002 scalar_inverse_ntt(&mask);
1003 scalar_sub(&v, &mask);
1004 scalar_compress(&v, 1);
1005 scalar_encode_1(out, &v);
1006}
1007
1008/* See section 6.3 */
1009int
1010MLKEM768_decap(uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
1011 const uint8_t *ciphertext, size_t ciphertext_len,
1012 const struct MLKEM768_private_key *private_key)
1013{
1014 const struct private_key *priv = private_key_768_from_external(
1015 private_key);
1016 uint8_t expected_ciphertext[MLKEM768_CIPHERTEXT_BYTES];
1017 uint8_t key_and_randomness[64];
1018 uint8_t failure_key[32];
1019 uint8_t decrypted[64];
1020 uint8_t mask;
1021 int i;
1022
1023 if (ciphertext_len != MLKEM768_CIPHERTEXT_BYTES) {
1024 arc4random_buf(out_shared_secret, MLKEM_SHARED_SECRET_BYTES);
1025 return 0;
1026 }
1027
1028 decrypt_cpa(decrypted, priv, ciphertext);
1029 memcpy(decrypted + 32, priv->pub.public_key_hash,
1030 sizeof(decrypted) - 32);
1031 hash_g(key_and_randomness, decrypted, sizeof(decrypted));
1032 encrypt_cpa(expected_ciphertext, &priv->pub, decrypted,
1033 key_and_randomness + 32);
1034 kdf(failure_key, priv->fo_failure_secret, ciphertext, ciphertext_len);
1035 mask = constant_time_eq_int_8(memcmp(ciphertext, expected_ciphertext,
1036 sizeof(expected_ciphertext)), 0);
1037 for (i = 0; i < MLKEM_SHARED_SECRET_BYTES; i++) {
1038 out_shared_secret[i] = constant_time_select_8(mask,
1039 key_and_randomness[i], failure_key[i]);
1040 }
1041
1042 return 1;
1043}
1044LCRYPTO_ALIAS(MLKEM768_decap);
1045
1046int
1047MLKEM768_marshal_public_key(CBB *out,
1048 const struct MLKEM768_public_key *public_key)
1049{
1050 return mlkem_marshal_public_key(out,
1051 public_key_768_from_external(public_key));
1052}
1053LCRYPTO_ALIAS(MLKEM768_marshal_public_key);
1054
1055/*
1056 * mlkem_parse_public_key_no_hash parses |in| into |pub| but doesn't calculate
1057 * the value of |pub->public_key_hash|.
1058 */
1059static int
1060mlkem_parse_public_key_no_hash(struct public_key *pub, CBS *in)
1061{
1062 CBS t_bytes;
1063
1064 if (!CBS_get_bytes(in, &t_bytes, kEncodedVectorSize) ||
1065 !vector_decode(&pub->t, CBS_data(&t_bytes), kLog2Prime)) {
1066 return 0;
1067 }
1068 memcpy(pub->rho, CBS_data(in), sizeof(pub->rho));
1069 if (!CBS_skip(in, sizeof(pub->rho)))
1070 return 0;
1071 matrix_expand(&pub->m, pub->rho);
1072 return 1;
1073}
1074
1075int
1076MLKEM768_parse_public_key(struct MLKEM768_public_key *public_key, CBS *in)
1077{
1078 struct public_key *pub = public_key_768_from_external(public_key);
1079 CBS orig_in = *in;
1080
1081 if (!mlkem_parse_public_key_no_hash(pub, in) ||
1082 CBS_len(in) != 0) {
1083 return 0;
1084 }
1085 hash_h(pub->public_key_hash, CBS_data(&orig_in), CBS_len(&orig_in));
1086 return 1;
1087}
1088LCRYPTO_ALIAS(MLKEM768_parse_public_key);
1089
1090int
1091MLKEM768_marshal_private_key(CBB *out,
1092 const struct MLKEM768_private_key *private_key)
1093{
1094 const struct private_key *const priv = private_key_768_from_external(
1095 private_key);
1096 uint8_t *s_output;
1097
1098 if (!CBB_add_space(out, &s_output, kEncodedVectorSize)) {
1099 return 0;
1100 }
1101 vector_encode(s_output, &priv->s, kLog2Prime);
1102 if (!mlkem_marshal_public_key(out, &priv->pub) ||
1103 !CBB_add_bytes(out, priv->pub.public_key_hash,
1104 sizeof(priv->pub.public_key_hash)) ||
1105 !CBB_add_bytes(out, priv->fo_failure_secret,
1106 sizeof(priv->fo_failure_secret))) {
1107 return 0;
1108 }
1109 return 1;
1110}
1111
1112int
1113MLKEM768_parse_private_key(struct MLKEM768_private_key *out_private_key,
1114 CBS *in)
1115{
1116 struct private_key *const priv = private_key_768_from_external(
1117 out_private_key);
1118 CBS s_bytes;
1119
1120 if (!CBS_get_bytes(in, &s_bytes, kEncodedVectorSize) ||
1121 !vector_decode(&priv->s, CBS_data(&s_bytes), kLog2Prime) ||
1122 !mlkem_parse_public_key_no_hash(&priv->pub, in)) {
1123 return 0;
1124 }
1125 memcpy(priv->pub.public_key_hash, CBS_data(in),
1126 sizeof(priv->pub.public_key_hash));
1127 if (!CBS_skip(in, sizeof(priv->pub.public_key_hash)))
1128 return 0;
1129 memcpy(priv->fo_failure_secret, CBS_data(in),
1130 sizeof(priv->fo_failure_secret));
1131 if (!CBS_skip(in, sizeof(priv->fo_failure_secret)))
1132 return 0;
1133 if (CBS_len(in) != 0)
1134 return 0;
1135
1136 return 1;
1137}
1138LCRYPTO_ALIAS(MLKEM768_parse_private_key);
diff --git a/src/lib/libcrypto/mlkem/mlkem_internal.h b/src/lib/libcrypto/mlkem/mlkem_internal.h
deleted file mode 100644
index d3f325932f..0000000000
--- a/src/lib/libcrypto/mlkem/mlkem_internal.h
+++ /dev/null
@@ -1,121 +0,0 @@
1/* $OpenBSD: mlkem_internal.h,v 1.4 2024/12/19 23:52:26 tb Exp $ */
2/*
3 * Copyright (c) 2023, Google Inc.
4 *
5 * Permission to use, copy, modify, and/or distribute this software for any
6 * purpose with or without fee is hereby granted, provided that the above
7 * copyright notice and this permission notice appear in all copies.
8 *
9 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
12 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
14 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
15 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16 */
17
18#ifndef OPENSSL_HEADER_CRYPTO_MLKEM_INTERNAL_H
19#define OPENSSL_HEADER_CRYPTO_MLKEM_INTERNAL_H
20
21#include "bytestring.h"
22#include "mlkem.h"
23
24#if defined(__cplusplus)
25extern "C" {
26#endif
27
28__BEGIN_HIDDEN_DECLS
29
30/*
31 * MLKEM_ENCAP_ENTROPY is the number of bytes of uniformly random entropy
32 * necessary to encapsulate a secret. The entropy will be leaked to the
33 * decapsulating party.
34 */
35#define MLKEM_ENCAP_ENTROPY 32
36
37/*
38 * MLKEM768_generate_key_external_entropy is a deterministic function to create a
39 * pair of ML-KEM 768 keys, using the supplied entropy. The entropy needs to be
40 * uniformly random generated. This function is should only be used for tests,
41 * regular callers should use the non-deterministic |MLKEM_generate_key|
42 * directly.
43 */
44void MLKEM768_generate_key_external_entropy(
45 uint8_t out_encoded_public_key[MLKEM768_PUBLIC_KEY_BYTES],
46 struct MLKEM768_private_key *out_private_key,
47 const uint8_t entropy[MLKEM_SEED_BYTES]);
48
49/*
50 * MLKEM768_PRIVATE_KEY_BYTES is the length of the data produced by
51 * |MLKEM768_marshal_private_key|.
52 */
53#define MLKEM768_PRIVATE_KEY_BYTES 2400
54
55/*
56 * MLKEM768_marshal_private_key serializes |private_key| to |out| in the standard
57 * format for ML-KEM private keys. It returns one on success or zero on
58 * allocation error.
59 */
60int MLKEM768_marshal_private_key(CBB *out,
61 const struct MLKEM768_private_key *private_key);
62
63/*
64 * MLKEM_encap_external_entropy behaves like |MLKEM_encap|, but uses
65 * |MLKEM_ENCAP_ENTROPY| bytes of |entropy| for randomization. The decapsulating
66 * side will be able to recover |entropy| in full. This function should only be
67 * used for tests, regular callers should use the non-deterministic
68 * |MLKEM_encap| directly.
69 */
70void MLKEM768_encap_external_entropy(
71 uint8_t out_ciphertext[MLKEM768_CIPHERTEXT_BYTES],
72 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
73 const struct MLKEM768_public_key *public_key,
74 const uint8_t entropy[MLKEM_ENCAP_ENTROPY]);
75
76/*
77 * MLKEM1024_generate_key_external_entropy is a deterministic function to create a
78 * pair of ML-KEM 1024 keys, using the supplied entropy. The entropy needs to be
79 * uniformly random generated. This function is should only be used for tests,
80 * regular callers should use the non-deterministic |MLKEM_generate_key|
81 * directly.
82 */
83void MLKEM1024_generate_key_external_entropy(
84 uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES],
85 struct MLKEM1024_private_key *out_private_key,
86 const uint8_t entropy[MLKEM_SEED_BYTES]);
87
88/*
89 * MLKEM1024_PRIVATE_KEY_BYTES is the length of the data produced by
90 * |MLKEM1024_marshal_private_key|.
91 */
92#define MLKEM1024_PRIVATE_KEY_BYTES 3168
93
94/*
95 * MLKEM1024_marshal_private_key serializes |private_key| to |out| in the
96 * standard format for ML-KEM private keys. It returns one on success or zero on
97 * allocation error.
98 */
99int MLKEM1024_marshal_private_key(CBB *out,
100 const struct MLKEM1024_private_key *private_key);
101
102/*
103 * MLKEM_encap_external_entropy behaves like |MLKEM_encap|, but uses
104 * |MLKEM_ENCAP_ENTROPY| bytes of |entropy| for randomization. The decapsulating
105 * side will be able to recover |entropy| in full. This function should only be
106 * used for tests, regular callers should use the non-deterministic
107 * |MLKEM_encap| directly.
108 */
109void MLKEM1024_encap_external_entropy(
110 uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES],
111 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
112 const struct MLKEM1024_public_key *public_key,
113 const uint8_t entropy[MLKEM_ENCAP_ENTROPY]);
114
115__END_HIDDEN_DECLS
116
117#if defined(__cplusplus)
118}
119#endif
120
121#endif /* OPENSSL_HEADER_CRYPTO_MLKEM_INTERNAL_H */