Skip to content

[RFC] ML-KEM: Add AArch64 arithmetic backend (from s2n-bignum) #2500

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions crypto/fipsmodule/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,20 @@ if((((ARCH STREQUAL "x86_64") AND NOT MY_ASSEMBLER_IS_TOO_OLD_FOR_512AVX) OR
${S2N_BIGNUM_DIR}/generic/bignum_copy_row_from_table_16.S
${S2N_BIGNUM_DIR}/generic/bignum_copy_row_from_table_32.S
)

# ML-KEM core arithmetic
list(APPEND BCM_ASM_SOURCES
${S2N_BIGNUM_DIR}/mlkem/mlkem_basemul_k2.S
${S2N_BIGNUM_DIR}/mlkem/mlkem_basemul_k3.S
${S2N_BIGNUM_DIR}/mlkem/mlkem_basemul_k4.S
${S2N_BIGNUM_DIR}/mlkem/mlkem_intt.S
${S2N_BIGNUM_DIR}/mlkem/mlkem_mulcache_compute.S
${S2N_BIGNUM_DIR}/mlkem/mlkem_ntt.S
${S2N_BIGNUM_DIR}/mlkem/mlkem_poly_reduce.S
${S2N_BIGNUM_DIR}/mlkem/mlkem_poly_tobytes.S
${S2N_BIGNUM_DIR}/mlkem/mlkem_poly_tomont.S
${S2N_BIGNUM_DIR}/mlkem/mlkem_rej_uniform_VARIABLE_TIME.S)

endif()

if(BORINGSSL_PREFIX)
Expand Down
1 change: 1 addition & 0 deletions crypto/fipsmodule/ml_kem/aarch64/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
This directory contains an AArch64 arithmetic backend for mlkem-native. The core assembly routines are imported from [s2n-bignum](https://github.com/awslabs/s2n-bignum/).
668 changes: 668 additions & 0 deletions crypto/fipsmodule/ml_kem/aarch64/constants.c

Large diffs are not rendered by default.

81 changes: 81 additions & 0 deletions crypto/fipsmodule/ml_kem/aarch64/meta.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 OR ISC

#ifndef ML_KEM_AARCH64_BACKEND_H
#define ML_KEM_AARCH64_BACKEND_H

#include "../mlkem/common.h"

#define MLK_USE_NATIVE_NTT
#define MLK_USE_NATIVE_INTT
#define MLK_USE_NATIVE_POLY_REDUCE
#define MLK_USE_NATIVE_POLY_TOMONT
#define MLK_USE_NATIVE_POLY_MULCACHE_COMPUTE
#define MLK_USE_NATIVE_POLYVEC_BASEMUL_ACC_MONTGOMERY_CACHED
#define MLK_USE_NATIVE_POLY_TOBYTES
#define MLK_USE_NATIVE_REJ_UNIFORM

extern const int16_t mlk_aarch64_ntt_zetas_layer12345[];
extern const int16_t mlk_aarch64_ntt_zetas_layer67[];
extern const int16_t mlk_aarch64_invntt_zetas_layer12345[];
extern const int16_t mlk_aarch64_invntt_zetas_layer67[];
extern const uint8_t mlk_rej_uniform_table[];
extern const int16_t mlk_aarch64_zetas_mulcache_native[];
extern const int16_t mlk_aarch64_zetas_mulcache_twisted_native[];

#include "../../../../third_party/s2n-bignum/s2n-bignum_aws-lc.h"

static MLK_INLINE void mlk_ntt_native(int16_t data[MLKEM_N]) {
mlkem_ntt(data, mlk_aarch64_ntt_zetas_layer12345, mlk_aarch64_ntt_zetas_layer67);
}

static MLK_INLINE void mlk_intt_native(int16_t data[MLKEM_N]) {
mlkem_intt(data, mlk_aarch64_invntt_zetas_layer12345, mlk_aarch64_invntt_zetas_layer67);
}

static MLK_INLINE void mlk_poly_reduce_native(int16_t data[MLKEM_N]) {
mlkem_poly_reduce(data);
}

static MLK_INLINE void mlk_poly_tomont_native(int16_t data[MLKEM_N]) {
mlkem_poly_tomont(data);
}

static MLK_INLINE void mlk_poly_mulcache_compute_native(int16_t x[MLKEM_N / 2], const int16_t y[MLKEM_N]) {
mlkem_mulcache_compute(x, y, mlk_aarch64_zetas_mulcache_native,
mlk_aarch64_zetas_mulcache_twisted_native);
}

static MLK_INLINE void mlk_polyvec_basemul_acc_montgomery_cached_k2_native(
int16_t r[MLKEM_N], const int16_t a[2 * MLKEM_N],
const int16_t b[2 * MLKEM_N], const int16_t b_cache[2 * (MLKEM_N / 2)]) {
mlkem_basemul_k2(r, a, b, b_cache);
}

static MLK_INLINE void mlk_polyvec_basemul_acc_montgomery_cached_k3_native(
int16_t r[MLKEM_N], const int16_t a[3 * MLKEM_N],
const int16_t b[3 * MLKEM_N], const int16_t b_cache[3 * (MLKEM_N / 2)]) {
mlkem_basemul_k3(r, a, b, b_cache);
}

static MLK_INLINE void mlk_polyvec_basemul_acc_montgomery_cached_k4_native(
int16_t r[MLKEM_N], const int16_t a[4 * MLKEM_N],
const int16_t b[4 * MLKEM_N], const int16_t b_cache[4 * (MLKEM_N / 2)]) {
mlkem_basemul_k4(r, a, b, b_cache);
}

static MLK_INLINE void mlk_poly_tobytes_native(uint8_t r[MLKEM_POLYBYTES],
const int16_t a[MLKEM_N]) {
mlkem_poly_tobytes(r, a);
}

static MLK_INLINE int mlk_rej_uniform_native(int16_t *r, unsigned len,
const uint8_t *buf,
unsigned buflen) {
if (len != MLKEM_N || buflen % 24 != 0) {
return -1;
}
return (int) mlkem_rej_uniform_VARIABLE_TIME(r, buf, buflen, mlk_rej_uniform_table);
}

#endif /* ML_KEM_AARCH64_BACKEND_H */
35 changes: 20 additions & 15 deletions crypto/fipsmodule/ml_kem/ml_kem.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@

#include "./ml_kem.h"

// AArch64 backend
#if defined(OPENSSL_AARCH64) && !defined(OPENSSL_NO_ASM)
#include "aarch64/constants.c"
#endif

typedef struct {
uint8_t *buffer;
size_t *length;
Expand Down Expand Up @@ -92,7 +97,7 @@ int ml_kem_512_keypair_deterministic_no_self_test(uint8_t *public_key /* OUT */
if (!check_buffer(pkey) || !check_buffer(skey)) {
return 1;
}
const int res = mlkem512_keypair_derand(pkey.buffer, skey.buffer, seed);
const int res = mlkem_native512_keypair_derand(pkey.buffer, skey.buffer, seed);
#if defined(AWSLC_FIPS)
/* PCT failure is the only failure condition for key generation. */
if (res != 0) {
Expand All @@ -110,7 +115,7 @@ int ml_kem_512_keypair(uint8_t *public_key /* OUT */,
size_t *secret_len /* IN_OUT */) {
output_buffer pkey = {public_key, public_len, MLKEM512_PUBLIC_KEY_BYTES};
output_buffer skey = {secret_key, secret_len, MLKEM512_SECRET_KEY_BYTES};
return ml_kem_common_keypair(mlkem512_keypair, pkey, skey);
return ml_kem_common_keypair(mlkem_native512_keypair, pkey, skey);
}

int ml_kem_512_encapsulate_deterministic(uint8_t *ciphertext /* OUT */,
Expand All @@ -131,7 +136,7 @@ int ml_kem_512_encapsulate_deterministic_no_self_test(uint8_t *ciphertext
const uint8_t *seed /* IN */) {
output_buffer ctext = {ciphertext, ciphertext_len, MLKEM512_CIPHERTEXT_BYTES};
output_buffer ss = {shared_secret, shared_secret_len, MLKEM512_SHARED_SECRET_LEN};
return ml_kem_common_encapsulate_deterministic(mlkem512_enc_derand, ctext, ss, public_key, seed);
return ml_kem_common_encapsulate_deterministic(mlkem_native512_enc_derand, ctext, ss, public_key, seed);
}

int ml_kem_512_encapsulate(uint8_t *ciphertext /* OUT */,
Expand All @@ -141,7 +146,7 @@ int ml_kem_512_encapsulate(uint8_t *ciphertext /* OUT */,
const uint8_t *public_key /* IN */) {
output_buffer ctext = {ciphertext, ciphertext_len, MLKEM512_CIPHERTEXT_BYTES};
output_buffer ss = {shared_secret, shared_secret_len, MLKEM512_SHARED_SECRET_LEN};
return ml_kem_common_encapsulate(mlkem512_enc, ctext, ss, public_key);
return ml_kem_common_encapsulate(mlkem_native512_enc, ctext, ss, public_key);
}

int ml_kem_512_decapsulate(uint8_t *shared_secret /* OUT */,
Expand All @@ -157,7 +162,7 @@ int ml_kem_512_decapsulate_no_self_test(uint8_t *shared_secret /* OUT */,
const uint8_t *ciphertext /* IN */,
const uint8_t *secret_key /* IN */) {
output_buffer ss = {shared_secret, shared_secret_len, MLKEM512_SHARED_SECRET_LEN};
return ml_kem_common_decapsulate(mlkem512_dec, ss, ciphertext, secret_key);
return ml_kem_common_decapsulate(mlkem_native512_dec, ss, ciphertext, secret_key);
}


Expand All @@ -181,7 +186,7 @@ int ml_kem_768_keypair_deterministic_no_self_test(uint8_t *public_key /* OUT */,
if (!check_buffer(pkey) || !check_buffer(skey)) {
return 1;
}
const int res = mlkem768_keypair_derand(pkey.buffer, skey.buffer, seed);
const int res = mlkem_native768_keypair_derand(pkey.buffer, skey.buffer, seed);
#if defined(AWSLC_FIPS)
/* PCT failure is the only failure condition for key generation. */
if (res != 0) {
Expand All @@ -199,7 +204,7 @@ int ml_kem_768_keypair(uint8_t *public_key /* OUT */,
size_t *secret_len /* IN_OUT */) {
output_buffer pkey = {public_key, public_len, MLKEM768_PUBLIC_KEY_BYTES};
output_buffer skey = {secret_key, secret_len, MLKEM768_SECRET_KEY_BYTES};
return ml_kem_common_keypair(mlkem768_keypair, pkey, skey);
return ml_kem_common_keypair(mlkem_native768_keypair, pkey, skey);
}

int ml_kem_768_encapsulate_deterministic(uint8_t *ciphertext /* OUT */,
Expand All @@ -220,7 +225,7 @@ int ml_kem_768_encapsulate_deterministic_no_self_test(uint8_t *ciphertext
const uint8_t *seed /* IN */) {
output_buffer ctext = {ciphertext, ciphertext_len, MLKEM768_CIPHERTEXT_BYTES};
output_buffer ss = {shared_secret, shared_secret_len, MLKEM768_SHARED_SECRET_LEN};
return ml_kem_common_encapsulate_deterministic(mlkem768_enc_derand, ctext, ss, public_key, seed);
return ml_kem_common_encapsulate_deterministic(mlkem_native768_enc_derand, ctext, ss, public_key, seed);
}

int ml_kem_768_encapsulate(uint8_t *ciphertext /* OUT */,
Expand All @@ -230,7 +235,7 @@ int ml_kem_768_encapsulate(uint8_t *ciphertext /* OUT */,
const uint8_t *public_key /* IN */) {
output_buffer ctext = {ciphertext, ciphertext_len, MLKEM768_CIPHERTEXT_BYTES};
output_buffer ss = {shared_secret, shared_secret_len, MLKEM768_SHARED_SECRET_LEN};
return ml_kem_common_encapsulate(mlkem768_enc, ctext, ss, public_key);
return ml_kem_common_encapsulate(mlkem_native768_enc, ctext, ss, public_key);
}

int ml_kem_768_decapsulate(uint8_t *shared_secret /* OUT */,
Expand All @@ -246,7 +251,7 @@ int ml_kem_768_decapsulate_no_self_test(uint8_t *shared_secret /* OUT */,
const uint8_t *ciphertext /* IN */,
const uint8_t *secret_key /* IN */) {
output_buffer ss = {shared_secret, shared_secret_len, MLKEM768_SHARED_SECRET_LEN};
return ml_kem_common_decapsulate(mlkem768_dec, ss, ciphertext, secret_key);
return ml_kem_common_decapsulate(mlkem_native768_dec, ss, ciphertext, secret_key);
}

int ml_kem_1024_keypair_deterministic(uint8_t *public_key /* OUT */,
Expand All @@ -268,7 +273,7 @@ int ml_kem_1024_keypair_deterministic_no_self_test(uint8_t *public_key /* OUT */
if (!check_buffer(pkey) || !check_buffer(skey)) {
return 1;
}
const int res = mlkem1024_keypair_derand(pkey.buffer, skey.buffer, seed);
const int res = mlkem_native1024_keypair_derand(pkey.buffer, skey.buffer, seed);
#if defined(AWSLC_FIPS)
/* PCT failure is the only failure condition for key generation. */
if (res != 0) {
Expand All @@ -286,7 +291,7 @@ int ml_kem_1024_keypair(uint8_t *public_key /* OUT */,
size_t *secret_len /* IN_OUT */) {
output_buffer pkey = {public_key, public_len, MLKEM1024_PUBLIC_KEY_BYTES};
output_buffer skey = {secret_key, secret_len, MLKEM1024_SECRET_KEY_BYTES};
return ml_kem_common_keypair(mlkem1024_keypair, pkey, skey);
return ml_kem_common_keypair(mlkem_native1024_keypair, pkey, skey);
}

int ml_kem_1024_encapsulate_deterministic(uint8_t *ciphertext /* OUT */,
Expand All @@ -307,7 +312,7 @@ int ml_kem_1024_encapsulate_deterministic_no_self_test(uint8_t *ciphertext
const uint8_t *seed /* IN */) {
output_buffer ctext = {ciphertext, ciphertext_len, MLKEM1024_CIPHERTEXT_BYTES};
output_buffer ss = {shared_secret, shared_secret_len, MLKEM1024_SHARED_SECRET_LEN};
return ml_kem_common_encapsulate_deterministic(mlkem1024_enc_derand, ctext, ss, public_key, seed);
return ml_kem_common_encapsulate_deterministic(mlkem_native1024_enc_derand, ctext, ss, public_key, seed);
}

int ml_kem_1024_encapsulate(uint8_t *ciphertext /* OUT */,
Expand All @@ -317,7 +322,7 @@ int ml_kem_1024_encapsulate(uint8_t *ciphertext /* OUT */,
const uint8_t *public_key /* IN */) {
output_buffer ctext = {ciphertext, ciphertext_len, MLKEM1024_CIPHERTEXT_BYTES};
output_buffer ss = {shared_secret, shared_secret_len, MLKEM1024_SHARED_SECRET_LEN};
return ml_kem_common_encapsulate(mlkem1024_enc, ctext, ss, public_key);
return ml_kem_common_encapsulate(mlkem_native1024_enc, ctext, ss, public_key);
}

int ml_kem_1024_decapsulate(uint8_t *shared_secret /* OUT */,
Expand All @@ -333,7 +338,7 @@ int ml_kem_1024_decapsulate_no_self_test(uint8_t *shared_secret /* OUT */,
const uint8_t *ciphertext /* IN */,
const uint8_t *secret_key /* IN */) {
output_buffer ss = {shared_secret, shared_secret_len, MLKEM1024_SHARED_SECRET_LEN};
return ml_kem_common_decapsulate(mlkem1024_dec, ss, ciphertext, secret_key);
return ml_kem_common_decapsulate(mlkem_native1024_dec, ss, ciphertext, secret_key);
}

int ml_kem_common_keypair(int (*keypair)(uint8_t * public_key, uint8_t *secret_key),
Expand Down
8 changes: 7 additions & 1 deletion crypto/fipsmodule/ml_kem/mlkem_native_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
// Namespacing: All symbols are of the form mlkem*. Level-specific
// symbols are further prefixed with their security level, e.g.
// mlkem512*, mlkem768*, mlkem1024*.
#define MLK_CONFIG_NAMESPACE_PREFIX mlkem
#define MLK_CONFIG_NAMESPACE_PREFIX mlkem_native

// Replace mlkem-native's FIPS 202 headers with glue code to
// AWS-LC's own FIPS 202 implementation.
Expand Down Expand Up @@ -68,4 +68,10 @@ static MLK_INLINE void mlk_randombytes(void *ptr, size_t len) {
#define MLK_CONFIG_NO_ASM
#endif

#if defined(OPENSSL_AARCH64) && !defined(OPENSSL_NO_ASM) && \
(defined(OPENSSL_LINUX) || defined(OPENSSL_APPLE))
#define MLK_CONFIG_USE_NATIVE_BACKEND_ARITH
#define MLK_CONFIG_ARITH_BACKEND_FILE "../aarch64/meta.h"
#endif

#endif // MLkEM_NATIVE_CONFIG_H
8 changes: 4 additions & 4 deletions third_party/s2n-bignum/META.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
name: s2n-bignum-imported
source: awslabs/s2n-bignum.git
commit: 54e1fa5756d6b13961c2f61d90f75426aa25d373
target: main
imported-at: 2025-04-28T17:22:07+0000
source: jargh/s2n-bignum-dev.git
commit: ae84a59689cb50ad9b9c6e25cd34037d5b1fb2b4
target: mlkem
imported-at: 2025-06-23T13:38:02+0000
39 changes: 39 additions & 0 deletions third_party/s2n-bignum/s2n-bignum-imported/arm/mlkem/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#############################################################################
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0 OR ISC OR MIT-0
#############################################################################

# If actually on an ARM8 machine, just use the GNU assembler (as). Otherwise
# use a cross-assembling version so that the code can still be assembled
# and the proofs checked against the object files (though you won't be able
# to run code without additional emulation infrastructure). The aarch64
# cross-assembling version can be installed manually by something like:
#
# sudo apt-get install binutils-aarch64-linux-gnu

UNAME_RESULT=$(shell uname -p)

ifeq ($(UNAME_RESULT),aarch64)
GAS=as
else
GAS=aarch64-linux-gnu-as
endif

# List of object files

OBJ = mlkem_basemul_k2.o \
mlkem_basemul_k3.o \
mlkem_basemul_k4.o \
mlkem_intt.o \
mlkem_mulcache_compute.o \
mlkem_ntt.o \
mlkem_poly_reduce.o \
mlkem_poly_tobytes.o \
mlkem_poly_tomont.o \
mlkem_rej_uniform_VARIABLE_TIME.o

%.o : %.S ; $(CC) -E -I../../include $< | $(GAS) -o $@ -

default: $(OBJ);

clean:; rm -f *.o *.correct unopt/*.o
Loading
Loading