Skip to content

Commit d03920d

Browse files
committed
Impl VecMath for c32, c64
1 parent 76e7c79 commit d03920d

File tree

1 file changed

+156
-2
lines changed

1 file changed

+156
-2
lines changed

src/vecmath/ffi.rs

Lines changed: 156 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
use cauchy::Scalar;
12
use intel_mkl_sys::*;
3+
use num_complex::{Complex32 as c32, Complex64 as c64};
24

3-
trait VecMath: Sized {
5+
trait VecMath: Scalar {
46
/* Arthmetic */
57
fn add(a: &[Self], b: &[Self], out: &mut [Self]);
68
fn sub(a: &[Self], b: &[Self], out: &mut [Self]);
79
fn mul(a: &[Self], b: &[Self], out: &mut [Self]);
8-
fn abs(in_: &[Self], out: &mut [Self]);
10+
fn abs(in_: &[Self], out: &mut [Self::Real]);
911

1012
/* Power and Root */
1113
fn div(a: &[Self], b: &[Self], out: &mut [Self]);
@@ -195,3 +197,155 @@ impl VecMath for f64 {
195197
impl_unary!(f64, asinh, vdAsinh);
196198
impl_unary!(f64, atanh, vdAtanh);
197199
}
200+
201+
macro_rules! impl_unary_c {
202+
($scalar:ty, $mkl_complex:ty, $name:ident, $impl_name:ident) => {
203+
fn $name(in_: &[$scalar], out: &mut [$scalar]) {
204+
assert_eq!(in_.len(), out.len());
205+
let n = in_.len() as i32;
206+
unsafe {
207+
$impl_name(
208+
n,
209+
in_.as_ptr() as *const $mkl_complex,
210+
out.as_mut_ptr() as *mut $mkl_complex,
211+
)
212+
}
213+
}
214+
};
215+
}
216+
217+
macro_rules! impl_unary_real_c {
218+
($scalar:ty, $mkl_complex:ty, $name:ident, $impl_name:ident) => {
219+
fn $name(in_: &[$scalar], out: &mut [<$scalar as Scalar>::Real]) {
220+
assert_eq!(in_.len(), out.len());
221+
let n = in_.len() as i32;
222+
unsafe {
223+
$impl_name(
224+
n,
225+
in_.as_ptr() as *const $mkl_complex,
226+
out.as_mut_ptr() as *mut <$scalar as Scalar>::Real,
227+
)
228+
}
229+
}
230+
};
231+
}
232+
233+
macro_rules! impl_binary_c {
234+
($scalar:ty, $mkl_complex:ty, $name:ident, $impl_name:ident) => {
235+
fn $name(a: &[$scalar], b: &[$scalar], out: &mut [$scalar]) {
236+
assert_eq!(a.len(), out.len());
237+
assert_eq!(b.len(), out.len());
238+
let n = out.len() as i32;
239+
unsafe {
240+
$impl_name(
241+
n,
242+
a.as_ptr() as *const $mkl_complex,
243+
b.as_ptr() as *const $mkl_complex,
244+
out.as_mut_ptr() as *mut $mkl_complex,
245+
)
246+
}
247+
}
248+
};
249+
}
250+
251+
macro_rules! impl_binary_scalar_c {
252+
($scalar:ty, $mkl_complex:ty, $name:ident, $impl_name:ident) => {
253+
fn $name(a: &[$scalar], b: $scalar, out: &mut [$scalar]) {
254+
assert_eq!(a.len(), out.len());
255+
let n = out.len() as i32;
256+
unsafe {
257+
$impl_name(
258+
n,
259+
a.as_ptr() as *const $mkl_complex,
260+
b.into_mkl(),
261+
out.as_mut_ptr() as *mut $mkl_complex,
262+
)
263+
}
264+
}
265+
};
266+
}
267+
268+
trait IntoMKL {
269+
type Output;
270+
fn into_mkl(self) -> Self::Output;
271+
}
272+
273+
impl IntoMKL for c32 {
274+
type Output = MKL_Complex8;
275+
fn into_mkl(self) -> MKL_Complex8 {
276+
MKL_Complex8 {
277+
real: self.re,
278+
imag: self.im,
279+
}
280+
}
281+
}
282+
283+
impl IntoMKL for c64 {
284+
type Output = MKL_Complex16;
285+
fn into_mkl(self) -> MKL_Complex16 {
286+
MKL_Complex16 {
287+
real: self.re,
288+
imag: self.im,
289+
}
290+
}
291+
}
292+
293+
impl VecMath for c32 {
294+
impl_binary_c!(c32, MKL_Complex8, add, vcAdd);
295+
impl_binary_c!(c32, MKL_Complex8, sub, vcSub);
296+
impl_binary_c!(c32, MKL_Complex8, mul, vcMul);
297+
impl_unary_real_c!(c32, MKL_Complex8, abs, vcAbs);
298+
299+
impl_binary_c!(c32, MKL_Complex8, div, vcDiv);
300+
impl_unary_c!(c32, MKL_Complex8, sqrt, vcSqrt);
301+
impl_binary_c!(c32, MKL_Complex8, pow, vcPow);
302+
impl_binary_scalar_c!(c32, MKL_Complex8, powx, vcPowx);
303+
304+
impl_unary_c!(c32, MKL_Complex8, exp, vcExp);
305+
impl_unary_c!(c32, MKL_Complex8, ln, vcLn);
306+
impl_unary_c!(c32, MKL_Complex8, log10, vcLog10);
307+
308+
impl_unary_c!(c32, MKL_Complex8, cos, vcCos);
309+
impl_unary_c!(c32, MKL_Complex8, sin, vcSin);
310+
impl_unary_c!(c32, MKL_Complex8, tan, vcTan);
311+
impl_unary_c!(c32, MKL_Complex8, acos, vcAcos);
312+
impl_unary_c!(c32, MKL_Complex8, asin, vcAsin);
313+
impl_unary_c!(c32, MKL_Complex8, atan, vcAtan);
314+
315+
impl_unary_c!(c32, MKL_Complex8, cosh, vcCosh);
316+
impl_unary_c!(c32, MKL_Complex8, sinh, vcSinh);
317+
impl_unary_c!(c32, MKL_Complex8, tanh, vcTanh);
318+
impl_unary_c!(c32, MKL_Complex8, acosh, vcAcosh);
319+
impl_unary_c!(c32, MKL_Complex8, asinh, vcAsinh);
320+
impl_unary_c!(c32, MKL_Complex8, atanh, vcAtanh);
321+
}
322+
323+
impl VecMath for c64 {
324+
impl_binary_c!(c64, MKL_Complex16, add, vzAdd);
325+
impl_binary_c!(c64, MKL_Complex16, sub, vzSub);
326+
impl_binary_c!(c64, MKL_Complex16, mul, vzMul);
327+
impl_unary_real_c!(c64, MKL_Complex16, abs, vzAbs);
328+
329+
impl_binary_c!(c64, MKL_Complex16, div, vzDiv);
330+
impl_unary_c!(c64, MKL_Complex16, sqrt, vzSqrt);
331+
impl_binary_c!(c64, MKL_Complex16, pow, vzPow);
332+
impl_binary_scalar_c!(c64, MKL_Complex16, powx, vzPowx);
333+
334+
impl_unary_c!(c64, MKL_Complex16, exp, vzExp);
335+
impl_unary_c!(c64, MKL_Complex16, ln, vzLn);
336+
impl_unary_c!(c64, MKL_Complex16, log10, vzLog10);
337+
338+
impl_unary_c!(c64, MKL_Complex16, cos, vzCos);
339+
impl_unary_c!(c64, MKL_Complex16, sin, vzSin);
340+
impl_unary_c!(c64, MKL_Complex16, tan, vzTan);
341+
impl_unary_c!(c64, MKL_Complex16, acos, vzAcos);
342+
impl_unary_c!(c64, MKL_Complex16, asin, vzAsin);
343+
impl_unary_c!(c64, MKL_Complex16, atan, vzAtan);
344+
345+
impl_unary_c!(c64, MKL_Complex16, cosh, vzCosh);
346+
impl_unary_c!(c64, MKL_Complex16, sinh, vzSinh);
347+
impl_unary_c!(c64, MKL_Complex16, tanh, vzTanh);
348+
impl_unary_c!(c64, MKL_Complex16, acosh, vzAcosh);
349+
impl_unary_c!(c64, MKL_Complex16, asinh, vzAsinh);
350+
impl_unary_c!(c64, MKL_Complex16, atanh, vzAtanh);
351+
}

0 commit comments

Comments
 (0)