1use crate::{Allocator, ArgminDot};
9
10use num_traits::{One, Zero};
11
12use crate::{ClosedAdd, ClosedMul};
13use nalgebra::{
14 base::{
15 constraint::{AreMultipliable, DimEq, ShapeConstraint},
16 dimension::Dim,
17 storage::Storage,
18 Scalar,
19 },
20 DefaultAllocator, Matrix, OMatrix,
21};
22
23impl<N, R1, R2, C1, C2, SA, SB> ArgminDot<Matrix<N, R2, C2, SB>, N> for Matrix<N, R1, C1, SA>
24where
25 N: Scalar + Zero + ClosedAdd + ClosedMul,
26 R1: Dim,
27 R2: Dim,
28 C1: Dim,
29 C2: Dim,
30 SA: Storage<N, R1, C1>,
31 SB: Storage<N, R2, C2>,
32 ShapeConstraint: DimEq<R1, R2> + DimEq<C1, C2>,
33{
34 #[inline]
35 #[allow(clippy::only_used_in_recursion)]
36 fn dot(&self, other: &Matrix<N, R2, C2, SB>) -> N {
37 self.dot(other)
38 }
39}
40
41impl<N, R, C, S> ArgminDot<N, OMatrix<N, R, C>> for Matrix<N, R, C, S>
42where
43 N: Scalar + Copy + ClosedMul,
44 R: Dim,
45 C: Dim,
46 S: Storage<N, R, C>,
47 DefaultAllocator: Allocator<N, R, C>,
48{
49 #[inline]
50 fn dot(&self, other: &N) -> OMatrix<N, R, C> {
51 self * *other
52 }
53}
54
55impl<N, R, C, S> ArgminDot<Matrix<N, R, C, S>, OMatrix<N, R, C>> for N
56where
57 N: Scalar + Copy + ClosedMul,
58 R: Dim,
59 C: Dim,
60 S: Storage<N, R, C>,
61 DefaultAllocator: Allocator<N, R, C>,
62{
63 #[inline]
64 fn dot(&self, other: &Matrix<N, R, C, S>) -> OMatrix<N, R, C> {
65 other * *self
66 }
67}
68
69impl<N, R1, R2, C1, C2, SA, SB> ArgminDot<Matrix<N, R2, C2, SB>, OMatrix<N, R1, C2>>
70 for Matrix<N, R1, C1, SA>
71where
72 N: Scalar + Zero + One + ClosedAdd + ClosedMul,
73 R1: Dim,
74 R2: Dim,
75 C1: Dim,
76 C2: Dim,
77 SA: Storage<N, R1, C1>,
78 SB: Storage<N, R2, C2>,
79 DefaultAllocator: Allocator<N, R1, C2>,
80 ShapeConstraint: AreMultipliable<R1, C1, R2, C2>,
81{
82 #[inline]
83 fn dot(&self, other: &Matrix<N, R2, C2, SB>) -> OMatrix<N, R1, C2> {
84 self * other
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use super::*;
91 use approx::assert_relative_eq;
92 use nalgebra::{Matrix3, RowVector3, Vector3};
93 use paste::item;
94
95 macro_rules! make_test {
96 ($t:ty) => {
97 item! {
98 #[test]
99 fn [<test_vec_vec_ $t>]() {
100 let a = Vector3::new(1 as $t, 2 as $t, 3 as $t);
101 let b = Vector3::new(4 as $t, 5 as $t, 6 as $t);
102 let res: $t = <Vector3<$t> as ArgminDot<Vector3<$t>, $t>>::dot(&a, &b);
103 assert_relative_eq!(res as f64, 32 as f64, epsilon = f64::EPSILON);
104 }
105 }
106
107 item! {
108 #[test]
109 fn [<test_vec_scalar_ $t>]() {
110 let a = Vector3::new(1 as $t, 2 as $t, 3 as $t);
111 let b = 2 as $t;
112 let product: Vector3<$t> =
113 <Vector3<$t> as ArgminDot<$t, Vector3<$t>>>::dot(&a, &b);
114 let res = Vector3::new(2 as $t, 4 as $t, 6 as $t);
115 for i in 0..3 {
116 assert_relative_eq!(res[i] as f64, product[i] as f64, epsilon = f64::EPSILON);
117 }
118 }
119 }
120
121 item! {
122 #[test]
123 fn [<test_scalar_vec_ $t>]() {
124 let a = Vector3::new(1 as $t, 2 as $t, 3 as $t);
125 let b = 2 as $t;
126 let product: Vector3<$t> =
127 <$t as ArgminDot<Vector3<$t>, Vector3<$t>>>::dot(&b, &a);
128 let res = Vector3::new(2 as $t, 4 as $t, 6 as $t);
129 for i in 0..3 {
130 assert_relative_eq!(res[i] as f64, product[i] as f64, epsilon = f64::EPSILON);
131 }
132 }
133 }
134
135 item! {
136 #[test]
137 fn [<test_mat_vec_ $t>]() {
138 let a = Vector3::new(1 as $t, 2 as $t, 3 as $t);
139 let b = RowVector3::new(4 as $t, 5 as $t, 6 as $t);
140 let res = Matrix3::new(
141 4 as $t, 5 as $t, 6 as $t,
142 8 as $t, 10 as $t, 12 as $t,
143 12 as $t, 15 as $t, 18 as $t
144 );
145 let product: Matrix3<$t> =
146 <Vector3<$t> as ArgminDot<RowVector3<$t>, Matrix3<$t>>>::dot(&a, &b);
147 for i in 0..3 {
148 for j in 0..3 {
149 assert_relative_eq!(res[(i, j)] as f64, product[(i, j)] as f64, epsilon = f64::EPSILON);
150 }
151 }
152 }
153 }
154
155 item! {
156 #[test]
157 fn [<test_mat_vec_2_ $t>]() {
158 let a = Matrix3::new(
159 1 as $t, 2 as $t, 3 as $t,
160 4 as $t, 5 as $t, 6 as $t,
161 7 as $t, 8 as $t, 9 as $t
162 );
163 let b = Vector3::new(1 as $t, 2 as $t, 3 as $t);
164 let res = Vector3::new(14 as $t, 32 as $t, 50 as $t);
165 let product: Vector3<$t> =
166 <Matrix3<$t> as ArgminDot<Vector3<$t>, Vector3<$t>>>::dot(&a, &b);
167 for i in 0..3 {
168 assert_relative_eq!(res[i] as f64, product[i] as f64, epsilon = f64::EPSILON);
169 }
170 }
171 }
172
173 item! {
174 #[test]
175 fn [<test_mat_mat_ $t>]() {
176 let a = Matrix3::new(
177 1 as $t, 2 as $t, 3 as $t,
178 4 as $t, 5 as $t, 6 as $t,
179 3 as $t, 2 as $t, 1 as $t
180 );
181 let b = Matrix3::new(
182 3 as $t, 2 as $t, 1 as $t,
183 6 as $t, 5 as $t, 4 as $t,
184 2 as $t, 4 as $t, 3 as $t
185 );
186 let res = Matrix3::new(
187 21 as $t, 24 as $t, 18 as $t,
188 54 as $t, 57 as $t, 42 as $t,
189 23 as $t, 20 as $t, 14 as $t
190 );
191 let product: Matrix3<$t> =
192 <Matrix3<$t> as ArgminDot<Matrix3<$t>, Matrix3<$t>>>::dot(&a, &b);
193 for i in 0..3 {
194 for j in 0..3 {
195 assert_relative_eq!(res[(i, j)] as f64, product[(i, j)] as f64, epsilon = f64::EPSILON);
196 }
197 }
198 }
199 }
200
201 item! {
202 #[test]
203 fn [<test_mat_primitive_ $t>]() {
204 let a = Matrix3::new(
205 1 as $t, 2 as $t, 3 as $t,
206 4 as $t, 5 as $t, 6 as $t,
207 3 as $t, 2 as $t, 1 as $t
208 );
209 let res = Matrix3::new(
210 2 as $t, 4 as $t, 6 as $t,
211 8 as $t, 10 as $t, 12 as $t,
212 6 as $t, 4 as $t, 2 as $t
213 );
214 let product: Matrix3<$t> =
215 <Matrix3<$t> as ArgminDot<$t, Matrix3<$t>>>::dot(&a, &(2 as $t));
216 for i in 0..3 {
217 for j in 0..3 {
218 assert_relative_eq!(res[(i, j)] as f64, product[(i, j)] as f64, epsilon = f64::EPSILON);
219 }
220 }
221 }
222 }
223
224 item! {
225 #[test]
226 fn [<test_primitive_mat_ $t>]() {
227 let a = Matrix3::new(
228 1 as $t, 2 as $t, 3 as $t,
229 4 as $t, 5 as $t, 6 as $t,
230 3 as $t, 2 as $t, 1 as $t
231 );
232 let res = Matrix3::new(
233 2 as $t, 4 as $t, 6 as $t,
234 8 as $t, 10 as $t, 12 as $t,
235 6 as $t, 4 as $t, 2 as $t
236 );
237 let product: Matrix3<$t> =
238 <$t as ArgminDot<Matrix3<$t>, Matrix3<$t>>>::dot(&(2 as $t), &a);
239 for i in 0..3 {
240 for j in 0..3 {
241 assert_relative_eq!(res[(i, j)] as f64, product[(i, j)] as f64, epsilon = f64::EPSILON);
242 }
243 }
244 }
245 }
246 };
247 }
248
249 make_test!(i8);
250 make_test!(u8);
251 make_test!(i16);
252 make_test!(u16);
253 make_test!(i32);
254 make_test!(u32);
255 make_test!(i64);
256 make_test!(u64);
257 make_test!(f32);
258 make_test!(f64);
259}