Skip to content

Commit a707846

Browse files
committed
add explicit type hints to ml-kem
1 parent d0a8f82 commit a707846

15 files changed

+234
-192
lines changed

README.md

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
> [!CAUTION]
1010
> :warning: **Under no circumstances should this be used for cryptographic
1111
applications.** :warning:
12-
>
12+
>
1313
> This is an educational resource and has not been designed to be secure
1414
> against any form of side-channel attack. The intended use of this project
1515
> is for learning and experimenting with ML-KEM and Kyber
@@ -26,7 +26,7 @@ from the NIST post-quantum cryptography project.
2626
**Note**: This project accompanies
2727
[`dilithium-py`](https://github.com/GiacomoPope/dilithium-py) which is a
2828
pure-python implementation of ML-DSA and CRYSTALS-Dilithium and shares a lot of
29-
the lower-level code of this implementation.
29+
the lower-level code of this implementation.
3030

3131
## Disclaimer
3232

@@ -82,7 +82,7 @@ the
8282

8383
Originally this project was planned to have zero dependencies, however to make this work
8484
pass the KATs, we needed a deterministic CSRNG. The reference implementation uses
85-
AES256 CTR DRBG. I have implemented this in [`aes256_ctr_drbg.py`](src/kyber_py/drbg/aes256_ctr_drbg.py).
85+
AES256 CTR DRBG. I have implemented this in [`aes256_ctr_drbg.py`](src/kyber_py/drbg/aes256_ctr_drbg.py).
8686
However, I have not implemented AES itself, instead I import this from `pycryptodome`. If this dependency is too annoying, then please make an issue and we can have a pure-python AES included into the repo.
8787

8888
To install dependencies, run `pip -r install requirements`.
@@ -206,7 +206,7 @@ require each element in a ring to have a multiplicative inverse). The ring in qu
206206
To help with experimenting with these polynomial rings themselves, the file [`polynomials_generic.py`](src/kyber_py/polynomials/polynomials_generic.py) has an implementation of the univariate polynomial ring
207207

208208
$$
209-
R_q = \mathbb{F}_q[X] /(X^n + 1)
209+
R_q = \mathbb{F}_q[X] /(X^n + 1)
210210
$$
211211

212212
where the user can select any $q, n$. For example, you can create the
@@ -215,8 +215,8 @@ ring $R_{11} = \mathbb{F}_{11}[X] /(X^8 + 1)$ in the following way:
215215
#### Example
216216

217217
```python
218-
>>> from kyber_py.polynomials.polynomials_generic import PolynomialRing
219-
>>> R = PolynomialRing(11, 8)
218+
>>> from kyber_py.polynomials.polynomials_generic import GenericPolynomialRing
219+
>>> R = GenericPolynomialRing(11, 8)
220220
>>> x = R.gen()
221221
>>> f = 3*x**3 + 4*x**7
222222
>>> g = R.random_element(); g
@@ -233,23 +233,23 @@ We hope that this allows for some hands-on experience at working with these
233233
polynomials before starting to play with the whole of Kyber/ML-KEM.
234234

235235
For the "Kyber-specific" functions, needed to implement the protocol itself, we
236-
have made a child class `PolynomialRingKyber(PolynomialRing)` which has the
236+
have made a child class `PolynomialRing(GenericPolynomialRing)` which has the
237237
following additional methods:
238238

239-
- `PolynomialRingKyber`
239+
- `PolynomialRing`
240240
- `ntt_sample(bytes)` takes $3n$ bytes and produces a random polynomial in $R_q$
241241
- `decode(bytes, l)` takes $\ell n$ bits and produces a polynomial in $R_q$
242242
- `cbd(beta, eta)` takes $\eta \cdot n / 4$ bytes and produces a polynomial in
243243
$R_q$ with coefficents taken from a centered binomial distribution
244-
- `PolynomialKyber`
244+
- `Polynomial`
245245
- `encode(l)` takes the polynomial and returns a length $\ell n / 8$ bytearray
246246
- `to_ntt()` converts the polynomial into the NTT domain for efficient
247247
polynomial multiplication and returns an element of type
248-
`PolynomialKyberNTT`
249-
- `PolynomialKyberNTT`
248+
`PolynomialNTT`
249+
- `PolynomialNTT`
250250
- `from_ntt()` converts the polynomial back from the NTT domain and returns an
251-
element of type `PolynomialKyber`
252-
251+
element of type `Polynomial`
252+
253253
This class fixes $q = 3329$ and $n = 256$
254254

255255
Lastly, we define a `self.compress(d)` and `self.decompress(d)` method for
@@ -276,20 +276,20 @@ Building on `polynomials_generic.py` we also include a file
276276
[`modules_generic.py`](src/kyber_py/modules/modules_generic.py) which has all of
277277
the functions needed to perform linear algebra given a ring.
278278

279-
Note that `Matrix` allows elements of the module to be of size $m \times n$ but
279+
Note that `GenericMatrix` allows elements of the module to be of size $m \times n$ but
280280
for Kyber, we only need vectors of length $k$ and square matrices of size $k
281281
\times k$.
282282

283-
As an example of the operations we can perform with out `Module` lets revisit
283+
As an example of the operations we can perform with out `GenericModule` let's revisit
284284
the ring from the previous example:
285285

286286
#### Example
287287

288288
```python
289-
>>> R = PolynomialRing(11, 8)
289+
>>> R = GenericPolynomialRing(11, 8)
290290
>>> x = R.gen()
291291
>>>
292-
>>> M = Module(R)
292+
>>> M = GenericModule(R)
293293
>>> # We create a matrix by feeding the coefficients to M
294294
>>> A = M([[x + 3*x**2, 4 + 3*x**7], [3*x**3 + 9*x**7, x**4]])
295295
>>> A
@@ -325,8 +325,8 @@ the ring from the previous example:
325325
[ 2 + 6*x^4 + x^5]
326326
```
327327

328-
On top of this class, we have the classes `ModuleKyber(Module)` and
329-
`MatrixKyber(Matrix)` which have helper functions which (for example) encode
328+
On top of this class, we have the classes `Module(GenericModule)` and
329+
`Matrix(GenericMatrix)` which have helper functions which (for example) encode
330330
every element of a matrix, or convert every element to or from the NTT domain.
331-
These are simple functions which call the respective `PolynomialKyber` methods
331+
These are simple functions which call the respective `Polynomial` methods
332332
for every element.

benchmarks/benchmark_kyber.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ def benchmark_kyber(Kyber, name, count):
5353
avg_dec = sum(dec_times) / count
5454
print(
5555
f" {name:11} |"
56-
f"{avg_keygen*1000:7.2f}ms | {1/avg_keygen:10.2f} |"
57-
f"{avg_enc*1000:6.2f}ms | {1/avg_enc:9.2f} |"
58-
f"{avg_dec*1000:6.2f}ms | {1/avg_dec:7.2f} |"
56+
f"{avg_keygen * 1000:7.2f}ms | {1 / avg_keygen:10.2f} |"
57+
f"{avg_enc * 1000:6.2f}ms | {1 / avg_enc:9.2f} |"
58+
f"{avg_dec * 1000:6.2f}ms | {1 / avg_dec:7.2f} |"
5959
)
6060

6161

benchmarks/benchmark_ml_kem.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ def benchmark_ml_kem(ML_KEM, name, count):
5353
avg_dec = sum(dec_times) / count
5454
print(
5555
f" {name:11} |"
56-
f"{avg_keygen*1000:7.2f}ms | {1/avg_keygen:10.2f} |"
57-
f"{avg_enc*1000:6.2f}ms | {1/avg_enc:9.2f} |"
58-
f"{avg_dec*1000:6.2f}ms | {1/avg_dec:7.2f} |"
56+
f"{avg_keygen * 1000:7.2f}ms | {1 / avg_keygen:10.2f} |"
57+
f"{avg_enc * 1000:6.2f}ms | {1 / avg_enc:9.2f} |"
58+
f"{avg_dec * 1000:6.2f}ms | {1 / avg_dec:7.2f} |"
5959
)
6060

6161

src/kyber_py/kyber/kyber.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
from hashlib import sha3_256, sha3_512, shake_128, shake_256
3-
from ..modules.modules import ModuleKyber
3+
from ..modules.modules import Module
44
from ..utilities.utils import select_bytes
55

66

@@ -17,7 +17,7 @@ def __init__(self, parameter_set):
1717
self.du = parameter_set["du"]
1818
self.dv = parameter_set["dv"]
1919

20-
self.M = ModuleKyber()
20+
self.M = Module()
2121
self.R = self.M.ring
2222

2323
# Use system randomness by default, for deterministic randomness
@@ -111,7 +111,7 @@ def _generate_error_vector(self, sigma, eta, N):
111111
Helper function which generates a element in the
112112
module from the Centered Binomial Distribution.
113113
"""
114-
elements = [0 for _ in range(self.k)]
114+
elements = [self.R.zero() for _ in range(self.k)]
115115
for i in range(self.k):
116116
input_bytes = self._prf(sigma, bytes([N]), 64 * eta)
117117
elements[i] = self.R.cbd(input_bytes, eta)
@@ -135,7 +135,9 @@ def _generate_matrix_from_seed(self, rho, transpose=False):
135135
136136
When `transpose` is set to True, the matrix A is built as the transpose.
137137
"""
138-
A_data = [[0 for _ in range(self.k)] for _ in range(self.k)]
138+
A_data = [
139+
[self.R.zero() for _ in range(self.k)] for _ in range(self.k)
140+
]
139141
for i in range(self.k):
140142
for j in range(self.k):
141143
input_bytes = self._xof(rho, bytes([j]), bytes([i]))

0 commit comments

Comments
 (0)