Skip to content

Commit 28a8113

Browse files
authored
Improve EMA indicator parity with Cython (#2642)
1 parent 6620b2f commit 28a8113

File tree

1 file changed

+96
-1
lines changed
  • crates/indicators/src/average

1 file changed

+96
-1
lines changed

crates/indicators/src/average/ema.rs

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ pub struct ExponentialMovingAverage {
4040

4141
impl Display for ExponentialMovingAverage {
4242
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43-
write!(f, "{}({})", self.name(), self.period,)
43+
write!(f, "{}({})", self.name(), self.period)
4444
}
4545
}
4646

@@ -79,8 +79,15 @@ impl Indicator for ExponentialMovingAverage {
7979

8080
impl ExponentialMovingAverage {
8181
/// Creates a new [`ExponentialMovingAverage`] instance.
82+
///
83+
/// # Panics
84+
/// * If `period` is not positive.
8285
#[must_use]
8386
pub fn new(period: usize, price_type: Option<PriceType>) -> Self {
87+
assert!(
88+
period > 0,
89+
"ExponentialMovingAverage::new → `period` must be positive (> 0); got {period}"
90+
);
8491
Self {
8592
period,
8693
price_type: price_type.unwrap_or(PriceType::Last),
@@ -101,10 +108,17 @@ impl MovingAverage for ExponentialMovingAverage {
101108
fn count(&self) -> usize {
102109
self.count
103110
}
111+
104112
fn update_raw(&mut self, value: f64) {
105113
if !self.has_inputs {
106114
self.has_inputs = true;
107115
self.value = value;
116+
self.count = 1;
117+
118+
if self.period == 1 {
119+
self.initialized = true;
120+
}
121+
return;
108122
}
109123

110124
self.value = self.alpha.mul_add(value, (1.0 - self.alpha) * self.value);
@@ -224,4 +238,85 @@ mod tests {
224238
assert!(!indicator_ema_10.initialized);
225239
assert_eq!(indicator_ema_10.value, 1522.0);
226240
}
241+
242+
#[rstest]
243+
fn test_period_one_behaviour() {
244+
let mut ema = ExponentialMovingAverage::new(1, None);
245+
assert_eq!(ema.alpha, 1.0, "α must be 1 when period = 1");
246+
247+
ema.update_raw(10.0);
248+
assert!(ema.initialized());
249+
assert_eq!(ema.value(), 10.0);
250+
251+
ema.update_raw(42.0);
252+
assert_eq!(
253+
ema.value(),
254+
42.0,
255+
"With α = 1, the EMA must track the latest sample exactly"
256+
);
257+
}
258+
259+
#[rstest]
260+
fn test_default_price_type_is_last() {
261+
let ema = ExponentialMovingAverage::new(3, None);
262+
assert_eq!(
263+
ema.price_type,
264+
PriceType::Last,
265+
"`price_type` default mismatch"
266+
);
267+
}
268+
269+
#[rstest]
270+
fn test_nan_poisoning_and_reset_recovery() {
271+
let mut ema = ExponentialMovingAverage::new(4, None);
272+
for x in 0..3 {
273+
ema.update_raw(x as f64);
274+
assert!(ema.value().is_finite());
275+
}
276+
277+
ema.update_raw(f64::NAN);
278+
assert!(ema.value().is_nan());
279+
280+
ema.update_raw(123.456);
281+
assert!(ema.value().is_nan());
282+
283+
ema.reset();
284+
assert!(!ema.has_inputs());
285+
ema.update_raw(7.0);
286+
assert_eq!(ema.value(), 7.0);
287+
assert!(ema.value().is_finite());
288+
}
289+
290+
#[rstest]
291+
fn test_reset_without_inputs_is_safe() {
292+
let mut ema = ExponentialMovingAverage::new(8, None);
293+
ema.reset();
294+
assert!(!ema.has_inputs());
295+
assert_eq!(ema.count(), 0);
296+
assert!(!ema.initialized());
297+
}
298+
299+
#[rstest]
300+
fn test_has_inputs_lifecycle() {
301+
let mut ema = ExponentialMovingAverage::new(5, None);
302+
assert!(!ema.has_inputs());
303+
304+
ema.update_raw(1.23);
305+
assert!(ema.has_inputs());
306+
307+
ema.reset();
308+
assert!(!ema.has_inputs());
309+
}
310+
311+
#[rstest]
312+
fn test_subnormal_inputs_do_not_underflow() {
313+
let mut ema = ExponentialMovingAverage::new(2, None);
314+
let tiny = f64::MIN_POSITIVE / 2.0;
315+
ema.update_raw(tiny);
316+
ema.update_raw(tiny);
317+
assert!(
318+
ema.value() > 0.0,
319+
"Underflow: EMA value collapsed to zero for sub-normal inputs"
320+
);
321+
}
227322
}

0 commit comments

Comments
 (0)