Skip to content

Commit b3811b4

Browse files
committed
add explicit type hints to ml-kem
1 parent d0a8f82 commit b3811b4

18 files changed

+238
-289
lines changed

benchmarks/benchmark_kyber.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ def benchmark_kyber(Kyber, name, count):
5353
avg_dec = sum(dec_times) / count
5454
print(
5555
f" {name:11} |"
56-
f"{avg_keygen*1000:7.2f}ms | {1/avg_keygen:10.2f} |"
57-
f"{avg_enc*1000:6.2f}ms | {1/avg_enc:9.2f} |"
58-
f"{avg_dec*1000:6.2f}ms | {1/avg_dec:7.2f} |"
56+
f"{avg_keygen * 1000:7.2f}ms | {1 / avg_keygen:10.2f} |"
57+
f"{avg_enc * 1000:6.2f}ms | {1 / avg_enc:9.2f} |"
58+
f"{avg_dec * 1000:6.2f}ms | {1 / avg_dec:7.2f} |"
5959
)
6060

6161

benchmarks/benchmark_ml_kem.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ def benchmark_ml_kem(ML_KEM, name, count):
5353
avg_dec = sum(dec_times) / count
5454
print(
5555
f" {name:11} |"
56-
f"{avg_keygen*1000:7.2f}ms | {1/avg_keygen:10.2f} |"
57-
f"{avg_enc*1000:6.2f}ms | {1/avg_enc:9.2f} |"
58-
f"{avg_dec*1000:6.2f}ms | {1/avg_dec:7.2f} |"
56+
f"{avg_keygen * 1000:7.2f}ms | {1 / avg_keygen:10.2f} |"
57+
f"{avg_enc * 1000:6.2f}ms | {1 / avg_enc:9.2f} |"
58+
f"{avg_dec * 1000:6.2f}ms | {1 / avg_dec:7.2f} |"
5959
)
6060

6161

src/kyber_py/drbg/aes256_ctr_drbg.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55

66

77
class AES256_CTR_DRBG:
8-
def __init__(
9-
self, seed: Optional[bytes] = None, personalization: bytes = b""
10-
):
8+
def __init__(self, seed: Optional[bytes] = None, personalization: bytes = b""):
119
"""
1210
DRBG implementation based on AES-256 CTR following the document NIST SP
1311
800-90A Section 10.2.1
@@ -59,9 +57,7 @@ def __instantiate(self, personalization: bytes = b"") -> bytes:
5957
f"{self.seed_length}. Input has length {len(personalization)}"
6058
)
6159
# Ensure personalization has exactly seed_length bytes
62-
personalization += bytes([0]) * (
63-
self.seed_length - len(personalization)
64-
)
60+
personalization += bytes([0]) * (self.seed_length - len(personalization))
6561
# debugging
6662
assert len(personalization) == self.seed_length
6763
return xor_bytes(self.entropy_input, personalization)
@@ -97,9 +93,7 @@ def __ctr_drbg_update(self, provided_data: bytes) -> None:
9793
self.key = tmp[:32]
9894
self.V = tmp[32:]
9995

100-
def random_bytes(
101-
self, num_bytes: int, additional: Optional[bytes] = None
102-
) -> bytes:
96+
def random_bytes(self, num_bytes: int, additional: Optional[bytes] = None) -> bytes:
10397
"""
10498
Generate pseudorandom bytes without a generating function
10599

src/kyber_py/kyber/kyber.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
from hashlib import sha3_256, sha3_512, shake_128, shake_256
3-
from ..modules.modules import ModuleKyber
3+
from ..modules.modules import Module
44
from ..utilities.utils import select_bytes
55

66

@@ -17,7 +17,7 @@ def __init__(self, parameter_set):
1717
self.du = parameter_set["du"]
1818
self.dv = parameter_set["dv"]
1919

20-
self.M = ModuleKyber()
20+
self.M = Module()
2121
self.R = self.M.ring
2222

2323
# Use system randomness by default, for deterministic randomness
@@ -111,7 +111,7 @@ def _generate_error_vector(self, sigma, eta, N):
111111
Helper function which generates a element in the
112112
module from the Centered Binomial Distribution.
113113
"""
114-
elements = [0 for _ in range(self.k)]
114+
elements = [self.R.zero() for _ in range(self.k)]
115115
for i in range(self.k):
116116
input_bytes = self._prf(sigma, bytes([N]), 64 * eta)
117117
elements[i] = self.R.cbd(input_bytes, eta)
@@ -135,7 +135,7 @@ def _generate_matrix_from_seed(self, rho, transpose=False):
135135
136136
When `transpose` is set to True, the matrix A is built as the transpose.
137137
"""
138-
A_data = [[0 for _ in range(self.k)] for _ in range(self.k)]
138+
A_data = [[self.R.zero() for _ in range(self.k)] for _ in range(self.k)]
139139
for i in range(self.k):
140140
for j in range(self.k):
141141
input_bytes = self._xof(rho, bytes([j]), bytes([i]))

src/kyber_py/ml_kem/ml_kem.py

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55

66
import os
77
from hashlib import sha3_256, sha3_512, shake_128, shake_256
8-
from ..modules.modules import ModuleKyber
8+
from ..modules.modules import Module, Matrix, Vector
9+
from ..polynomials.polynomials import Polynomial
910
from ..utilities.utils import select_bytes
1011

1112

1213
class ML_KEM:
13-
def __init__(self, params):
14+
def __init__(self, params: dict):
1415
"""
1516
Initialise the ML-KEM with specified lattice parameters.
1617
@@ -23,31 +24,31 @@ def __init__(self, params):
2324
self.du = params["du"]
2425
self.dv = params["dv"]
2526

26-
self.M = ModuleKyber()
27+
self.M = Module()
2728
self.R = self.M.ring
2829
self.oid = params["oid"] if "oid" in params else None
2930

3031
# Use system randomness by default, for deterministic randomness
3132
# use the method `set_drbg_seed()`
3233
self.random_bytes = os.urandom
3334

34-
def _ek_size(self):
35+
def _ek_size(self) -> int:
3536
"""
3637
Return the size of the encapsulation key for the selected paramters.
3738
3839
:rtype: int
3940
"""
4041
return 384 * self.k + 32
4142

42-
def _dk_size(self):
43+
def _dk_size(self) -> int:
4344
"""
4445
Return the size of the decapsulation key for the selected parameters.
4546
4647
:rtype: int
4748
"""
4849
return 768 * self.k + 96
4950

50-
def set_drbg_seed(self, seed):
51+
def set_drbg_seed(self, seed: bytes):
5152
"""
5253
Change entropy source to a DRBG and seed it with provided value.
5354
@@ -74,7 +75,7 @@ def set_drbg_seed(self, seed):
7475
)
7576

7677
@staticmethod
77-
def _xof(bytes32, i, j):
78+
def _xof(b: bytes, i: bytes, j: bytes) -> bytes:
7879
"""
7980
eXtendable-Output Function (XOF) described in 4.9 of FIPS 203 (page 19)
8081
@@ -88,15 +89,15 @@ def _xof(bytes32, i, j):
8889
Casa de Chá da Boa Nova
8990
https://cryptojedi.org/papers/terminate-20230516.pdf
9091
"""
91-
input_bytes = bytes32 + i + j
92+
input_bytes = b + i + j
9293
if len(input_bytes) != 34:
9394
raise ValueError(
9495
"Input bytes should be one 32 byte array and 2 single bytes."
9596
)
9697
return shake_128(input_bytes).digest(840)
9798

9899
@staticmethod
99-
def _prf(eta, s, b):
100+
def _prf(eta: int, s: bytes, b: bytes) -> bytes:
100101
"""
101102
Pseudorandom function described in 4.3 of FIPS 203 (page 18)
102103
"""
@@ -108,57 +109,61 @@ def _prf(eta, s, b):
108109
return shake_256(input_bytes).digest(eta * 64)
109110

110111
@staticmethod
111-
def _H(s):
112+
def _H(s: bytes) -> bytes:
112113
"""
113114
Hash function described in 4.4 of FIPS 203 (page 18)
114115
"""
115116
return sha3_256(s).digest()
116117

117118
@staticmethod
118-
def _J(s):
119+
def _J(s: bytes) -> bytes:
119120
"""
120121
Hash function described in 4.4 of FIPS 203 (page 18)
121122
"""
122123
return shake_256(s).digest(32)
123124

124125
@staticmethod
125-
def _G(s):
126+
def _G(s: bytes) -> tuple[bytes, bytes]:
126127
"""
127128
Hash function described in 4.5 of FIPS 203 (page 18)
128129
"""
129130
h = sha3_512(s).digest()
130131
return h[:32], h[32:]
131132

132-
def _generate_matrix_from_seed(self, rho, transpose=False):
133+
def _generate_matrix_from_seed(self, rho: bytes, transpose: bool = False) -> Matrix:
133134
"""
134135
Helper function which generates a element of size
135136
k x k from a seed `rho`.
136137
137138
When `transpose` is set to True, the matrix A is
138139
built as the transpose.
139140
"""
140-
A_data = [[0 for _ in range(self.k)] for _ in range(self.k)]
141+
A_data = [[self.R.zero() for _ in range(self.k)] for _ in range(self.k)]
141142
for i in range(self.k):
142143
for j in range(self.k):
143144
xof_bytes = self._xof(rho, bytes([j]), bytes([i]))
144145
A_data[i][j] = self.R.ntt_sample(xof_bytes)
145146
A_hat = self.M(A_data, transpose=transpose)
146147
return A_hat
147148

148-
def _generate_error_vector(self, sigma, eta, N):
149+
def _generate_error_vector(
150+
self, sigma: bytes, eta: int, N: int
151+
) -> tuple[Vector, int]:
149152
"""
150153
Helper function which generates a element in the
151154
module from the Centered Binomial Distribution.
152155
"""
153-
elements = [0 for _ in range(self.k)]
156+
elements = [self.R.zero() for _ in range(self.k)]
154157
for i in range(self.k):
155158
prf_output = self._prf(eta, sigma, bytes([N]))
156159
elements[i] = self.R.cbd(prf_output, eta)
157160
N += 1
158161
v = self.M.vector(elements)
159162
return v, N
160163

161-
def _generate_polynomial(self, sigma, eta, N):
164+
def _generate_polynomial(
165+
self, sigma: bytes, eta: int, N: int
166+
) -> tuple[Polynomial, int]:
162167
"""
163168
Helper function which generates a element in the
164169
polynomial ring from the Centered Binomial Distribution.
@@ -167,7 +172,7 @@ def _generate_polynomial(self, sigma, eta, N):
167172
p = self.R.cbd(prf_output, eta)
168173
return p, N + 1
169174

170-
def _k_pke_keygen(self, d):
175+
def _k_pke_keygen(self, d: bytes) -> tuple[bytes, bytes]:
171176
"""
172177
Use randomness to generate an encryption key and a corresponding
173178
decryption key following Algorithm 13 (FIPS 203)
@@ -203,7 +208,7 @@ def _k_pke_keygen(self, d):
203208

204209
return (ek_pke, dk_pke)
205210

206-
def _k_pke_encrypt(self, ek_pke, m, r):
211+
def _k_pke_encrypt(self, ek_pke: bytes, m: bytes, r: bytes) -> bytes:
207212
"""
208213
Uses the encryption key to encrypt a plaintext message using the
209214
randomness r following Algorithm 14 (FIPS 203)
@@ -231,9 +236,7 @@ def _k_pke_encrypt(self, ek_pke, m, r):
231236

232237
# Next check that t_hat has been canonically encoded
233238
if t_hat.encode(12) != t_hat_bytes:
234-
raise ValueError(
235-
"Modulus check failed, t_hat does not encode correctly"
236-
)
239+
raise ValueError("Modulus check failed, t_hat does not encode correctly")
237240

238241
# Generate A_hat^T from seed rho
239242
A_hat_T = self._generate_matrix_from_seed(rho, transpose=True)
@@ -255,7 +258,7 @@ def _k_pke_encrypt(self, ek_pke, m, r):
255258

256259
return c1 + c2
257260

258-
def _k_pke_decrypt(self, dk_pke, c):
261+
def _k_pke_decrypt(self, dk_pke: bytes, c: bytes) -> bytes:
259262
"""
260263
Uses the decryption key to decrypt a ciphertext following
261264
Algorithm 15 (FIPS 203)
@@ -273,7 +276,7 @@ def _k_pke_decrypt(self, dk_pke, c):
273276

274277
return m
275278

276-
def _keygen_internal(self, d, z):
279+
def _keygen_internal(self, d: bytes, z: bytes) -> tuple[bytes, bytes]:
277280
"""
278281
Use randomness to generate an encapsulation key and a corresponding
279282
decapsulation key following Algorithm 16 (FIPS 203)
@@ -288,7 +291,7 @@ def _keygen_internal(self, d, z):
288291

289292
return (ek, dk)
290293

291-
def keygen(self):
294+
def keygen(self) -> tuple[bytes, bytes]:
292295
"""
293296
Generate an encapsulation key and corresponding decapsulation key
294297
following Algorithm 19 (FIPS 203)
@@ -309,7 +312,7 @@ def keygen(self):
309312
) = self._keygen_internal(d, z)
310313
return (ek, dk)
311314

312-
def key_derive(self, seed):
315+
def key_derive(self, seed: bytes) -> tuple[bytes, bytes]:
313316
"""
314317
Derive an encapsulation key and corresponding decapsulation key
315318
following the approach from Section 7.1 (FIPS 203)
@@ -329,7 +332,7 @@ def key_derive(self, seed):
329332
ek, dk = self._keygen_internal(d, z)
330333
return (ek, dk)
331334

332-
def _encaps_internal(self, ek, m):
335+
def _encaps_internal(self, ek: bytes, m: bytes) -> tuple[bytes, bytes]:
333336
"""
334337
Uses the encapsulation key and randomness to generate a key and an
335338
associated ciphertext following Algorithm 17 (FIPS 203)
@@ -355,7 +358,7 @@ def _encaps_internal(self, ek, m):
355358

356359
return K, c
357360

358-
def encaps(self, ek):
361+
def encaps(self, ek: bytes) -> tuple[bytes, bytes]:
359362
"""
360363
Uses the encapsulation key to generate a shared secret key and an
361364
associated ciphertext following Algorithm 20 (FIPS 203)
@@ -374,7 +377,7 @@ def encaps(self, ek):
374377
K, c = self._encaps_internal(ek, m)
375378
return K, c
376379

377-
def _decaps_internal(self, dk, c):
380+
def _decaps_internal(self, dk: bytes, c: bytes) -> bytes:
378381
"""
379382
Uses the decapsulation key to produce a shared secret key from a
380383
ciphertext following Algorithm 18 (FIPS 203)
@@ -429,7 +432,7 @@ def _decaps_internal(self, dk, c):
429432
# performed in constant time
430433
return select_bytes(K_bar, K_prime, c == c_prime)
431434

432-
def decaps(self, dk, c):
435+
def decaps(self, dk: bytes, c: bytes) -> bytes:
433436
"""
434437
Uses the decapsulation key to produce a shared secret key from a
435438
ciphertext following Algorithm 21 (FIPS 203).

src/kyber_py/ml_kem/pkcs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def ek_to_der(kem, ek):
3434
raise ValueError("Only KEMs with specified OIDs can be encoded")
3535

3636
if len(ek) != kem._ek_size():
37-
raise ValueError(f"Provided key size doesn't match the provided kem")
37+
raise ValueError("Provided key size doesn't match the provided kem")
3838

3939
enc = der.encode_sequence(
4040
der.encode_sequence(
@@ -252,7 +252,7 @@ def dk_from_der(enc_key):
252252
else:
253253
tag, seed, empty = der.remove_implicit(priv_key)
254254
if tag != 0:
255-
raise der.UnexpectedDER(f"Unexpected tag in private key encoding")
255+
raise der.UnexpectedDER("Unexpected tag in private key encoding")
256256
if empty:
257257
raise der.UnexpectedDER("Junk after seed encoding")
258258

0 commit comments

Comments
 (0)