Skip to content

Commit ad23ebc

Browse files
committed
feat: add simd uniform
1 parent 1996c09 commit ad23ebc

File tree

6 files changed

+74
-3
lines changed

6 files changed

+74
-3
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ flamegraph.svg
55
venv
66
CI.yml
77
/test
8-
.idea/
8+
.idea/
9+
.zed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ default = []
8181
jemalloc = ["dep:tikv-jemallocator"]
8282
malliavin = []
8383
mimalloc = ["dep:mimalloc"]
84+
simd = []
8485
yahoo = ["dep:time", "dep:yahoo_finance_api", "dep:polars"]
8586

8687
[lib]

src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
#![allow(clippy::type_complexity)]
44
#![allow(clippy::too_many_arguments)]
55
//#![warn(missing_docs)]
6-
6+
//
7+
#![feature(portable_simd)]
78
// TODO: this is just temporary
89
#![allow(dead_code)]
910

src/stats/distr.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ mod tests {
3535
beta::SimdBeta, binomial::SimdBinomial, cauchy::SimdCauchy, gamma::SimdGamma,
3636
geometric::SimdGeometric, hypergeometric::SimdHypergeometric, inverse_gauss::SimdInverseGauss,
3737
lognormal::SimdLogNormal, normal::SimdNormal, normal_inverse_gauss::SimdNormalInverseGauss,
38-
pareto::SimdPareto, poisson::SimdPoisson, studentt::SimdStudentT, weibull::SimdWeibull,
38+
pareto::SimdPareto, poisson::SimdPoisson, studentt::SimdStudentT, uniform::SimdUniform,
39+
weibull::SimdWeibull,
3940
};
4041

4142
use plotly::{
@@ -386,6 +387,22 @@ mod tests {
386387
plot.add_trace(trace);
387388
}
388389

390+
// 15) Uniform (0,1)
391+
{
392+
let (xa, ya) = subplot_axes(4, 3);
393+
let mut rng = thread_rng();
394+
let dist = SimdUniform::new(0.0, 1.0);
395+
let samples: Vec<f32> = (0..sample_size).map(|_| dist.sample(&mut rng)).collect();
396+
let (xs, bins) = make_histogram(&samples, 100, -4.0, 4.0);
397+
let trace = Scatter::new(xs, bins)
398+
.name("Uniform(0,1)")
399+
.mode(Mode::Lines)
400+
.line(Line::new().shape(LineShape::Linear))
401+
.x_axis(&xa)
402+
.y_axis(&ya);
403+
plot.add_trace(trace);
404+
}
405+
389406
plot.show();
390407
}
391408
}

src/stochastic.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,12 @@ pub trait SamplingExt<T: Clone + Send + Sync + Zero>: Send + Sync {
8282
/// Sample the process
8383
fn sample(&self) -> Array1<T>;
8484

85+
/// Sample the process with simd acceleration
86+
#[cfg(feature = "simd")]
87+
fn sample_simd(&self) -> Array1<T> {
88+
unimplemented!()
89+
}
90+
8591
/// Sample the process with CUDA support
8692
#[cfg(feature = "cuda")]
8793
fn sample_cuda(&self) -> Result<Either<Array1<T>, Array2<T>>> {

src/stochastic/diffusion/ou.rs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,25 @@ impl SamplingExt<f64> for OU {
3232
ou
3333
}
3434

35+
#[cfg(feature = "simd")]
36+
// TODO: experimental
37+
fn sample_simd(&self) -> Array1<f64> {
38+
use crate::stats::distr::normal::SimdNormal;
39+
40+
let dt = self.t.unwrap_or(1.0) as f32 / (self.n - 1) as f32;
41+
let gn = Array1::random(self.n, SimdNormal::new(0.0, dt.sqrt()));
42+
43+
let mut ou = Array1::<f64>::zeros(self.n);
44+
ou[0] = self.x0.unwrap_or(0.0);
45+
46+
for i in 1..self.n {
47+
ou[i] =
48+
ou[i - 1] + self.theta * (self.mu - ou[i - 1]) * dt as f64 + self.sigma * gn[i - 1] as f64
49+
}
50+
51+
ou
52+
}
53+
3554
/// Number of time steps
3655
fn n(&self) -> usize {
3756
self.n
@@ -73,6 +92,32 @@ mod tests {
7392
plot_1d!(ou.sample(), "Fractional Ornstein-Uhlenbeck (FOU) Process");
7493
}
7594

95+
#[cfg(feature = "simd")]
96+
#[test]
97+
fn sample_simd() {
98+
use std::time::Instant;
99+
100+
let start = Instant::now();
101+
let ou = OU::new(2.0, 1.0, 0.8, N, Some(X0), Some(1.0), None);
102+
103+
for _ in 0..100_000 {
104+
ou.sample_simd();
105+
}
106+
107+
let elapsed = start.elapsed();
108+
println!("Elapsed time for sample_simd: {:?}", elapsed);
109+
110+
let start = Instant::now();
111+
let ou = OU::new(2.0, 1.0, 0.8, N, Some(X0), Some(1.0), None);
112+
113+
for _ in 0..100_000 {
114+
ou.sample();
115+
}
116+
117+
let elapsed = start.elapsed();
118+
println!("Elapsed time for sample: {:?}", elapsed);
119+
}
120+
76121
#[test]
77122
#[ignore = "Not implemented"]
78123
#[cfg(feature = "malliavin")]

0 commit comments

Comments
 (0)