Skip to content

Commit 752841e

Browse files
committed
ML-KEM: Add AArch64 arithmetic backend
Context: The ML-KEM implementation in AWS-LC is imported from mlkem-native. mlkem-native comes in a "C-only" version, but also offers AArch64 and x86_64 backends for (a) arithmetic, and (b) FIPS-202. Currently, only the "C-only" version is imported into AWS-LC. This commit adds a custom AArch64 backend to AWS-LC. The backend is essentially the same as in mlkem-native, but its assembly sources are taken from s2n-bignum and its headers are written from scratch. The constant tables used in the backend are copied from mlkem-native. Compared to extending the mlkem-native->AWS-LC importer to include mlkem-native's AArch64 backend, this approach sticks to s2n-bignum as the sole source of verified assembly. It also provides greater flexibility in maintaining and adjusting the backend, both the assembly and the headers. For example, the assembly may be optimized for Graviton cores in the future, or the dispatch in the metadata files adjusted; the latter will mostly be relevant as we integrate x86_64 assembly, for which we aim to use the same methodology. To avoid a symbol clash with s2n-bignum, the mlkem-native namespace is changed from `mlkem` to `mlkem_native`. s2n-bignum is partially re-imported from the development branch https://github.com/jargh/s2n-bignum-dev/tree/mlkem/, restricting to the ML-KEM related files. Signed-off-by: Hanno Becker <[email protected]>
1 parent 587cf97 commit 752841e

19 files changed

+3019
-20
lines changed

crypto/fipsmodule/CMakeLists.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,20 @@ if((((ARCH STREQUAL "x86_64") AND NOT MY_ASSEMBLER_IS_TOO_OLD_FOR_512AVX) OR
291291
${S2N_BIGNUM_DIR}/generic/bignum_copy_row_from_table_16.S
292292
${S2N_BIGNUM_DIR}/generic/bignum_copy_row_from_table_32.S
293293
)
294+
295+
# ML-KEM core arithmetic
296+
list(APPEND BCM_ASM_SOURCES
297+
${S2N_BIGNUM_DIR}/mlkem/mlkem_basemul_k2.S
298+
${S2N_BIGNUM_DIR}/mlkem/mlkem_basemul_k3.S
299+
${S2N_BIGNUM_DIR}/mlkem/mlkem_basemul_k4.S
300+
${S2N_BIGNUM_DIR}/mlkem/mlkem_intt.S
301+
${S2N_BIGNUM_DIR}/mlkem/mlkem_mulcache_compute.S
302+
${S2N_BIGNUM_DIR}/mlkem/mlkem_ntt.S
303+
${S2N_BIGNUM_DIR}/mlkem/mlkem_poly_reduce.S
304+
${S2N_BIGNUM_DIR}/mlkem/mlkem_poly_tobytes.S
305+
${S2N_BIGNUM_DIR}/mlkem/mlkem_poly_tomont.S
306+
${S2N_BIGNUM_DIR}/mlkem/mlkem_rej_uniform_VARIABLE_TIME.S)
307+
294308
endif()
295309

296310
if(BORINGSSL_PREFIX)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
This directory contains an AArch64 arithmetic backend for mlkem-native. The core assembly routines are imported from [s2n-bignum](https://github.com/awslabs/s2n-bignum/).

crypto/fipsmodule/ml_kem/aarch64/constants.c

Lines changed: 668 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0 OR ISC
3+
4+
#ifndef ML_KEM_AARCH64_BACKEND_H
5+
#define ML_KEM_AARCH64_BACKEND_H
6+
7+
#include "../mlkem/common.h"
8+
9+
#define MLK_USE_NATIVE_NTT
10+
#define MLK_USE_NATIVE_INTT
11+
#define MLK_USE_NATIVE_POLY_REDUCE
12+
#define MLK_USE_NATIVE_POLY_TOMONT
13+
#define MLK_USE_NATIVE_POLY_MULCACHE_COMPUTE
14+
#define MLK_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED
15+
#define MLK_USE_NATIVE_POLY_TOBYTES
16+
#define MLK_USE_NATIVE_REJ_UNIFORM
17+
18+
extern const int16_t mlk_aarch64_ntt_zetas_layer12345[];
19+
extern const int16_t mlk_aarch64_ntt_zetas_layer67[];
20+
extern const int16_t mlk_aarch64_invntt_zetas_layer12345[];
21+
extern const int16_t mlk_aarch64_invntt_zetas_layer67[];
22+
extern const uint8_t mlk_rej_uniform_table[];
23+
extern const int16_t mlk_aarch64_zetas_mulcache_native[];
24+
extern const int16_t mlk_aarch64_zetas_mulcache_twisted_native[];
25+
26+
#include "s2n-bignum.h"
27+
28+
static MLK_INLINE void mlk_ntt_native(int16_t data[MLKEM_N]) {
29+
mlkem_ntt(data, mlk_aarch64_ntt_zetas_layer12345, mlk_aarch64_ntt_zetas_layer67);
30+
}
31+
32+
static MLK_INLINE void mlk_intt_native(int16_t data[MLKEM_N]) {
33+
mlkem_intt(data, mlk_aarch64_invntt_zetas_layer12345, mlk_aarch64_invntt_zetas_layer67);
34+
}
35+
36+
static MLK_INLINE void mlk_poly_reduce_native(int16_t data[MLKEM_N]) {
37+
mlkem_poly_reduce(data);
38+
}
39+
40+
static MLK_INLINE void mlk_poly_tomont_native(int16_t data[MLKEM_N]) {
41+
mlkem_poly_tomont(data);
42+
}
43+
44+
static MLK_INLINE void mlk_poly_mulcache_compute_native(int16_t x[MLKEM_N / 2], const int16_t y[MLKEM_N]) {
45+
mlkem_mulcache_compute(x, y, mlk_aarch64_zetas_mulcache_native,
46+
mlk_aarch64_zetas_mulcache_twisted_native);
47+
}
48+
49+
static MLK_INLINE void mlk_polyvec_basemul_acc_montgomery_cached_k2_native(
50+
int16_t r[MLKEM_N], const int16_t a[2 * MLKEM_N],
51+
const int16_t b[2 * MLKEM_N], const int16_t b_cache[2 * (MLKEM_N / 2)]) {
52+
mlkem_basemul_k2(r, a, b, b_cache);
53+
}
54+
55+
static MLK_INLINE void mlk_polyvec_basemul_acc_montgomery_cached_k3_native(
56+
int16_t r[MLKEM_N], const int16_t a[3 * MLKEM_N],
57+
const int16_t b[3 * MLKEM_N], const int16_t b_cache[3 * (MLKEM_N / 2)]) {
58+
mlkem_basemul_k3(r, a, b, b_cache);
59+
}
60+
61+
static MLK_INLINE void mlk_polyvec_basemul_acc_montgomery_cached_k4_native(
62+
int16_t r[MLKEM_N], const int16_t a[4 * MLKEM_N],
63+
const int16_t b[4 * MLKEM_N], const int16_t b_cache[4 * (MLKEM_N / 2)]) {
64+
mlkem_basemul_k4(r, a, b, b_cache);
65+
}
66+
67+
static MLK_INLINE void mlk_poly_tobytes_native(uint8_t r[MLKEM_POLYBYTES],
68+
const int16_t a[MLKEM_N]) {
69+
mlkem_poly_tobytes(r, a);
70+
}
71+
72+
static MLK_INLINE int mlk_rej_uniform_native(int16_t *r, unsigned len,
73+
const uint8_t *buf,
74+
unsigned buflen) {
75+
if (len != MLKEM_N || buflen % 24 != 0) {
76+
return -1;
77+
}
78+
return (int) mlkem_rej_uniform_VARIABLE_TIME(r, buf, buflen, mlk_rej_uniform_table);
79+
}
80+
81+
#endif /* ML_KEM_AARCH64_BACKEND_H */

crypto/fipsmodule/ml_kem/ml_kem.c

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@
2626

2727
#include "./ml_kem.h"
2828

29+
// AArch64 backend
30+
#if defined(OPENSSL_AARCH64) && !defined(OPENSSL_NO_ASM)
31+
#include "aarch64/constants.c"
32+
#endif
33+
2934
typedef struct {
3035
uint8_t *buffer;
3136
size_t *length;
@@ -92,7 +97,7 @@ int ml_kem_512_keypair_deterministic_no_self_test(uint8_t *public_key /* OUT */
9297
if (!check_buffer(pkey) || !check_buffer(skey)) {
9398
return 1;
9499
}
95-
const int res = mlkem512_keypair_derand(pkey.buffer, skey.buffer, seed);
100+
const int res = mlkem_native512_keypair_derand(pkey.buffer, skey.buffer, seed);
96101
#if defined(AWSLC_FIPS)
97102
/* PCT failure is the only failure condition for key generation. */
98103
if (res != 0) {
@@ -110,7 +115,7 @@ int ml_kem_512_keypair(uint8_t *public_key /* OUT */,
110115
size_t *secret_len /* IN_OUT */) {
111116
output_buffer pkey = {public_key, public_len, MLKEM512_PUBLIC_KEY_BYTES};
112117
output_buffer skey = {secret_key, secret_len, MLKEM512_SECRET_KEY_BYTES};
113-
return ml_kem_common_keypair(mlkem512_keypair, pkey, skey);
118+
return ml_kem_common_keypair(mlkem_native512_keypair, pkey, skey);
114119
}
115120

116121
int ml_kem_512_encapsulate_deterministic(uint8_t *ciphertext /* OUT */,
@@ -131,7 +136,7 @@ int ml_kem_512_encapsulate_deterministic_no_self_test(uint8_t *ciphertext
131136
const uint8_t *seed /* IN */) {
132137
output_buffer ctext = {ciphertext, ciphertext_len, MLKEM512_CIPHERTEXT_BYTES};
133138
output_buffer ss = {shared_secret, shared_secret_len, MLKEM512_SHARED_SECRET_LEN};
134-
return ml_kem_common_encapsulate_deterministic(mlkem512_enc_derand, ctext, ss, public_key, seed);
139+
return ml_kem_common_encapsulate_deterministic(mlkem_native512_enc_derand, ctext, ss, public_key, seed);
135140
}
136141

137142
int ml_kem_512_encapsulate(uint8_t *ciphertext /* OUT */,
@@ -141,7 +146,7 @@ int ml_kem_512_encapsulate(uint8_t *ciphertext /* OUT */,
141146
const uint8_t *public_key /* IN */) {
142147
output_buffer ctext = {ciphertext, ciphertext_len, MLKEM512_CIPHERTEXT_BYTES};
143148
output_buffer ss = {shared_secret, shared_secret_len, MLKEM512_SHARED_SECRET_LEN};
144-
return ml_kem_common_encapsulate(mlkem512_enc, ctext, ss, public_key);
149+
return ml_kem_common_encapsulate(mlkem_native512_enc, ctext, ss, public_key);
145150
}
146151

147152
int ml_kem_512_decapsulate(uint8_t *shared_secret /* OUT */,
@@ -157,7 +162,7 @@ int ml_kem_512_decapsulate_no_self_test(uint8_t *shared_secret /* OUT */,
157162
const uint8_t *ciphertext /* IN */,
158163
const uint8_t *secret_key /* IN */) {
159164
output_buffer ss = {shared_secret, shared_secret_len, MLKEM512_SHARED_SECRET_LEN};
160-
return ml_kem_common_decapsulate(mlkem512_dec, ss, ciphertext, secret_key);
165+
return ml_kem_common_decapsulate(mlkem_native512_dec, ss, ciphertext, secret_key);
161166
}
162167

163168

@@ -181,7 +186,7 @@ int ml_kem_768_keypair_deterministic_no_self_test(uint8_t *public_key /* OUT */,
181186
if (!check_buffer(pkey) || !check_buffer(skey)) {
182187
return 1;
183188
}
184-
const int res = mlkem768_keypair_derand(pkey.buffer, skey.buffer, seed);
189+
const int res = mlkem_native768_keypair_derand(pkey.buffer, skey.buffer, seed);
185190
#if defined(AWSLC_FIPS)
186191
/* PCT failure is the only failure condition for key generation. */
187192
if (res != 0) {
@@ -199,7 +204,7 @@ int ml_kem_768_keypair(uint8_t *public_key /* OUT */,
199204
size_t *secret_len /* IN_OUT */) {
200205
output_buffer pkey = {public_key, public_len, MLKEM768_PUBLIC_KEY_BYTES};
201206
output_buffer skey = {secret_key, secret_len, MLKEM768_SECRET_KEY_BYTES};
202-
return ml_kem_common_keypair(mlkem768_keypair, pkey, skey);
207+
return ml_kem_common_keypair(mlkem_native768_keypair, pkey, skey);
203208
}
204209

205210
int ml_kem_768_encapsulate_deterministic(uint8_t *ciphertext /* OUT */,
@@ -220,7 +225,7 @@ int ml_kem_768_encapsulate_deterministic_no_self_test(uint8_t *ciphertext
220225
const uint8_t *seed /* IN */) {
221226
output_buffer ctext = {ciphertext, ciphertext_len, MLKEM768_CIPHERTEXT_BYTES};
222227
output_buffer ss = {shared_secret, shared_secret_len, MLKEM768_SHARED_SECRET_LEN};
223-
return ml_kem_common_encapsulate_deterministic(mlkem768_enc_derand, ctext, ss, public_key, seed);
228+
return ml_kem_common_encapsulate_deterministic(mlkem_native768_enc_derand, ctext, ss, public_key, seed);
224229
}
225230

226231
int ml_kem_768_encapsulate(uint8_t *ciphertext /* OUT */,
@@ -230,7 +235,7 @@ int ml_kem_768_encapsulate(uint8_t *ciphertext /* OUT */,
230235
const uint8_t *public_key /* IN */) {
231236
output_buffer ctext = {ciphertext, ciphertext_len, MLKEM768_CIPHERTEXT_BYTES};
232237
output_buffer ss = {shared_secret, shared_secret_len, MLKEM768_SHARED_SECRET_LEN};
233-
return ml_kem_common_encapsulate(mlkem768_enc, ctext, ss, public_key);
238+
return ml_kem_common_encapsulate(mlkem_native768_enc, ctext, ss, public_key);
234239
}
235240

236241
int ml_kem_768_decapsulate(uint8_t *shared_secret /* OUT */,
@@ -246,7 +251,7 @@ int ml_kem_768_decapsulate_no_self_test(uint8_t *shared_secret /* OUT */,
246251
const uint8_t *ciphertext /* IN */,
247252
const uint8_t *secret_key /* IN */) {
248253
output_buffer ss = {shared_secret, shared_secret_len, MLKEM768_SHARED_SECRET_LEN};
249-
return ml_kem_common_decapsulate(mlkem768_dec, ss, ciphertext, secret_key);
254+
return ml_kem_common_decapsulate(mlkem_native768_dec, ss, ciphertext, secret_key);
250255
}
251256

252257
int ml_kem_1024_keypair_deterministic(uint8_t *public_key /* OUT */,
@@ -268,7 +273,7 @@ int ml_kem_1024_keypair_deterministic_no_self_test(uint8_t *public_key /* OUT */
268273
if (!check_buffer(pkey) || !check_buffer(skey)) {
269274
return 1;
270275
}
271-
const int res = mlkem1024_keypair_derand(pkey.buffer, skey.buffer, seed);
276+
const int res = mlkem_native1024_keypair_derand(pkey.buffer, skey.buffer, seed);
272277
#if defined(AWSLC_FIPS)
273278
/* PCT failure is the only failure condition for key generation. */
274279
if (res != 0) {
@@ -286,7 +291,7 @@ int ml_kem_1024_keypair(uint8_t *public_key /* OUT */,
286291
size_t *secret_len /* IN_OUT */) {
287292
output_buffer pkey = {public_key, public_len, MLKEM1024_PUBLIC_KEY_BYTES};
288293
output_buffer skey = {secret_key, secret_len, MLKEM1024_SECRET_KEY_BYTES};
289-
return ml_kem_common_keypair(mlkem1024_keypair, pkey, skey);
294+
return ml_kem_common_keypair(mlkem_native1024_keypair, pkey, skey);
290295
}
291296

292297
int ml_kem_1024_encapsulate_deterministic(uint8_t *ciphertext /* OUT */,
@@ -307,7 +312,7 @@ int ml_kem_1024_encapsulate_deterministic_no_self_test(uint8_t *ciphertext
307312
const uint8_t *seed /* IN */) {
308313
output_buffer ctext = {ciphertext, ciphertext_len, MLKEM1024_CIPHERTEXT_BYTES};
309314
output_buffer ss = {shared_secret, shared_secret_len, MLKEM1024_SHARED_SECRET_LEN};
310-
return ml_kem_common_encapsulate_deterministic(mlkem1024_enc_derand, ctext, ss, public_key, seed);
315+
return ml_kem_common_encapsulate_deterministic(mlkem_native1024_enc_derand, ctext, ss, public_key, seed);
311316
}
312317

313318
int ml_kem_1024_encapsulate(uint8_t *ciphertext /* OUT */,
@@ -317,7 +322,7 @@ int ml_kem_1024_encapsulate(uint8_t *ciphertext /* OUT */,
317322
const uint8_t *public_key /* IN */) {
318323
output_buffer ctext = {ciphertext, ciphertext_len, MLKEM1024_CIPHERTEXT_BYTES};
319324
output_buffer ss = {shared_secret, shared_secret_len, MLKEM1024_SHARED_SECRET_LEN};
320-
return ml_kem_common_encapsulate(mlkem1024_enc, ctext, ss, public_key);
325+
return ml_kem_common_encapsulate(mlkem_native1024_enc, ctext, ss, public_key);
321326
}
322327

323328
int ml_kem_1024_decapsulate(uint8_t *shared_secret /* OUT */,
@@ -333,7 +338,7 @@ int ml_kem_1024_decapsulate_no_self_test(uint8_t *shared_secret /* OUT */,
333338
const uint8_t *ciphertext /* IN */,
334339
const uint8_t *secret_key /* IN */) {
335340
output_buffer ss = {shared_secret, shared_secret_len, MLKEM1024_SHARED_SECRET_LEN};
336-
return ml_kem_common_decapsulate(mlkem1024_dec, ss, ciphertext, secret_key);
341+
return ml_kem_common_decapsulate(mlkem_native1024_dec, ss, ciphertext, secret_key);
337342
}
338343

339344
int ml_kem_common_keypair(int (*keypair)(uint8_t * public_key, uint8_t *secret_key),

crypto/fipsmodule/ml_kem/mlkem_native_config.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
// Namespacing: All symbols are of the form mlkem*. Level-specific
1010
// symbols are further prefixed with their security level, e.g.
1111
// mlkem512*, mlkem768*, mlkem1024*.
12-
#define MLK_CONFIG_NAMESPACE_PREFIX mlkem
12+
#define MLK_CONFIG_NAMESPACE_PREFIX mlkem_native
1313

1414
// Replace mlkem-native's FIPS 202 headers with glue code to
1515
// AWS-LC's own FIPS 202 implementation.
@@ -68,4 +68,9 @@ static MLK_INLINE void mlk_randombytes(void *ptr, size_t len) {
6868
#define MLK_CONFIG_NO_ASM
6969
#endif
7070

71+
#if defined(OPENSSL_AARCH64) && !defined(OPENSSL_NO_ASM)
72+
#define MLK_CONFIG_USE_NATIVE_BACKEND_ARITH
73+
#define MLK_CONFIG_ARITH_BACKEND_FILE "../aarch64/meta.h"
74+
#endif
75+
7176
#endif // MLkEM_NATIVE_CONFIG_H

third_party/s2n-bignum/META.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
name: s2n-bignum-imported
2-
source: awslabs/s2n-bignum.git
3-
commit: 54e1fa5756d6b13961c2f61d90f75426aa25d373
4-
target: main
5-
imported-at: 2025-04-28T17:22:07+0000
2+
source: jargh/s2n-bignum-dev.git
3+
commit: ae84a59689cb50ad9b9c6e25cd34037d5b1fb2b4
4+
target: mlkem
5+
imported-at: 2025-06-23T13:38:02+0000
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#############################################################################
2+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
# SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT-0
4+
#############################################################################
5+
6+
# If actually on an ARM8 machine, just use the GNU assembler (as). Otherwise
7+
# use a cross-assembling version so that the code can still be assembled
8+
# and the proofs checked against the object files (though you won't be able
9+
# to run code without additional emulation infrastructure). The aarch64
10+
# cross-assembling version can be installed manually by something like:
11+
#
12+
# sudo apt-get install binutils-aarch64-linux-gnu
13+
14+
UNAME_RESULT=$(shell uname -p)
15+
16+
ifeq ($(UNAME_RESULT),aarch64)
17+
GAS=as
18+
else
19+
GAS=aarch64-linux-gnu-as
20+
endif
21+
22+
# List of object files
23+
24+
OBJ = mlkem_basemul_k2.o \
25+
mlkem_basemul_k3.o \
26+
mlkem_basemul_k4.o \
27+
mlkem_intt.o \
28+
mlkem_mulcache_compute.o \
29+
mlkem_ntt.o \
30+
mlkem_poly_reduce.o \
31+
mlkem_poly_tobytes.o \
32+
mlkem_poly_tomont.o \
33+
mlkem_rej_uniform_VARIABLE_TIME.o
34+
35+
%.o : %.S ; $(CC) -E -I../../include $< | $(GAS) -o $@ -
36+
37+
default: $(OBJ);
38+
39+
clean:; rm -f *.o *.correct unopt/*.o

0 commit comments

Comments
 (0)