5
5
6
6
import os
7
7
from hashlib import sha3_256 , sha3_512 , shake_128 , shake_256
8
- from ..modules .modules import ModuleKyber
8
+ from ..modules .modules import Module , Matrix , Vector
9
+ from ..polynomials .polynomials import Polynomial
9
10
from ..utilities .utils import select_bytes
10
11
11
12
12
13
class ML_KEM :
13
- def __init__ (self , params ):
14
+ def __init__ (self , params : dict ):
14
15
"""
15
16
Initialise the ML-KEM with specified lattice parameters.
16
17
@@ -23,31 +24,31 @@ def __init__(self, params):
23
24
self .du = params ["du" ]
24
25
self .dv = params ["dv" ]
25
26
26
- self .M = ModuleKyber ()
27
+ self .M = Module ()
27
28
self .R = self .M .ring
28
29
self .oid = params ["oid" ] if "oid" in params else None
29
30
30
31
# Use system randomness by default, for deterministic randomness
31
32
# use the method `set_drbg_seed()`
32
33
self .random_bytes = os .urandom
33
34
34
- def _ek_size (self ):
35
+ def _ek_size (self ) -> int :
35
36
"""
36
37
Return the size of the encapsulation key for the selected paramters.
37
38
38
39
:rtype: int
39
40
"""
40
41
return 384 * self .k + 32
41
42
42
- def _dk_size (self ):
43
+ def _dk_size (self ) -> int :
43
44
"""
44
45
Return the size of the decapsulation key for the selected parameters.
45
46
46
47
:rtype: int
47
48
"""
48
49
return 768 * self .k + 96
49
50
50
- def set_drbg_seed (self , seed ):
51
+ def set_drbg_seed (self , seed : bytes ):
51
52
"""
52
53
Change entropy source to a DRBG and seed it with provided value.
53
54
@@ -74,7 +75,7 @@ def set_drbg_seed(self, seed):
74
75
)
75
76
76
77
@staticmethod
77
- def _xof (bytes32 , i , j ) :
78
+ def _xof (b : bytes , i : bytes , j : bytes ) -> bytes :
78
79
"""
79
80
eXtendable-Output Function (XOF) described in 4.9 of FIPS 203 (page 19)
80
81
@@ -88,15 +89,15 @@ def _xof(bytes32, i, j):
88
89
Casa de Chá da Boa Nova
89
90
https://cryptojedi.org/papers/terminate-20230516.pdf
90
91
"""
91
- input_bytes = bytes32 + i + j
92
+ input_bytes = b + i + j
92
93
if len (input_bytes ) != 34 :
93
94
raise ValueError (
94
95
"Input bytes should be one 32 byte array and 2 single bytes."
95
96
)
96
97
return shake_128 (input_bytes ).digest (840 )
97
98
98
99
@staticmethod
99
- def _prf (eta , s , b ) :
100
+ def _prf (eta : int , s : bytes , b : bytes ) -> bytes :
100
101
"""
101
102
Pseudorandom function described in 4.3 of FIPS 203 (page 18)
102
103
"""
@@ -108,57 +109,61 @@ def _prf(eta, s, b):
108
109
return shake_256 (input_bytes ).digest (eta * 64 )
109
110
110
111
@staticmethod
111
- def _H (s ) :
112
+ def _H (s : bytes ) -> bytes :
112
113
"""
113
114
Hash function described in 4.4 of FIPS 203 (page 18)
114
115
"""
115
116
return sha3_256 (s ).digest ()
116
117
117
118
@staticmethod
118
- def _J (s ) :
119
+ def _J (s : bytes ) -> bytes :
119
120
"""
120
121
Hash function described in 4.4 of FIPS 203 (page 18)
121
122
"""
122
123
return shake_256 (s ).digest (32 )
123
124
124
125
@staticmethod
125
- def _G (s ) :
126
+ def _G (s : bytes ) -> tuple [ bytes , bytes ] :
126
127
"""
127
128
Hash function described in 4.5 of FIPS 203 (page 18)
128
129
"""
129
130
h = sha3_512 (s ).digest ()
130
131
return h [:32 ], h [32 :]
131
132
132
- def _generate_matrix_from_seed (self , rho , transpose = False ):
133
+ def _generate_matrix_from_seed (self , rho : bytes , transpose : bool = False ) -> Matrix :
133
134
"""
134
135
Helper function which generates a element of size
135
136
k x k from a seed `rho`.
136
137
137
138
When `transpose` is set to True, the matrix A is
138
139
built as the transpose.
139
140
"""
140
- A_data = [[0 for _ in range (self .k )] for _ in range (self .k )]
141
+ A_data = [[self . R . zero () for _ in range (self .k )] for _ in range (self .k )]
141
142
for i in range (self .k ):
142
143
for j in range (self .k ):
143
144
xof_bytes = self ._xof (rho , bytes ([j ]), bytes ([i ]))
144
145
A_data [i ][j ] = self .R .ntt_sample (xof_bytes )
145
146
A_hat = self .M (A_data , transpose = transpose )
146
147
return A_hat
147
148
148
- def _generate_error_vector (self , sigma , eta , N ):
149
+ def _generate_error_vector (
150
+ self , sigma : bytes , eta : int , N : int
151
+ ) -> tuple [Vector , int ]:
149
152
"""
150
153
Helper function which generates a element in the
151
154
module from the Centered Binomial Distribution.
152
155
"""
153
- elements = [0 for _ in range (self .k )]
156
+ elements = [self . R . zero () for _ in range (self .k )]
154
157
for i in range (self .k ):
155
158
prf_output = self ._prf (eta , sigma , bytes ([N ]))
156
159
elements [i ] = self .R .cbd (prf_output , eta )
157
160
N += 1
158
161
v = self .M .vector (elements )
159
162
return v , N
160
163
161
- def _generate_polynomial (self , sigma , eta , N ):
164
+ def _generate_polynomial (
165
+ self , sigma : bytes , eta : int , N : int
166
+ ) -> tuple [Polynomial , int ]:
162
167
"""
163
168
Helper function which generates a element in the
164
169
polynomial ring from the Centered Binomial Distribution.
@@ -167,7 +172,7 @@ def _generate_polynomial(self, sigma, eta, N):
167
172
p = self .R .cbd (prf_output , eta )
168
173
return p , N + 1
169
174
170
- def _k_pke_keygen (self , d ) :
175
+ def _k_pke_keygen (self , d : bytes ) -> tuple [ bytes , bytes ] :
171
176
"""
172
177
Use randomness to generate an encryption key and a corresponding
173
178
decryption key following Algorithm 13 (FIPS 203)
@@ -203,7 +208,7 @@ def _k_pke_keygen(self, d):
203
208
204
209
return (ek_pke , dk_pke )
205
210
206
- def _k_pke_encrypt (self , ek_pke , m , r ) :
211
+ def _k_pke_encrypt (self , ek_pke : bytes , m : bytes , r : bytes ) -> bytes :
207
212
"""
208
213
Uses the encryption key to encrypt a plaintext message using the
209
214
randomness r following Algorithm 14 (FIPS 203)
@@ -231,9 +236,7 @@ def _k_pke_encrypt(self, ek_pke, m, r):
231
236
232
237
# Next check that t_hat has been canonically encoded
233
238
if t_hat .encode (12 ) != t_hat_bytes :
234
- raise ValueError (
235
- "Modulus check failed, t_hat does not encode correctly"
236
- )
239
+ raise ValueError ("Modulus check failed, t_hat does not encode correctly" )
237
240
238
241
# Generate A_hat^T from seed rho
239
242
A_hat_T = self ._generate_matrix_from_seed (rho , transpose = True )
@@ -255,7 +258,7 @@ def _k_pke_encrypt(self, ek_pke, m, r):
255
258
256
259
return c1 + c2
257
260
258
- def _k_pke_decrypt (self , dk_pke , c ) :
261
+ def _k_pke_decrypt (self , dk_pke : bytes , c : bytes ) -> bytes :
259
262
"""
260
263
Uses the decryption key to decrypt a ciphertext following
261
264
Algorithm 15 (FIPS 203)
@@ -273,7 +276,7 @@ def _k_pke_decrypt(self, dk_pke, c):
273
276
274
277
return m
275
278
276
- def _keygen_internal (self , d , z ) :
279
+ def _keygen_internal (self , d : bytes , z : bytes ) -> tuple [ bytes , bytes ] :
277
280
"""
278
281
Use randomness to generate an encapsulation key and a corresponding
279
282
decapsulation key following Algorithm 16 (FIPS 203)
@@ -288,7 +291,7 @@ def _keygen_internal(self, d, z):
288
291
289
292
return (ek , dk )
290
293
291
- def keygen (self ):
294
+ def keygen (self ) -> tuple [ bytes , bytes ] :
292
295
"""
293
296
Generate an encapsulation key and corresponding decapsulation key
294
297
following Algorithm 19 (FIPS 203)
@@ -309,7 +312,7 @@ def keygen(self):
309
312
) = self ._keygen_internal (d , z )
310
313
return (ek , dk )
311
314
312
- def key_derive (self , seed ) :
315
+ def key_derive (self , seed : bytes ) -> tuple [ bytes , bytes ] :
313
316
"""
314
317
Derive an encapsulation key and corresponding decapsulation key
315
318
following the approach from Section 7.1 (FIPS 203)
@@ -329,7 +332,7 @@ def key_derive(self, seed):
329
332
ek , dk = self ._keygen_internal (d , z )
330
333
return (ek , dk )
331
334
332
- def _encaps_internal (self , ek , m ) :
335
+ def _encaps_internal (self , ek : bytes , m : bytes ) -> tuple [ bytes , bytes ] :
333
336
"""
334
337
Uses the encapsulation key and randomness to generate a key and an
335
338
associated ciphertext following Algorithm 17 (FIPS 203)
@@ -355,7 +358,7 @@ def _encaps_internal(self, ek, m):
355
358
356
359
return K , c
357
360
358
- def encaps (self , ek ) :
361
+ def encaps (self , ek : bytes ) -> tuple [ bytes , bytes ] :
359
362
"""
360
363
Uses the encapsulation key to generate a shared secret key and an
361
364
associated ciphertext following Algorithm 20 (FIPS 203)
@@ -374,7 +377,7 @@ def encaps(self, ek):
374
377
K , c = self ._encaps_internal (ek , m )
375
378
return K , c
376
379
377
- def _decaps_internal (self , dk , c ) :
380
+ def _decaps_internal (self , dk : bytes , c : bytes ) -> bytes :
378
381
"""
379
382
Uses the decapsulation key to produce a shared secret key from a
380
383
ciphertext following Algorithm 18 (FIPS 203)
@@ -429,7 +432,7 @@ def _decaps_internal(self, dk, c):
429
432
# performed in constant time
430
433
return select_bytes (K_bar , K_prime , c == c_prime )
431
434
432
- def decaps (self , dk , c ) :
435
+ def decaps (self , dk : bytes , c : bytes ) -> bytes :
433
436
"""
434
437
Uses the decapsulation key to produce a shared secret key from a
435
438
ciphertext following Algorithm 21 (FIPS 203).
0 commit comments