diff --git a/src/arith/mod.rs b/src/arith/mod.rs index 32e291dd9..a3f194ce8 100644 --- a/src/arith/mod.rs +++ b/src/arith/mod.rs @@ -180,12 +180,12 @@ macro_rules! binary_func { /// /// This is an element wise binary operation. #[allow(unused_mut)] - pub fn $fn_name(lhs: &Array, rhs: &Array) -> Array { + pub fn $fn_name(lhs: &Array, rhs: &Array, batch: bool) -> Array { unsafe { let mut temp: i64 = 0; let err_val = $ffi_fn(&mut temp as MutAfArray, lhs.get() as AfArray, rhs.get() as AfArray, - 0); + batch as c_int); HANDLE_ERROR(AfError::from(err_val)); Array::from(temp) } @@ -217,6 +217,8 @@ macro_rules! convertable_type_def { ) } +convertable_type_def!(Complex); +convertable_type_def!(Complex); convertable_type_def!(u64); convertable_type_def!(i64); convertable_type_def!(f64); @@ -350,19 +352,13 @@ pub fn clamp (input: &Array, arg1: &T, arg2: &U, batch: bool) -> Array } macro_rules! arith_scalar_func { - ($rust_type: ty, $op_name:ident, $fn_name: ident, $ffi_fn: ident) => ( + ($rust_type: ty, $op_name:ident, $fn_name: ident) => ( impl<'f> $op_name<$rust_type> for &'f Array { type Output = Array; fn $fn_name(self, rhs: $rust_type) -> Array { - let cnst_arr = constant(rhs, self.dims()); - unsafe { - let mut temp: i64 = 0; - let err_val = $ffi_fn(&mut temp as MutAfArray, self.get() as AfArray, - cnst_arr.get() as AfArray, 0); - HANDLE_ERROR(AfError::from(err_val)); - Array::from(temp) - } + let temp = rhs.clone(); + $fn_name(self, &temp, false) } } @@ -370,14 +366,8 @@ macro_rules! arith_scalar_func { type Output = Array; fn $fn_name(self, rhs: $rust_type) -> Array { - let cnst_arr = constant(rhs, self.dims()); - unsafe { - let mut temp: i64 = 0; - let err_val = $ffi_fn(&mut temp as MutAfArray, self.get() as AfArray, - cnst_arr.get() as AfArray, 0); - HANDLE_ERROR(AfError::from(err_val)); - Array::from(temp) - } + let temp = rhs.clone(); + $fn_name(&self, &temp, false) } } ) @@ -385,10 +375,10 @@ macro_rules! arith_scalar_func { macro_rules! arith_scalar_spec { ($ty_name:ty) => ( - arith_scalar_func!($ty_name, Add, add, af_add); - arith_scalar_func!($ty_name, Sub, sub, af_sub); - arith_scalar_func!($ty_name, Mul, mul, af_mul); - arith_scalar_func!($ty_name, Div, div, af_div); + arith_scalar_func!($ty_name, Add, add); + arith_scalar_func!($ty_name, Sub, sub); + arith_scalar_func!($ty_name, Mul, mul); + arith_scalar_func!($ty_name, Div, div); ) } @@ -403,33 +393,51 @@ arith_scalar_spec!(i32); arith_scalar_spec!(u8); macro_rules! arith_func { - ($op_name:ident, $fn_name:ident, $ffi_fn: ident) => ( + ($op_name:ident, $fn_name:ident, $delegate:ident) => ( impl $op_name for Array { type Output = Array; fn $fn_name(self, rhs: Array) -> Array { - unsafe { - let mut temp: i64 = 0; - let err_val = $ffi_fn(&mut temp as MutAfArray, - self.get() as AfArray, rhs.get() as AfArray, 0); - HANDLE_ERROR(AfError::from(err_val)); - Array::from(temp) - } + $delegate(&self, &rhs, false) + } + } + + impl<'a> $op_name<&'a Array> for Array { + type Output = Array; + + fn $fn_name(self, rhs: &'a Array) -> Array { + $delegate(&self, rhs, false) + } + } + + impl<'a> $op_name for &'a Array { + type Output = Array; + + fn $fn_name(self, rhs: Array) -> Array { + $delegate(self, &rhs, false) + } + } + + impl<'a, 'b> $op_name<&'a Array> for &'b Array { + type Output = Array; + + fn $fn_name(self, rhs: &'a Array) -> Array { + $delegate(self, rhs, false) } } ) } -arith_func!(Add, add, af_add); -arith_func!(Sub, sub, af_sub); -arith_func!(Mul, mul, af_mul); -arith_func!(Div, div, af_div); -arith_func!(Rem, rem, af_rem); -arith_func!(BitAnd, bitand, af_bitand); -arith_func!(BitOr, bitor, af_bitor); -arith_func!(BitXor, bitxor, af_bitxor); -arith_func!(Shl, shl, af_bitshiftl); -arith_func!(Shr, shr, af_bitshiftr); +arith_func!(Add , add , add ); +arith_func!(Sub , sub , sub ); +arith_func!(Mul , mul , mul ); +arith_func!(Div , div , div ); +arith_func!(Rem , rem , rem ); +arith_func!(Shl , shl , shiftl); +arith_func!(Shr , shr , shiftr); +arith_func!(BitAnd, bitand, bitand); +arith_func!(BitOr , bitor , bitor ); +arith_func!(BitXor, bitxor, bitxor); #[cfg(op_assign)] mod op_assign { @@ -477,7 +485,7 @@ macro_rules! bit_assign_func { let mut idxrs = Indexer::new(); idxrs.set_index(&Seq::::default(), 0, Some(false)); idxrs.set_index(&Seq::::default(), 1, Some(false)); - let tmp = assign_gen(self as &Array, &idxrs, & $func(self as &Array, &rhs)); + let tmp = assign_gen(self as &Array, &idxrs, & $func(self as &Array, &rhs, false)); mem::replace(self, tmp); } }