Skip to content

Commit 26f5b4a

Browse files
committed
Improves bit length and error sanitization
Handles edge cases for zero values in bit length calculations Ensures error sanitization function is validated before use Signed-off-by: DavidOsipov <[email protected]>
1 parent 36d6fb7 commit 26f5b4a

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

feldman_vss.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,10 @@ def constant_time_compare(a: Union[int, str, bytes], b: Union[int, str, bytes])
546546
# Convert to bytes for consistent handling
547547
if isinstance(a, int) and isinstance(b, int):
548548
# For integers, ensure same bit length with padding
549-
bit_length: int = max(a.bit_length(), b.bit_length(), 8) # Minimum 8 bits
549+
# Handle the case where a or b might be 0 (which doesn't have bit_length directly applicable)
550+
a_bits = a.bit_length() if a != 0 and hasattr(a, 'bit_length') else 0
551+
b_bits = b.bit_length() if b != 0 and hasattr(b, 'bit_length') else 0
552+
bit_length: int = max(a_bits, b_bits, 8) # Minimum 8 bits
550553
byte_length: int = (bit_length + 7) // 8
551554
a_bytes: bytes = a.to_bytes(byte_length, byteorder="big")
552555
b_bytes: bytes = b.to_bytes(byte_length, byteorder="big")
@@ -588,7 +591,8 @@ def estimate_mpz_size(n: Union[int, "gmpy2.mpz"]) -> int:
588591
"""
589592
if isinstance(n, (int, gmpy2.mpz)):
590593
bit_length: int = (
591-
n.bit_length() if hasattr(n, "bit_length") else gmpy2.mpz(n).bit_length()
594+
n.bit_length() if hasattr(n, "bit_length") and n != 0 else
595+
gmpy2.mpz(n).bit_length() if n != 0 else 0
592596
)
593597
else:
594598
bit_length = n # Assume n is already a bit length
@@ -775,14 +779,14 @@ def check_memory_safety(operation: str, *args: Any, max_size_mb: int = 1024, rej
775779
b: Any
776780
a, b = args
777781
a_bits: int = (
778-
a.bit_length()
779-
if hasattr(a, "bit_length")
780-
else gmpy2.mpz(a).bit_length()
782+
a.bit_length() if hasattr(a, "bit_length") and a != 0
783+
else gmpy2.mpz(a).bit_length() if a != 0
784+
else 0
781785
)
782786
b_bits: int = (
783-
b.bit_length()
784-
if hasattr(b, "bit_length")
785-
else gmpy2.mpz(b).bit_length()
787+
b.bit_length() if hasattr(b, "bit_length") and b != 0
788+
else gmpy2.mpz(b).bit_length() if b != 0
789+
else 0
786790
)
787791
result_bits = a_bits + b_bits # Multiplication roughly adds bit lengths
788792
estimated_bytes = estimate_mpz_size(result_bits)
@@ -1016,7 +1020,7 @@ def secure_redundant_execution(
10161020
logger.error(detailed_message)
10171021

10181022
# Use sanitization function if provided
1019-
if callable(sanitize_error_func):
1023+
if sanitize_error_func is not None and callable(sanitize_error_func):
10201024
sanitized_message: str = sanitize_error_func(message, detailed_message)
10211025
raise SecurityError(sanitized_message)
10221026
else:
@@ -1083,7 +1087,7 @@ def secure_redundant_execution(
10831087
logger.error(detailed_message)
10841088

10851089
# Use sanitization function if provided, otherwise use the generic message
1086-
if callable(sanitize_error_func):
1090+
if sanitize_error_func is not None and callable(sanitize_error_func):
10871091
sanitized_message = sanitize_error_func(message, detailed_message)
10881092
raise SecurityError(sanitized_message)
10891093
else:
@@ -1103,7 +1107,7 @@ def secure_redundant_execution(
11031107
message: str = "Security validation process failed"
11041108
logger.error(detailed_message)
11051109

1106-
if callable(sanitize_error_func):
1110+
if sanitize_error_func is not None and callable(sanitize_error_func):
11071111
sanitized_message: str = sanitize_error_func(message, detailed_message)
11081112
raise SecurityError(sanitized_message) from e
11091113
else:

0 commit comments

Comments
 (0)