diff --git a/Cargo.toml b/Cargo.toml index 5050629c1..82aa89170 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,7 +54,6 @@ lazy_static = "1.0" half = "1.5.0" [dev-dependencies] -float-cmp = "0.6.0" half = "1.5.0" [build-dependencies] diff --git a/examples/pi.rs b/examples/pi.rs index 1d6e91bb4..f180478b7 100644 --- a/examples/pi.rs +++ b/examples/pi.rs @@ -23,7 +23,7 @@ fn main() { let root = &sqrt(xplusy); let cnst = &constant(1, dims); let (real, imag) = sum_all(&le(root, cnst, false)); - let pi_val = real * 4.0 / (samples as f64); + let pi_val = (real as f64) * 4.0 / (samples as f64); } println!("Estimated Pi Value in {:?}", start.elapsed()); diff --git a/src/algorithm/mod.rs b/src/algorithm/mod.rs index f8cf2ac19..35d286404 100644 --- a/src/algorithm/mod.rs +++ b/src/algorithm/mod.rs @@ -1,5 +1,5 @@ use super::core::{ - af_array, AfError, Array, BinaryOp, HasAfEnum, RealNumber, ReduceByKeyInput, Scanable, + af_array, AfError, Array, BinaryOp, Fromf64, HasAfEnum, RealNumber, ReduceByKeyInput, Scanable, HANDLE_ERROR, }; @@ -518,9 +518,13 @@ where } macro_rules! all_reduce_func_def { - ($doc_str: expr, $fn_name: ident, $ffi_name: ident) => { + ($doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type:ty) => { #[doc=$doc_str] - pub fn $fn_name(input: &Array) -> (f64, f64) { + pub fn $fn_name(input: &Array) -> ($out_type, $out_type) + where + T: HasAfEnum, + $out_type: HasAfEnum + Fromf64 + { let mut real: f64 = 0.0; let mut imag: f64 = 0.0; unsafe { @@ -529,7 +533,7 @@ macro_rules! all_reduce_func_def { ); HANDLE_ERROR(AfError::from(err_val)); } - (real, imag) + (<$out_type>::fromf64(real), <$out_type>::fromf64(imag)) } }; } @@ -559,7 +563,8 @@ all_reduce_func_def!( ``` ", sum_all, - af_sum_all + af_sum_all, + T::AggregateOutType ); all_reduce_func_def!( @@ -588,7 +593,8 @@ all_reduce_func_def!( ``` ", product_all, - af_product_all + af_product_all, + T::ProductOutType ); all_reduce_func_def!( @@ -616,7 +622,8 @@ all_reduce_func_def!( ``` ", min_all, - af_min_all + af_min_all, + T::InType ); all_reduce_func_def!( @@ -644,7 +651,8 @@ all_reduce_func_def!( ``` ", max_all, - af_max_all + af_max_all, + T::InType ); all_reduce_func_def!( @@ -670,7 +678,8 @@ all_reduce_func_def!( ``` ", all_true_all, - af_all_true_all + af_all_true_all, + bool ); all_reduce_func_def!( @@ -696,7 +705,8 @@ all_reduce_func_def!( ``` ", any_true_all, - af_any_true_all + af_any_true_all, + bool ); all_reduce_func_def!( @@ -722,7 +732,8 @@ all_reduce_func_def!( ``` ", count_all, - af_count_all + af_count_all, + u64 ); /// Sum all values using user provided value for `NAN` @@ -740,7 +751,11 @@ all_reduce_func_def!( /// A tuple of summation result. /// /// Note: For non-complex data type Arrays, second value of tuple is zero. -pub fn sum_nan_all(input: &Array, val: f64) -> (f64, f64) { +pub fn sum_nan_all(input: &Array, val: f64) -> (T::AggregateOutType, T::AggregateOutType) +where + T: HasAfEnum, + T::AggregateOutType: HasAfEnum + Fromf64, +{ let mut real: f64 = 0.0; let mut imag: f64 = 0.0; unsafe { @@ -752,7 +767,10 @@ pub fn sum_nan_all(input: &Array, val: f64) -> (f64, f64) { ); HANDLE_ERROR(AfError::from(err_val)); } - (real, imag) + ( + ::fromf64(real), + ::fromf64(imag), + ) } /// Product of all values using user provided value for `NAN` @@ -770,7 +788,11 @@ pub fn sum_nan_all(input: &Array, val: f64) -> (f64, f64) { /// A tuple of product result. /// /// Note: For non-complex data type Arrays, second value of tuple is zero. -pub fn product_nan_all(input: &Array, val: f64) -> (f64, f64) { +pub fn product_nan_all(input: &Array, val: f64) -> (T::ProductOutType, T::ProductOutType) +where + T: HasAfEnum, + T::ProductOutType: HasAfEnum + Fromf64, +{ let mut real: f64 = 0.0; let mut imag: f64 = 0.0; unsafe { @@ -782,7 +804,10 @@ pub fn product_nan_all(input: &Array, val: f64) -> (f64, f64) { ); HANDLE_ERROR(AfError::from(err_val)); } - (real, imag) + ( + ::fromf64(real), + ::fromf64(imag), + ) } macro_rules! dim_ireduce_func_def { @@ -833,9 +858,13 @@ dim_ireduce_func_def!(" ", imax, af_imax, InType); macro_rules! all_ireduce_func_def { - ($doc_str: expr, $fn_name: ident, $ffi_name: ident) => { + ($doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type:ty) => { #[doc=$doc_str] - pub fn $fn_name(input: &Array) -> (f64, f64, u32) { + pub fn $fn_name(input: &Array) -> ($out_type, $out_type, u32) + where + T: HasAfEnum, + $out_type: HasAfEnum + Fromf64 + { let mut real: f64 = 0.0; let mut imag: f64 = 0.0; let mut temp: u32 = 0; @@ -846,7 +875,7 @@ macro_rules! all_ireduce_func_def { ); HANDLE_ERROR(AfError::from(err_val)); } - (real, imag, temp) + (<$out_type>::fromf64(real), <$out_type>::fromf64(imag), temp) } }; } @@ -868,7 +897,8 @@ all_ireduce_func_def!( * index of minimum element in the third component. ", imin_all, - af_imin_all + af_imin_all, + T::InType ); all_ireduce_func_def!( " @@ -887,7 +917,8 @@ all_ireduce_func_def!( - index of maximum element in the third component. ", imax_all, - af_imax_all + af_imax_all, + T::InType ); /// Locate the indices of non-zero elements. diff --git a/src/core/util.rs b/src/core/util.rs index a5bbe7e93..2eff73f64 100644 --- a/src/core/util.rs +++ b/src/core/util.rs @@ -796,3 +796,34 @@ impl BitOr for MatProp { Self::from(self as u32 | rhs as u32) } } + +/// Trait to convert reduction's scalar output to appropriate output type +/// +/// This is an internal trait and ideally of no use to user usecases. +pub trait Fromf64 { + /// Convert to target type from a double precision value + fn fromf64(value: f64) -> Self; +} + +#[rustfmt::skip] +impl Fromf64 for usize{ fn fromf64(value: f64) -> Self { value as Self }} +#[rustfmt::skip] +impl Fromf64 for f64 { fn fromf64(value: f64) -> Self { value as Self }} +#[rustfmt::skip] +impl Fromf64 for u64 { fn fromf64(value: f64) -> Self { value as Self }} +#[rustfmt::skip] +impl Fromf64 for i64 { fn fromf64(value: f64) -> Self { value as Self }} +#[rustfmt::skip] +impl Fromf64 for f32 { fn fromf64(value: f64) -> Self { value as Self }} +#[rustfmt::skip] +impl Fromf64 for u32 { fn fromf64(value: f64) -> Self { value as Self }} +#[rustfmt::skip] +impl Fromf64 for i32 { fn fromf64(value: f64) -> Self { value as Self }} +#[rustfmt::skip] +impl Fromf64 for u16 { fn fromf64(value: f64) -> Self { value as Self }} +#[rustfmt::skip] +impl Fromf64 for i16 { fn fromf64(value: f64) -> Self { value as Self }} +#[rustfmt::skip] +impl Fromf64 for u8 { fn fromf64(value: f64) -> Self { value as Self }} +#[rustfmt::skip] +impl Fromf64 for bool { fn fromf64(value: f64) -> Self { value > 0.0 }} diff --git a/tests/scalar_arith.rs b/tests/scalar_arith.rs index 915538c54..a80fc2aec 100644 --- a/tests/scalar_arith.rs +++ b/tests/scalar_arith.rs @@ -1,5 +1,4 @@ use ::arrayfire::*; -use float_cmp::approx_eq; #[test] fn check_scalar_arith() { @@ -15,5 +14,5 @@ fn check_scalar_arith() { let scalar_res = all_true_all(&scalar_res_comp); let res = all_true_all(&res_comp); - assert!(approx_eq!(f64, scalar_res.0, res.0, ulps = 2)); + assert!(scalar_res.0 == res.0); }