argmin_math/nalgebra_m/
dot.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use crate::{Allocator, ArgminDot};
9
10use num_traits::{One, Zero};
11
12use crate::{ClosedAdd, ClosedMul};
13use nalgebra::{
14    base::{
15        constraint::{AreMultipliable, DimEq, ShapeConstraint},
16        dimension::Dim,
17        storage::Storage,
18        Scalar,
19    },
20    DefaultAllocator, Matrix, OMatrix,
21};
22
23impl<N, R1, R2, C1, C2, SA, SB> ArgminDot<Matrix<N, R2, C2, SB>, N> for Matrix<N, R1, C1, SA>
24where
25    N: Scalar + Zero + ClosedAdd + ClosedMul,
26    R1: Dim,
27    R2: Dim,
28    C1: Dim,
29    C2: Dim,
30    SA: Storage<N, R1, C1>,
31    SB: Storage<N, R2, C2>,
32    ShapeConstraint: DimEq<R1, R2> + DimEq<C1, C2>,
33{
34    #[inline]
35    #[allow(clippy::only_used_in_recursion)]
36    fn dot(&self, other: &Matrix<N, R2, C2, SB>) -> N {
37        self.dot(other)
38    }
39}
40
41impl<N, R, C, S> ArgminDot<N, OMatrix<N, R, C>> for Matrix<N, R, C, S>
42where
43    N: Scalar + Copy + ClosedMul,
44    R: Dim,
45    C: Dim,
46    S: Storage<N, R, C>,
47    DefaultAllocator: Allocator<N, R, C>,
48{
49    #[inline]
50    fn dot(&self, other: &N) -> OMatrix<N, R, C> {
51        self * *other
52    }
53}
54
55impl<N, R, C, S> ArgminDot<Matrix<N, R, C, S>, OMatrix<N, R, C>> for N
56where
57    N: Scalar + Copy + ClosedMul,
58    R: Dim,
59    C: Dim,
60    S: Storage<N, R, C>,
61    DefaultAllocator: Allocator<N, R, C>,
62{
63    #[inline]
64    fn dot(&self, other: &Matrix<N, R, C, S>) -> OMatrix<N, R, C> {
65        other * *self
66    }
67}
68
69impl<N, R1, R2, C1, C2, SA, SB> ArgminDot<Matrix<N, R2, C2, SB>, OMatrix<N, R1, C2>>
70    for Matrix<N, R1, C1, SA>
71where
72    N: Scalar + Zero + One + ClosedAdd + ClosedMul,
73    R1: Dim,
74    R2: Dim,
75    C1: Dim,
76    C2: Dim,
77    SA: Storage<N, R1, C1>,
78    SB: Storage<N, R2, C2>,
79    DefaultAllocator: Allocator<N, R1, C2>,
80    ShapeConstraint: AreMultipliable<R1, C1, R2, C2>,
81{
82    #[inline]
83    fn dot(&self, other: &Matrix<N, R2, C2, SB>) -> OMatrix<N, R1, C2> {
84        self * other
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91    use approx::assert_relative_eq;
92    use nalgebra::{Matrix3, RowVector3, Vector3};
93    use paste::item;
94
95    macro_rules! make_test {
96        ($t:ty) => {
97            item! {
98                #[test]
99                fn [<test_vec_vec_ $t>]() {
100                    let a = Vector3::new(1 as $t, 2 as $t, 3 as $t);
101                    let b = Vector3::new(4 as $t, 5 as $t, 6 as $t);
102                    let res: $t = <Vector3<$t> as ArgminDot<Vector3<$t>, $t>>::dot(&a, &b);
103                    assert_relative_eq!(res as f64, 32 as f64, epsilon = f64::EPSILON);
104                }
105            }
106
107            item! {
108                #[test]
109                fn [<test_vec_scalar_ $t>]() {
110                    let a = Vector3::new(1 as $t, 2 as $t, 3 as $t);
111                    let b = 2 as $t;
112                    let product: Vector3<$t> =
113                        <Vector3<$t> as ArgminDot<$t, Vector3<$t>>>::dot(&a, &b);
114                    let res = Vector3::new(2 as $t, 4 as $t, 6 as $t);
115                    for i in 0..3 {
116                        assert_relative_eq!(res[i] as f64, product[i] as f64, epsilon = f64::EPSILON);
117                    }
118                }
119            }
120
121            item! {
122                #[test]
123                fn [<test_scalar_vec_ $t>]() {
124                    let a = Vector3::new(1 as $t, 2 as $t, 3 as $t);
125                    let b = 2 as $t;
126                    let product: Vector3<$t> =
127                        <$t as ArgminDot<Vector3<$t>, Vector3<$t>>>::dot(&b, &a);
128                    let res = Vector3::new(2 as $t, 4 as $t, 6 as $t);
129                    for i in 0..3 {
130                        assert_relative_eq!(res[i] as f64, product[i] as f64, epsilon = f64::EPSILON);
131                    }
132                }
133            }
134
135            item! {
136                #[test]
137                fn [<test_mat_vec_ $t>]() {
138                    let a = Vector3::new(1 as $t, 2 as $t, 3 as $t);
139                    let b = RowVector3::new(4 as $t, 5 as $t, 6 as $t);
140                    let res = Matrix3::new(
141                        4 as $t, 5 as $t, 6 as $t,
142                        8 as $t, 10 as $t, 12 as $t,
143                        12 as $t, 15 as $t, 18 as $t
144                    );
145                    let product: Matrix3<$t> =
146                        <Vector3<$t> as ArgminDot<RowVector3<$t>, Matrix3<$t>>>::dot(&a, &b);
147                    for i in 0..3 {
148                        for j in 0..3 {
149                            assert_relative_eq!(res[(i, j)] as f64, product[(i, j)] as f64, epsilon = f64::EPSILON);
150                        }
151                    }
152                }
153            }
154
155            item! {
156                #[test]
157                fn [<test_mat_vec_2_ $t>]() {
158                    let a = Matrix3::new(
159                        1 as $t, 2 as $t, 3 as $t,
160                        4 as $t, 5 as $t, 6 as $t,
161                        7 as $t, 8 as $t, 9 as $t
162                    );
163                    let b = Vector3::new(1 as $t, 2 as $t, 3 as $t);
164                    let res = Vector3::new(14 as $t, 32 as $t, 50 as $t);
165                    let product: Vector3<$t> =
166                        <Matrix3<$t> as ArgminDot<Vector3<$t>, Vector3<$t>>>::dot(&a, &b);
167                    for i in 0..3 {
168                        assert_relative_eq!(res[i] as f64, product[i] as f64, epsilon = f64::EPSILON);
169                    }
170                }
171            }
172
173            item! {
174                #[test]
175                fn [<test_mat_mat_ $t>]() {
176                    let a = Matrix3::new(
177                        1 as $t, 2 as $t, 3 as $t,
178                        4 as $t, 5 as $t, 6 as $t,
179                        3 as $t, 2 as $t, 1 as $t
180                    );
181                    let b = Matrix3::new(
182                        3 as $t, 2 as $t, 1 as $t,
183                        6 as $t, 5 as $t, 4 as $t,
184                        2 as $t, 4 as $t, 3 as $t
185                    );
186                    let res = Matrix3::new(
187                        21 as $t, 24 as $t, 18 as $t,
188                        54 as $t, 57 as $t, 42 as $t,
189                        23 as $t, 20 as $t, 14 as $t
190                    );
191                    let product: Matrix3<$t> =
192                        <Matrix3<$t> as ArgminDot<Matrix3<$t>, Matrix3<$t>>>::dot(&a, &b);
193                    for i in 0..3 {
194                        for j in 0..3 {
195                            assert_relative_eq!(res[(i, j)] as f64, product[(i, j)] as f64, epsilon = f64::EPSILON);
196                        }
197                    }
198                }
199            }
200
201            item! {
202                #[test]
203                fn [<test_mat_primitive_ $t>]() {
204                    let a = Matrix3::new(
205                        1 as $t, 2 as $t, 3 as $t,
206                        4 as $t, 5 as $t, 6 as $t,
207                        3 as $t, 2 as $t, 1 as $t
208                    );
209                    let res = Matrix3::new(
210                        2 as $t, 4 as $t, 6 as $t,
211                        8 as $t, 10 as $t, 12 as $t,
212                        6 as $t, 4 as $t, 2 as $t
213                    );
214                    let product: Matrix3<$t> =
215                        <Matrix3<$t> as ArgminDot<$t, Matrix3<$t>>>::dot(&a, &(2 as $t));
216                    for i in 0..3 {
217                        for j in 0..3 {
218                            assert_relative_eq!(res[(i, j)] as f64, product[(i, j)] as f64, epsilon = f64::EPSILON);
219                        }
220                    }
221                }
222            }
223
224            item! {
225                #[test]
226                fn [<test_primitive_mat_ $t>]() {
227                    let a = Matrix3::new(
228                        1 as $t, 2 as $t, 3 as $t,
229                        4 as $t, 5 as $t, 6 as $t,
230                        3 as $t, 2 as $t, 1 as $t
231                    );
232                    let res = Matrix3::new(
233                        2 as $t, 4 as $t, 6 as $t,
234                        8 as $t, 10 as $t, 12 as $t,
235                        6 as $t, 4 as $t, 2 as $t
236                    );
237                    let product: Matrix3<$t> =
238                        <$t as ArgminDot<Matrix3<$t>, Matrix3<$t>>>::dot(&(2 as $t), &a);
239                    for i in 0..3 {
240                        for j in 0..3 {
241                            assert_relative_eq!(res[(i, j)] as f64, product[(i, j)] as f64, epsilon = f64::EPSILON);
242                        }
243                    }
244                }
245            }
246        };
247    }
248
249    make_test!(i8);
250    make_test!(u8);
251    make_test!(i16);
252    make_test!(u16);
253    make_test!(i32);
254    make_test!(u32);
255    make_test!(i64);
256    make_test!(u64);
257    make_test!(f32);
258    make_test!(f64);
259}