argmin_math/nalgebra_m/
mul.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use crate::{Allocator, ArgminMul, SameShapeAllocator};
9
10use crate::ClosedMul;
11use nalgebra::{
12    base::{
13        constraint::{SameNumberOfColumns, SameNumberOfRows, ShapeConstraint},
14        dimension::Dim,
15        storage::Storage,
16        MatrixSum, Scalar,
17    },
18    DefaultAllocator, Matrix, OMatrix,
19};
20
21impl<N, R, C, S> ArgminMul<N, OMatrix<N, R, C>> for Matrix<N, R, C, S>
22where
23    N: Scalar + Copy + ClosedMul,
24    R: Dim,
25    C: Dim,
26    S: Storage<N, R, C>,
27    DefaultAllocator: Allocator<N, R, C>,
28{
29    #[inline]
30    fn mul(&self, other: &N) -> OMatrix<N, R, C> {
31        self * *other
32    }
33}
34
35impl<N, R, C, S> ArgminMul<Matrix<N, R, C, S>, OMatrix<N, R, C>> for N
36where
37    N: Scalar + Copy + ClosedMul,
38    R: Dim,
39    C: Dim,
40    S: Storage<N, R, C>,
41    DefaultAllocator: Allocator<N, R, C>,
42{
43    #[inline]
44    fn mul(&self, other: &Matrix<N, R, C, S>) -> OMatrix<N, R, C> {
45        other * *self
46    }
47}
48
49impl<N, R1, R2, C1, C2, SA, SB> ArgminMul<Matrix<N, R2, C2, SB>, MatrixSum<N, R1, C1, R2, C2>>
50    for Matrix<N, R1, C1, SA>
51where
52    N: Scalar + ClosedMul,
53    R1: Dim,
54    R2: Dim,
55    C1: Dim,
56    C2: Dim,
57    SA: Storage<N, R1, C1>,
58    SB: Storage<N, R2, C2>,
59    DefaultAllocator: SameShapeAllocator<N, R1, C1, R2, C2>,
60    ShapeConstraint: SameNumberOfRows<R1, R2> + SameNumberOfColumns<C1, C2>,
61{
62    #[inline]
63    fn mul(&self, other: &Matrix<N, R2, C2, SB>) -> MatrixSum<N, R1, C1, R2, C2> {
64        self.component_mul(other)
65    }
66}
67
68#[cfg(test)]
69mod tests {
70    use super::*;
71    use approx::assert_relative_eq;
72    use nalgebra::{Matrix2x3, Vector3};
73    use paste::item;
74
75    macro_rules! make_test {
76        ($t:ty) => {
77            item! {
78                #[test]
79                fn [<test_mul_vec_scalar_ $t>]() {
80                    let a = Vector3::new(1 as $t, 4 as $t, 8 as $t);
81                    let b = 2 as $t;
82                    let target = Vector3::new(2 as $t, 8 as $t, 16 as $t);
83                    let res = <Vector3<$t> as ArgminMul<$t, Vector3<$t>>>::mul(&a, &b);
84                    for i in 0..3 {
85                        assert_relative_eq!(target[i] as f64, res[i] as f64, epsilon = f64::EPSILON);
86                    }
87                }
88            }
89
90            item! {
91                #[test]
92                fn [<test_mul_scalar_vec_ $t>]() {
93                    let a = Vector3::new(1 as $t, 4 as $t, 8 as $t);
94                    let b = 2 as $t;
95                    let target = Vector3::new(2 as $t, 8 as $t, 16 as $t);
96                    let res = <$t as ArgminMul<Vector3<$t>, Vector3<$t>>>::mul(&b, &a);
97                    for i in 0..3 {
98                        assert_relative_eq!(target[i] as f64, res[i] as f64, epsilon = f64::EPSILON);
99                    }
100                }
101            }
102
103            item! {
104                #[test]
105                fn [<test_mul_vec_vec_ $t>]() {
106                    let a = Vector3::new(1 as $t, 4 as $t, 8 as $t);
107                    let b = Vector3::new(2 as $t, 3 as $t, 4 as $t);
108                    let target = Vector3::new(2 as $t, 12 as $t, 32 as $t);
109                    let res = <Vector3<$t> as ArgminMul<Vector3<$t>, Vector3<$t>>>::mul(&a, &b);
110                    for i in 0..3 {
111                        assert_relative_eq!(target[i] as f64, res[i] as f64, epsilon = f64::EPSILON);
112                    }
113                }
114            }
115
116            item! {
117                #[test]
118                fn [<test_mul_mat_mat_ $t>]() {
119                    let a = Matrix2x3::new(
120                        1 as $t, 4 as $t, 8 as $t,
121                        2 as $t, 5 as $t, 9 as $t
122                    );
123                    let b = Matrix2x3::new(
124                        2 as $t, 3 as $t, 4 as $t,
125                        3 as $t, 4 as $t, 5 as $t
126                    );
127                    let target = Matrix2x3::new(
128                        2 as $t, 12 as $t, 32 as $t,
129                        6 as $t, 20 as $t, 45 as $t
130                    );
131                    let res = <Matrix2x3<$t> as ArgminMul<Matrix2x3<$t>, Matrix2x3<$t>>>::mul(&a, &b);
132                    for i in 0..3 {
133                        for j in 0..2 {
134                            assert_relative_eq!(target[(j, i)] as f64, res[(j, i)] as f64, epsilon = f64::EPSILON);
135                        }
136                    }
137                }
138            }
139
140            item! {
141                #[test]
142                fn [<test_mul_scalar_mat_1_ $t>]() {
143                    let a = Matrix2x3::new(
144                        1 as $t, 4 as $t, 8 as $t,
145                        2 as $t, 5 as $t, 9 as $t
146                    );
147                    let b = 2 as $t;
148                    let target = Matrix2x3::new(
149                        2 as $t, 8 as $t, 16 as $t,
150                        4 as $t, 10 as $t, 18 as $t
151                    );
152                    let res = <Matrix2x3<$t> as ArgminMul<$t, Matrix2x3<$t>>>::mul(&a, &b);
153                    for i in 0..3 {
154                        for j in 0..2 {
155                            assert_relative_eq!(target[(j, i)] as f64, res[(j, i)] as f64, epsilon = f64::EPSILON);
156                        }
157                    }
158                }
159            }
160
161            item! {
162                #[test]
163                fn [<test_mul_scalar_mat_2_ $t>]() {
164                    let b = Matrix2x3::new(
165                        1 as $t, 4 as $t, 8 as $t,
166                        2 as $t, 5 as $t, 9 as $t
167                    );
168                    let a = 2 as $t;
169                    let target = Matrix2x3::new(
170                        2 as $t, 8 as $t, 16 as $t,
171                        4 as $t, 10 as $t, 18 as $t
172                    );
173                    let res = <$t as ArgminMul<Matrix2x3<$t>, Matrix2x3<$t>>>::mul(&a, &b);
174                    for i in 0..3 {
175                        for j in 0..2 {
176                            assert_relative_eq!(target[(j, i)] as f64, res[(j, i)] as f64, epsilon = f64::EPSILON);
177                        }
178                    }
179                }
180            }
181        };
182    }
183
184    make_test!(i8);
185    make_test!(u8);
186    make_test!(i16);
187    make_test!(u16);
188    make_test!(i32);
189    make_test!(u32);
190    make_test!(i64);
191    make_test!(u64);
192    make_test!(f32);
193    make_test!(f64);
194}