Skip to content

Commit 2be618f

Browse files
committed
format with black instead of ruff
1 parent 1fb9668 commit 2be618f

15 files changed

+138
-45
lines changed

src/kyber_py/drbg/aes256_ctr_drbg.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55

66

77
class AES256_CTR_DRBG:
8-
def __init__(self, seed: Optional[bytes] = None, personalization: bytes = b""):
8+
def __init__(
9+
self, seed: Optional[bytes] = None, personalization: bytes = b""
10+
):
911
"""
1012
DRBG implementation based on AES-256 CTR following the document NIST SP
1113
800-90A Section 10.2.1
@@ -57,7 +59,9 @@ def __instantiate(self, personalization: bytes = b"") -> bytes:
5759
f"{self.seed_length}. Input has length {len(personalization)}"
5860
)
5961
# Ensure personalization has exactly seed_length bytes
60-
personalization += bytes([0]) * (self.seed_length - len(personalization))
62+
personalization += bytes([0]) * (
63+
self.seed_length - len(personalization)
64+
)
6165
# debugging
6266
assert len(personalization) == self.seed_length
6367
return xor_bytes(self.entropy_input, personalization)
@@ -93,7 +97,9 @@ def __ctr_drbg_update(self, provided_data: bytes) -> None:
9397
self.key = tmp[:32]
9498
self.V = tmp[32:]
9599

96-
def random_bytes(self, num_bytes: int, additional: Optional[bytes] = None) -> bytes:
100+
def random_bytes(
101+
self, num_bytes: int, additional: Optional[bytes] = None
102+
) -> bytes:
97103
"""
98104
Generate pseudorandom bytes without a generating function
99105

src/kyber_py/kyber/kyber.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,9 @@ def _generate_matrix_from_seed(self, rho, transpose=False):
135135
136136
When `transpose` is set to True, the matrix A is built as the transpose.
137137
"""
138-
A_data = [[self.R.zero() for _ in range(self.k)] for _ in range(self.k)]
138+
A_data = [
139+
[self.R.zero() for _ in range(self.k)] for _ in range(self.k)
140+
]
139141
for i in range(self.k):
140142
for j in range(self.k):
141143
input_bytes = self._xof(rho, bytes([j]), bytes([i]))

src/kyber_py/ml_kem/ml_kem.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,19 @@ def _G(s: bytes) -> tuple[bytes, bytes]:
130130
h = sha3_512(s).digest()
131131
return h[:32], h[32:]
132132

133-
def _generate_matrix_from_seed(self, rho: bytes, transpose: bool = False) -> Matrix:
133+
def _generate_matrix_from_seed(
134+
self, rho: bytes, transpose: bool = False
135+
) -> Matrix:
134136
"""
135137
Helper function which generates a element of size
136138
k x k from a seed `rho`.
137139
138140
When `transpose` is set to True, the matrix A is
139141
built as the transpose.
140142
"""
141-
A_data = [[self.R.zero() for _ in range(self.k)] for _ in range(self.k)]
143+
A_data = [
144+
[self.R.zero() for _ in range(self.k)] for _ in range(self.k)
145+
]
142146
for i in range(self.k):
143147
for j in range(self.k):
144148
xof_bytes = self._xof(rho, bytes([j]), bytes([i]))
@@ -236,7 +240,9 @@ def _k_pke_encrypt(self, ek_pke: bytes, m: bytes, r: bytes) -> bytes:
236240

237241
# Next check that t_hat has been canonically encoded
238242
if t_hat.encode(12) != t_hat_bytes:
239-
raise ValueError("Modulus check failed, t_hat does not encode correctly")
243+
raise ValueError(
244+
"Modulus check failed, t_hat does not encode correctly"
245+
)
240246

241247
# Generate A_hat^T from seed rho
242248
A_hat_T = self._generate_matrix_from_seed(rho, transpose=True)

src/kyber_py/modules/modules.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,9 @@ def decode_vector(self, input_bytes, k, d, is_ntt=False):
8989
# Ensure the input bytes are the correct length to create k elements with
9090
# d bits used for each coefficient
9191
if self.ring.n * d * k != len(input_bytes) * 8:
92-
raise ValueError("Byte length is the wrong length for given k, d values")
92+
raise ValueError(
93+
"Byte length is the wrong length for given k, d values"
94+
)
9395

9496
# Bytes needed to decode a polynomial
9597
n = 32 * d

src/kyber_py/modules/modules_generic.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ def __getitem__(self, idx):
5454
"""
5555
matrix[i, j] returns the element on row i, column j
5656
"""
57-
assert isinstance(idx, tuple) and len(idx) == 2, "Can't access individual rows"
57+
assert (
58+
isinstance(idx, tuple) and len(idx) == 2
59+
), "Can't access individual rows"
5860
if not self._transpose:
5961
return self._data[idx[0]][idx[1]]
6062
else:
@@ -64,7 +66,9 @@ def __eq__(self, other):
6466
if self.dim() != other.dim():
6567
return False
6668
m, n = self.dim()
67-
return all([self[i, j] == other[i, j] for i in range(m) for j in range(n)])
69+
return all(
70+
[self[i, j] == other[i, j] for i in range(m) for j in range(n)]
71+
)
6872

6973
def __neg__(self):
7074
"""
@@ -128,7 +132,10 @@ def __matmul__(self, other):
128132

129133
return self.parent(
130134
[
131-
[sum(self[i, k] * other[k, j] for k in range(n)) for j in range(l)]
135+
[
136+
sum(self[i, k] * other[k, j] for k in range(n))
137+
for j in range(l)
138+
]
132139
for i in range(m)
133140
]
134141
)
@@ -149,10 +156,17 @@ def __repr__(self):
149156
if m == 1:
150157
return str(self._data[0])
151158

152-
max_col_width = [max(len(str(self[i, j])) for i in range(m)) for j in range(n)]
159+
max_col_width = [
160+
max(len(str(self[i, j])) for i in range(m)) for j in range(n)
161+
]
153162
info = "]\n[".join(
154163
[
155-
", ".join([f"{str(self[i, j]):>{max_col_width[j]}}" for j in range(n)])
164+
", ".join(
165+
[
166+
f"{str(self[i, j]):>{max_col_width[j]}}"
167+
for j in range(n)
168+
]
169+
)
156170
for i in range(m)
157171
]
158172
)
@@ -175,7 +189,9 @@ def random_element(self, m, n):
175189
:param int m: the number of columns in tge matrix
176190
:return: an element of the module with dimension `m times n`
177191
"""
178-
elements = [[self.ring.random_element() for _ in range(n)] for _ in range(m)]
192+
elements = [
193+
[self.ring.random_element() for _ in range(n)] for _ in range(m)
194+
]
179195
return self(elements)
180196

181197
def __repr__(self):
@@ -192,14 +208,18 @@ def __call__(self, matrix_elements, transpose=False):
192208

193209
if isinstance(matrix_elements[0], list):
194210
for element_list in matrix_elements:
195-
if not all(isinstance(aij, self.ring.element) for aij in element_list):
211+
if not all(
212+
isinstance(aij, self.ring.element) for aij in element_list
213+
):
196214
raise TypeError(
197215
f"All elements of the matrix must be elements of the ring: {self.ring}"
198216
)
199217
return self.matrix(self, matrix_elements, transpose=transpose)
200218

201219
elif isinstance(matrix_elements[0], self.ring.element):
202-
if not all(isinstance(aij, self.ring.element) for aij in matrix_elements):
220+
if not all(
221+
isinstance(aij, self.ring.element) for aij in matrix_elements
222+
):
203223
raise TypeError(
204224
f"All elements of the matrix must be elements of the ring: {self.ring}"
205225
)

src/kyber_py/polynomials/polynomials.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ def __init__(self):
1616
self.element_ntt = PolynomialNTT
1717

1818
root_of_unity = 17
19-
self.ntt_zetas = [pow(root_of_unity, self._br(i, 7), 3329) for i in range(128)]
19+
self.ntt_zetas = [
20+
pow(root_of_unity, self._br(i, 7), 3329) for i in range(128)
21+
]
2022
self.ntt_f = pow(128, -1, 3329)
2123

2224
@staticmethod
@@ -212,7 +214,9 @@ def to_ntt(self):
212214
"""
213215
Not supported, raises a ``TypeError``
214216
"""
215-
raise TypeError(f"Polynomial is already in the NTT domain: {type(self) = }")
217+
raise TypeError(
218+
f"Polynomial is already in the NTT domain: {type(self) = }"
219+
)
216220

217221
def from_ntt(self):
218222
"""
@@ -281,7 +285,9 @@ def _ntt_multiplication(self, other):
281285
"""
282286
Number Theoretic Transform multiplication.
283287
"""
284-
new_coeffs = self._ntt_coefficient_multiplication(self.coeffs, other.coeffs)
288+
new_coeffs = self._ntt_coefficient_multiplication(
289+
self.coeffs, other.coeffs
290+
)
285291
return new_coeffs
286292

287293
def __add__(self, other):

src/kyber_py/polynomials/polynomials_generic.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,13 +124,16 @@ def __neg__(self):
124124
def _add_(self, other):
125125
if isinstance(other, type(self)):
126126
new_coeffs = [
127-
self._add_mod_q(x, y) for x, y in zip(self.coeffs, other.coeffs)
127+
self._add_mod_q(x, y)
128+
for x, y in zip(self.coeffs, other.coeffs)
128129
]
129130
elif isinstance(other, int):
130131
new_coeffs = self.coeffs.copy()
131132
new_coeffs[0] = self._add_mod_q(new_coeffs[0], other)
132133
else:
133-
raise NotImplementedError("Polynomials can only be added to each other")
134+
raise NotImplementedError(
135+
"Polynomials can only be added to each other"
136+
)
134137
return new_coeffs
135138

136139
def __add__(self, other):
@@ -147,7 +150,8 @@ def __iadd__(self, other):
147150
def _sub_(self, other):
148151
if isinstance(other, type(self)):
149152
new_coeffs = [
150-
self._sub_mod_q(x, y) for x, y in zip(self.coeffs, other.coeffs)
153+
self._sub_mod_q(x, y)
154+
for x, y in zip(self.coeffs, other.coeffs)
151155
]
152156
elif isinstance(other, int):
153157
new_coeffs = self.coeffs.copy()
@@ -211,7 +215,10 @@ def __eq__(self, other):
211215
if isinstance(other, type(self)):
212216
return self.coeffs == other.coeffs
213217
elif isinstance(other, int):
214-
if self.is_constant() and (other % self.parent.q) == self.coeffs[0]:
218+
if (
219+
self.is_constant()
220+
and (other % self.parent.q) == self.coeffs[0]
221+
):
215222
return True
216223
return False
217224

tests/test_drbg.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ def test_bad_personalization(self):
3535
# if the personalization is longer than 48 bytes, fail
3636
seed = os.urandom(48)
3737
personalization = os.urandom(49)
38-
self.assertRaises(ValueError, lambda: AES256_CTR_DRBG(seed, personalization))
38+
self.assertRaises(
39+
ValueError, lambda: AES256_CTR_DRBG(seed, personalization)
40+
)
3941

4042
def test_additional(self):
4143
drbg = AES256_CTR_DRBG()
@@ -47,4 +49,6 @@ def test_bad_additional(self):
4749
# if the additional data is longer than 48 bytes, fail
4850
drbg = AES256_CTR_DRBG()
4951
additional = os.urandom(49)
50-
self.assertRaises(ValueError, lambda: drbg.random_bytes(32, additional))
52+
self.assertRaises(
53+
ValueError, lambda: drbg.random_bytes(32, additional)
54+
)

tests/test_kyber.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ def parse_kat_data(data):
1010
count_blocks = data.split("\n\n")
1111
for block in count_blocks[:-1]:
1212
block_data = block.split("\n")
13-
count, seed, pk, sk, ct, ss = [line.split(" = ")[-1] for line in block_data]
13+
count, seed, pk, sk, ct, ss = [
14+
line.split(" = ")[-1] for line in block_data
15+
]
1416
parsed_data[int(count)] = {
1517
"seed": bytes.fromhex(seed),
1618
"pk": bytes.fromhex(pk),

tests/test_ml_kem.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ def test_encaps_modulus_check_failure(self):
5151
self.assertRaises(ValueError, lambda: ML_KEM_512.encaps(bad_ek))
5252

5353
def test_xof_failure(self):
54-
self.assertRaises(ValueError, lambda: ML_KEM_512._xof(b"1", b"2", b"3"))
54+
self.assertRaises(
55+
ValueError, lambda: ML_KEM_512._xof(b"1", b"2", b"3")
56+
)
5557

5658
def test_prf_failure(self):
5759
self.assertRaises(ValueError, lambda: ML_KEM_512._prf(2, b"1", b"2"))
@@ -256,7 +258,9 @@ def generic_keygen_kat(self, ML_KEM, index):
256258
self.assertEqual(dk, dk_kat)
257259

258260
def generic_encap_kat(self, ML_KEM, index):
259-
with open("assets/ML-KEM-encapDecap-FIPS203/internalProjection.json") as f:
261+
with open(
262+
"assets/ML-KEM-encapDecap-FIPS203/internalProjection.json"
263+
) as f:
260264
data = json.load(f)
261265
kat_data = data["testGroups"][index]["tests"]
262266

@@ -275,7 +279,9 @@ def generic_encap_kat(self, ML_KEM, index):
275279
self.assertEqual(K_prime, k_kat)
276280

277281
def generic_decap_kat(self, ML_KEM, index):
278-
with open("assets/ML-KEM-encapDecap-FIPS203/internalProjection.json") as f:
282+
with open(
283+
"assets/ML-KEM-encapDecap-FIPS203/internalProjection.json"
284+
) as f:
279285
data = json.load(f)
280286
kat_data = data["testGroups"][3 + index]["tests"]
281287

0 commit comments

Comments
 (0)