Skip to content

Commit aa1bf64

Browse files
authored
fix(extension): safeguard TTL parsing with NonZeroU64 and cover with tests (#160)
1 parent cd811f1 commit aa1bf64

File tree

1 file changed

+34
-43
lines changed

1 file changed

+34
-43
lines changed

src/extension.rs

Lines changed: 34 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
use std::time::{SystemTime, UNIX_EPOCH};
1+
use std::{
2+
num::NonZeroU64,
3+
time::{SystemTime, UNIX_EPOCH},
4+
};
25

36
/// Enum representing different types of extensions.
47
#[allow(clippy::upper_case_acronyms)]
@@ -79,18 +82,6 @@ fn parser(prefix: String, full: String) -> Extension {
7982
///
8083
/// If the string `s` does not start with the prefix, the function returns
8184
/// `None`.
82-
///
83-
/// # Arguments
84-
///
85-
/// * `trim` - Whether to trim the string before checking the prefix.
86-
/// * `s` - The string to handle.
87-
/// * `prefix` - The prefix to check and remove from the string.
88-
/// * `handler` - The function to apply to the string after removing the prefix.
89-
///
90-
/// # Returns
91-
///
92-
/// This function returns an `Option<Extensions>`. If the string starts with the
93-
/// prefix, it returns `Some(Extensions)`. Otherwise, it returns `None`.
9485
#[tracing::instrument(level = "trace", skip(handler))]
9586
#[inline]
9687
fn parse_extension(
@@ -117,12 +108,6 @@ fn parse_extension(
117108
/// extension. The function uses the `murmurhash3_x64_128` function to generate
118109
/// a 128-bit hash from the string. The hash is then returned as a tuple `(a, b)`
119110
/// wrapped in the `Extensions::Range` variant.
120-
/// # Arguments
121-
/// * `s` - The string to parse.
122-
/// # Returns
123-
/// This function returns an `Extensions` enum.
124-
/// If the string is empty, it returns `Extensions::None`.
125-
/// If the string is not empty, it returns `Extensions::Range(a, b)`.
126111
#[inline(always)]
127112
fn parse_range_extension(s: &str) -> Extension {
128113
let hash = fxhash::hash64(s.as_bytes());
@@ -139,16 +124,6 @@ fn parse_range_extension(s: &str) -> Extension {
139124
/// wrapped in the `Extensions::Session` variant.
140125
///
141126
/// If the string is empty, the function returns `Extensions::None`.
142-
///
143-
/// # Arguments
144-
///
145-
/// * `s` - The string to parse.
146-
///
147-
/// # Returns
148-
///
149-
/// This function returns an `Extensions` enum. If the string is not empty, it
150-
/// will return a `Extensions::Session` variant containing a tuple `(a, b)`.
151-
/// Otherwise, it will return `Extensions::None`.
152127
#[inline(always)]
153128
fn parse_session_extension(s: &str) -> Extension {
154129
let hash = fxhash::hash64(s.as_bytes());
@@ -161,29 +136,45 @@ fn parse_session_extension(s: &str) -> Extension {
161136
/// the TTL value. If successful, it returns an `Extensions::Session` variant
162137
/// with the parsed TTL value and a fixed value of `1`. If the string cannot be
163138
/// parsed into a `u64`, it returns `Extensions::None`.
164-
///
165-
/// # Arguments
166-
///
167-
/// * `s` - The string to parse as a TTL value.
168-
///
169-
/// # Returns
170-
///
171-
/// Returns an `Extensions` enum variant. If parsing is successful, returns
172-
/// `Extensions::Session` with the TTL value and `1`. Otherwise, returns
173-
/// `Extensions::None`.
174-
#[inline(always)]
139+
#[inline]
175140
fn parse_ttl_extension(s: &str) -> Extension {
176-
if let Ok(ttl) = s.parse::<u64>() {
141+
if let Ok(Some(ttl)) = s.parse::<u64>().map(NonZeroU64::new) {
177142
let start = SystemTime::now();
178143
let timestamp = start
179144
.duration_since(UNIX_EPOCH)
180145
.map(|d| d.as_secs())
181146
.unwrap_or(rand::random());
182147

183-
let time = timestamp - (timestamp % ttl);
184-
148+
let time = timestamp - (timestamp % ttl.get());
185149
let hash = fxhash::hash64(&time.to_be_bytes());
186150
return Extension::TTL(hash);
187151
}
188152
Extension::None
189153
}
154+
155+
#[cfg(test)]
156+
mod tests {
157+
use super::*;
158+
159+
#[test]
160+
fn test_parse_ttl_extension_zero() {
161+
// Should return Extension::None for zero input
162+
assert!(matches!(parse_ttl_extension("0"), Extension::None));
163+
}
164+
165+
#[test]
166+
fn test_parse_ttl_extension_nonzero() {
167+
// Should return Extension::TTL for non-zero input
168+
let ext = parse_ttl_extension("60");
169+
match ext {
170+
Extension::TTL(_) => {}
171+
_ => panic!("Expected Extension::TTL"),
172+
}
173+
}
174+
175+
#[test]
176+
fn test_parse_ttl_extension_invalid() {
177+
// Should return Extension::None for invalid input
178+
assert!(matches!(parse_ttl_extension("abc"), Extension::None));
179+
}
180+
}

0 commit comments

Comments
 (0)