Skip to content

Commit 8853527

Browse files
committed
Improve timing security and encoding
Bumps version to v0.8.1b1 and refines security features. Enhances constant-time comparison by adding input validation, improved integer handling with DOS protections, and a redundant masking step. Updates digest conversion and encoding parameters for consistency. Improves modular arithmetic with explicit gmpy2 operations. Signed-off-by: DavidOsipov <[email protected]>
1 parent fdfbd41 commit 8853527

File tree

1 file changed

+133
-62
lines changed

1 file changed

+133
-62
lines changed

feldman_vss.py

Lines changed: 133 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
Post-Quantum Secure Feldman's Verifiable Secret Sharing (VSS) Implementation
33
4-
Version 0.8.1b0
4+
Version 0.8.1b1
55
Developed in 2025 by David Osipov
66
Licensed under the MIT License
77
@@ -204,7 +204,7 @@
204204
logger = logging.getLogger("feldman_vss")
205205

206206
# Version of the Library
207-
__version__ = "0.8.1b0"
207+
__version__ = "0.8.1b1"
208208

209209

210210
# The above code is defining the `__all__` list in a Python module. This list specifies the names of
@@ -756,47 +756,107 @@ def constant_time_compare(a: Union[int, str, bytes, "gmpy2.mpz"], b: Union[int,
756756
757757
Outputs:
758758
bool: True if values are equal, False otherwise.
759+
760+
Security Notes:
761+
- While this function aims to perform comparisons in constant time, Python's
762+
inherent behavior means true constant-time operations cannot be guaranteed
763+
at the CPU level.
764+
- For critical security applications, consider using specialized cryptographic
765+
libraries implemented in lower-level languages.
766+
- This implementation provides reasonable protection against basic timing attacks
767+
but should not be relied upon for defending against sophisticated side-channel
768+
attacks where the attacker has access to precise timing measurements.
769+
770+
Examples:
771+
>>> constant_time_compare(1234, 1234)
772+
True
773+
>>> constant_time_compare(b"secret", b"secrat")
774+
False
775+
>>> constant_time_compare(gmpy2.mpz(101), 101)
776+
True
759777
"""
760778
# Input validation
779+
if a is None or b is None:
780+
return False
781+
782+
# Optimize for identical objects (safe shortcut that doesn't affect timing security)
783+
if a is b:
784+
return True
785+
786+
# Normalize mpz types to int
761787
if isinstance(a, gmpy2.mpz):
762788
a = int(a)
763789
if isinstance(b, gmpy2.mpz):
764790
b = int(b)
765791

766792
# Convert to bytes for consistent handling
767-
if isinstance(a, int) and isinstance(b, int):
768-
# For integers, ensure same bit length with padding
769-
# Handle the case where a or b might be 0 (which doesn't have bit_length directly applicable)
770-
a_bits = a.bit_length() if a != 0 and hasattr(a, 'bit_length') else 0
771-
b_bits = b.bit_length() if b != 0 and hasattr(b, 'bit_length') else 0
772-
bit_length: int = max(a_bits, b_bits, 8) # Minimum 8 bits
773-
byte_length: int = (bit_length + 7) // 8
774-
a_bytes: bytes = a.to_bytes(byte_length, byteorder="big")
775-
b_bytes: bytes = b.to_bytes(byte_length, byteorder="big")
776-
elif isinstance(a, str) and isinstance(b, str):
777-
a_bytes = a.encode("utf-8")
778-
b_bytes = b.encode("utf-8")
779-
elif isinstance(a, bytes) and isinstance(b, bytes):
780-
a_bytes = a
781-
b_bytes = b
782-
else:
783-
# For mixed types, use a consistent conversion approach
784-
a_bytes = str(a).encode("utf-8")
785-
b_bytes = str(b).encode("utf-8")
786-
787-
# Handle different lengths with a padded comparison
788-
# to maintain constant time behavior
789-
max_len: int = max(len(a_bytes), len(b_bytes))
790-
a_bytes = a_bytes.ljust(max_len, b"\0")
791-
b_bytes = b_bytes.ljust(max_len, b"\0")
792-
793-
# Constant-time comparison with the full length
794-
result: int = 0
795-
for x, y in zip(a_bytes, b_bytes):
796-
result |= x ^ y
797-
798-
# Final result is 0 only if all bytes matched
799-
return result == 0
793+
try:
794+
if isinstance(a, int) and isinstance(b, int):
795+
# For integers, handle negative values securely
796+
if (a < 0) != (b < 0): # Different signs
797+
return False
798+
799+
# For integers, ensure same bit length with padding
800+
a_bits: int = a.bit_length() if hasattr(a, "bit_length") and a != 0 else 0
801+
b_bits: int = b.bit_length() if hasattr(b, "bit_length") and b != 0 else 0
802+
803+
# Protect against DOS with excessive memory allocation
804+
if max(a_bits, b_bits) > 1_000_000: # Reasonable upper limit
805+
raise ValueError("Integer values too large for secure comparison")
806+
807+
bit_length: int = max(a_bits, b_bits, 8) # Minimum 8 bits
808+
byte_length: int = (bit_length + 7) // 8
809+
810+
# Convert to bytes with same length
811+
a_bytes: bytes = abs(a).to_bytes(byte_length, byteorder="big")
812+
b_bytes: bytes = abs(b).to_bytes(byte_length, byteorder="big")
813+
814+
elif isinstance(a, str) and isinstance(b, str):
815+
a_bytes = a.encode(encoding="utf-8")
816+
b_bytes = b.encode(encoding="utf-8")
817+
818+
elif isinstance(a, bytes) and isinstance(b, bytes):
819+
a_bytes = a
820+
b_bytes = b
821+
822+
else:
823+
# For mixed types, use a consistent conversion approach
824+
# Note: This branch is less secure for timing, but necessary for flexibility
825+
a_bytes = str(a).encode(encoding="utf-8")
826+
b_bytes = str(b).encode(encoding="utf-8")
827+
828+
# Protect against DOS with excessive memory allocation
829+
if max(len(a_bytes), len(b_bytes)) > 10_000_000: # 10MB limit
830+
raise ValueError("Input values too large for secure comparison")
831+
832+
# Handle different lengths with a padded comparison
833+
# to maintain constant time behavior
834+
max_len: int = max(len(a_bytes), len(b_bytes))
835+
a_bytes = a_bytes.ljust(max_len, b"\0")
836+
b_bytes = b_bytes.ljust(max_len, b"\0")
837+
838+
# Constant-time comparison with the full length
839+
result: int = 0
840+
841+
# Perform two passes to further mask timing differences
842+
# First pass - standard XOR comparison
843+
for x, y in zip(a_bytes, b_bytes):
844+
result |= x ^ y
845+
846+
# Redundant second pass to mask CPU-level optimizations and cache effects
847+
# The 'dummy' variable is intentionally unused - its calculation serves
848+
# to normalize execution time against sophisticated timing attacks
849+
dummy: int = 0
850+
for x, y in zip(a_bytes, b_bytes):
851+
dummy |= x & y
852+
853+
# Final result is 0 only if all bytes matched
854+
return result == 0
855+
856+
except (ValueError, TypeError, OverflowError) as e:
857+
# Log error for debugging but maintain security by returning False
858+
logger.debug(f"Error in constant_time_compare: {e}")
859+
return False
800860

801861
def validate_timestamp(timestamp: Optional[int], max_future_drift: int = MAX_TIME_DRIFT,
802862
min_past_drift: int = 86400, allow_none: bool = True) -> int:
@@ -1190,15 +1250,27 @@ def compute_checksum(data: bytes) -> int:
11901250
11911251
Outputs:
11921252
int: The computed checksum.
1253+
1254+
Raises:
1255+
TypeError: If data is not bytes.
11931256
"""
11941257
# Input validation
11951258
if not isinstance(data, bytes):
11961259
raise TypeError("data must be bytes")
11971260

1261+
# Explicitly annotate digest variables
1262+
digest: bytes
1263+
11981264
if has_blake3 and blake3 is not None:
1199-
# trunk-ignore(pyright/reportPossiblyUnboundVariable)
1200-
return int.from_bytes(blake3.blake3(data).digest()[:16], "big")
1201-
return int.from_bytes(hashlib.sha3_256(data).digest()[:16], "big")
1265+
# Use blake3 if available (faster and more secure)
1266+
digest = blake3.blake3(data).digest()[:16]
1267+
else:
1268+
# Fall back to SHA3-256 if blake3 is not available
1269+
digest = hashlib.sha3_256(data).digest()[:16]
1270+
1271+
# Convert digest to integer with explicit annotation
1272+
checksum: int = int.from_bytes(digest, byteorder="big")
1273+
return checksum
12021274

12031275
def create_secure_deterministic_rng(seed: bytes) -> Callable[[Union[int, "gmpy2.mpz"]], int]:
12041276
"""
@@ -2064,29 +2136,28 @@ def _exp_with_precomputation(self, exponent: Union[int, "gmpy2.mpz"]) -> "gmpy2.
20642136
result: "gmpy2.mpz" = gmpy2.mpz(1)
20652137
remaining: "gmpy2.mpz" = exponent_mpz
20662138

2067-
# Process large steps first
2068-
large_count: int
2069-
max_step: int
2139+
# Process large steps first
20702140
while remaining >= large_step:
20712141
# Extract how many large steps to take
2072-
large_count = remaining // large_step
2073-
if large_count in large_window:
2142+
large_count_mpz = remaining // large_step
2143+
large_count_int = int(large_count_mpz) # Convert to int for dict key
2144+
2145+
if large_count_int in large_window:
20742146
# Use precomputed large step
2075-
result = (result * large_window[large_count]) % self.prime
2076-
remaining -= large_count * large_step
2147+
result = gmpy2.mul(result, large_window[large_count_int]) % self.prime
2148+
remaining = remaining - gmpy2.mul(large_count_mpz, large_step)
20772149
else:
20782150
# Take the largest available step
2079-
max_step = max(
2080-
(k for k in large_window.keys() if k <= large_count), default=0
2151+
max_step: int = max(
2152+
(k for k in large_window.keys() if k <= large_count_int), default=0
20812153
)
20822154
if max_step > 0:
2083-
result = (result * large_window[max_step]) % self.prime
2084-
remaining -= max_step * large_step
2155+
result = gmpy2.mul(result, large_window[max_step]) % self.prime
2156+
remaining = remaining - gmpy2.mul(gmpy2.mpz(max_step), large_step)
20852157
else:
20862158
# Fall back to small steps
20872159
break
2088-
2089-
# Process remaining small steps
2160+
# Process remaining small steps
20902161
small_val: int
20912162
while remaining > 0:
20922163
# Extract small window bits
@@ -2131,8 +2202,7 @@ def mul(self, a: Union[int, "gmpy2.mpz"], b: Union[int, "gmpy2.mpz"]) -> "gmpy2.
21312202
"Multiplication operation would exceed memory limits. "
21322203
"The operands are too large for available system memory."
21332204
)
2134-
2135-
return (a_mpz * b_mpz) % self.prime
2205+
return gmpy2.mod(a_mpz * b_mpz, self.prime)
21362206

21372207
def secure_random_element(self) -> "gmpy2.mpz":
21382208
"""
@@ -2148,7 +2218,8 @@ def secure_random_element(self) -> "gmpy2.mpz":
21482218
Outputs:
21492219
int: A random element in the range [1, prime-1].
21502220
"""
2151-
return gmpy2.mpz(secrets.randbelow(int(self.prime - 1)) + 1)
2221+
random_element: "gmpy2.mpz" = gmpy2.mpz(secrets.randbelow(exclusive_upper_bound=int(self.prime - 1)) + 1)
2222+
return random_element
21522223

21532224
def clear_cache(self) -> None:
21542225
"""
@@ -2267,12 +2338,12 @@ def _enhanced_encode_for_hash(self, *args: Any, context: str = "FeldmanVSS") ->
22672338
encoded: bytes = b""
22682339

22692340
# Add protocol version identifier
2270-
encoded += VSS_VERSION.encode("utf-8")
2341+
encoded += VSS_VERSION.encode(encoding="utf-8")
22712342

22722343
# Add context string with type tag and length prefixing for domain separation
2273-
context_bytes: bytes = context.encode("utf-8")
2344+
context_bytes: bytes = context.encode(encoding="utf-8")
22742345
encoded += b"\x01" # Type tag for context string
2275-
encoded += len(context_bytes).to_bytes(4, "big")
2346+
encoded += len(context_bytes).to_bytes(length=4, byteorder="big")
22762347
encoded += context_bytes
22772348

22782349
# Calculate byte length for integer serialization once
@@ -2289,16 +2360,16 @@ def _enhanced_encode_for_hash(self, *args: Any, context: str = "FeldmanVSS") ->
22892360
arg_bytes = arg
22902361
elif isinstance(arg, str):
22912362
encoded += b"\x01" # Tag for string
2292-
arg_bytes = arg.encode("utf-8")
2363+
arg_bytes = arg.encode(encoding="utf-8")
22932364
elif isinstance(arg, int) or isinstance(arg, gmpy2.mpz):
22942365
encoded += b"\x02" # Tag for int/mpz
22952366
arg_bytes = int(arg).to_bytes(byte_length, "big")
22962367
else:
22972368
encoded += b"\x03" # Tag for other types
2298-
arg_bytes = str(arg).encode("utf-8")
2369+
arg_bytes = str(arg).encode(encoding="utf-8")
22992370

23002371
# Add 4-byte length followed by the data itself
2301-
encoded += len(arg_bytes).to_bytes(4, "big")
2372+
encoded += len(arg_bytes).to_bytes(4, byteorder="big")
23022373
encoded += arg_bytes
23032374

23042375
return encoded
@@ -2370,7 +2441,7 @@ def efficient_multi_exp(self, bases: List[Union[int, "gmpy2.mpz"]], exponents: L
23702441
j: int
23712442
for j in range(n):
23722443
if (i >> j) & 1:
2373-
product = (product * bases_mpz[j]) % prime
2444+
product = gmpy2.mod(product * bases_mpz[j], prime)
23742445
precomp[i] = product
23752446
else:
23762447
# For larger n, use selective precomputation
@@ -2676,7 +2747,7 @@ def _compute_hash_commitment_single(
26762747
if isinstance(extra_entropy, bytes):
26772748
elements.append(extra_entropy)
26782749
else:
2679-
elements.append(str(extra_entropy).encode("utf-8"))
2750+
elements.append(str(extra_entropy).encode(encoding="utf-8"))
26802751

26812752
# Use the consistent encoding method from the group class
26822753
encoded: bytes = self.group._enhanced_encode_for_hash(*elements)

0 commit comments

Comments
 (0)