argmin_math/ndarray_m/
dot.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use crate::ArgminDot;
9use ndarray::{Array1, Array2};
10use num_complex::Complex;
11
12macro_rules! make_dot_ndarray {
13    ($t:ty) => {
14        impl ArgminDot<Array1<$t>, $t> for Array1<$t> {
15            #[inline]
16            fn dot(&self, other: &Array1<$t>) -> $t {
17                ndarray::Array1::dot(self, other)
18            }
19        }
20
21        impl ArgminDot<$t, Array1<$t>> for Array1<$t> {
22            #[inline]
23            fn dot(&self, other: &$t) -> Array1<$t> {
24                self.iter().cloned().map(|s| s * *other).collect()
25            }
26        }
27
28        impl ArgminDot<Array1<$t>, Array1<$t>> for $t {
29            #[inline]
30            fn dot(&self, other: &Array1<$t>) -> Array1<$t> {
31                other.iter().cloned().map(|o| o * *self).collect()
32            }
33        }
34
35        impl ArgminDot<Array1<$t>, Array2<$t>> for Array1<$t> {
36            #[inline]
37            fn dot(&self, other: &Array1<$t>) -> Array2<$t> {
38                Array2::from_shape_fn((self.len(), other.len()), |(i, j)| self[i] * other[j])
39            }
40        }
41
42        impl ArgminDot<Array1<$t>, Array1<$t>> for Array2<$t> {
43            #[inline]
44            fn dot(&self, other: &Array1<$t>) -> Array1<$t> {
45                ndarray::Array2::dot(self, other)
46            }
47        }
48
49        impl ArgminDot<Array2<$t>, Array2<$t>> for Array2<$t> {
50            #[inline]
51            fn dot(&self, other: &Array2<$t>) -> Array2<$t> {
52                ndarray::Array2::dot(self, other)
53            }
54        }
55
56        impl ArgminDot<$t, Array2<$t>> for Array2<$t> {
57            #[inline]
58            fn dot(&self, other: &$t) -> Array2<$t> {
59                Array2::from_shape_fn((self.nrows(), self.ncols()), |(i, j)| *other * self[(i, j)])
60            }
61        }
62
63        impl ArgminDot<Array2<$t>, Array2<$t>> for $t {
64            #[inline]
65            fn dot(&self, other: &Array2<$t>) -> Array2<$t> {
66                Array2::from_shape_fn((other.nrows(), other.ncols()), |(i, j)| {
67                    *self * other[(i, j)]
68                })
69            }
70        }
71
72        impl ArgminDot<Array1<Complex<$t>>, Complex<$t>> for Array1<Complex<$t>> {
73            #[inline]
74            fn dot(&self, other: &Array1<Complex<$t>>) -> Complex<$t> {
75                ndarray::Array1::dot(self, other)
76            }
77        }
78
79        impl ArgminDot<Complex<$t>, Array1<Complex<$t>>> for Array1<Complex<$t>> {
80            #[inline]
81            fn dot(&self, other: &Complex<$t>) -> Array1<Complex<$t>> {
82                self.iter().cloned().map(|s| s * *other).collect()
83            }
84        }
85
86        impl ArgminDot<Array1<Complex<$t>>, Array1<Complex<$t>>> for Complex<$t> {
87            #[inline]
88            fn dot(&self, other: &Array1<Complex<$t>>) -> Array1<Complex<$t>> {
89                other.iter().cloned().map(|o| o * *self).collect()
90            }
91        }
92
93        impl ArgminDot<Array1<Complex<$t>>, Array2<Complex<$t>>> for Array1<Complex<$t>> {
94            #[inline]
95            fn dot(&self, other: &Array1<Complex<$t>>) -> Array2<Complex<$t>> {
96                Array2::from_shape_fn((self.len(), other.len()), |(i, j)| self[i] * other[j])
97            }
98        }
99
100        impl ArgminDot<Array1<Complex<$t>>, Array1<Complex<$t>>> for Array2<Complex<$t>> {
101            #[inline]
102            fn dot(&self, other: &Array1<Complex<$t>>) -> Array1<Complex<$t>> {
103                ndarray::Array2::dot(self, other)
104            }
105        }
106
107        impl ArgminDot<Array2<Complex<$t>>, Array2<Complex<$t>>> for Array2<Complex<$t>> {
108            #[inline]
109            fn dot(&self, other: &Array2<Complex<$t>>) -> Array2<Complex<$t>> {
110                ndarray::Array2::dot(self, other)
111            }
112        }
113
114        impl ArgminDot<Complex<$t>, Array2<Complex<$t>>> for Array2<Complex<$t>> {
115            #[inline]
116            fn dot(&self, other: &Complex<$t>) -> Array2<Complex<$t>> {
117                Array2::from_shape_fn((self.nrows(), self.ncols()), |(i, j)| *other * self[(i, j)])
118            }
119        }
120
121        impl ArgminDot<Array2<Complex<$t>>, Array2<Complex<$t>>> for Complex<$t> {
122            #[inline]
123            fn dot(&self, other: &Array2<Complex<$t>>) -> Array2<Complex<$t>> {
124                Array2::from_shape_fn((other.nrows(), other.ncols()), |(i, j)| {
125                    *self * other[(i, j)]
126                })
127            }
128        }
129    };
130}
131
132make_dot_ndarray!(i8);
133make_dot_ndarray!(i16);
134make_dot_ndarray!(i32);
135make_dot_ndarray!(i64);
136make_dot_ndarray!(u8);
137make_dot_ndarray!(u16);
138make_dot_ndarray!(u32);
139make_dot_ndarray!(u64);
140make_dot_ndarray!(f32);
141make_dot_ndarray!(f64);
142
143// All code that does not depend on a linked ndarray-linalg backend can still be tested as normal.
144// To avoid dublicating tests and to allow convenient testing of functionality that does not need ndarray-linalg the tests are still included here.
145// The tests expect the name for the crate containing the tested functions to be argmin_math
146#[cfg(test)]
147use crate as argmin_math;
148include!(concat!(
149    env!("CARGO_MANIFEST_DIR"),
150    "/ndarray-tests-src/dot.rs"
151));