argmin_math/nalgebra_m/
eye.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use crate::{Allocator, ArgminEye};
9
10use num_traits::{One, Zero};
11
12use nalgebra::{base::dimension::Dim, DefaultAllocator, OMatrix, Scalar};
13
14impl<N, R, C> ArgminEye for OMatrix<N, R, C>
15where
16    N: Scalar + Zero + One,
17    R: Dim,
18    C: Dim,
19    DefaultAllocator: Allocator<N, R, C>,
20{
21    #[inline]
22    fn eye_like(&self) -> OMatrix<N, R, C> {
23        assert!(self.is_square());
24        Self::identity_generic(R::from_usize(self.nrows()), C::from_usize(self.ncols()))
25    }
26
27    #[inline]
28    fn eye(n: usize) -> OMatrix<N, R, C> {
29        Self::identity_generic(R::from_usize(n), C::from_usize(n))
30    }
31}
32
33#[cfg(test)]
34mod tests {
35    use super::*;
36    use approx::assert_relative_eq;
37    use nalgebra::{Matrix2x3, Matrix3};
38    use paste::item;
39
40    macro_rules! make_test {
41        ($t:ty) => {
42            item! {
43                #[test]
44                fn [<test_eye_ $t>]() {
45                    let e: Matrix3<$t> = <Matrix3<$t> as ArgminEye>::eye(3);
46                    let res = Matrix3::new(
47                        1 as $t, 0 as $t, 0 as $t,
48                        0 as $t, 1 as $t, 0 as $t,
49                        0 as $t, 0 as $t, 1 as $t
50                    );
51                    for i in 0..3 {
52                        for j in 0..3 {
53                            assert_relative_eq!(res[(i, j)] as f64, e[(i, j)] as f64, epsilon = f64::EPSILON);
54                        }
55                    }
56                }
57            }
58
59            item! {
60                #[test]
61                fn [<test_eye_like_ $t>]() {
62                    let a = Matrix3::new(
63                        0 as $t, 2 as $t, 6 as $t,
64                        3 as $t, 2 as $t, 7 as $t,
65                        9 as $t, 8 as $t, 1 as $t
66                    );
67                    let e: Matrix3<$t> = a.eye_like();
68                    let res = Matrix3::new(
69                        1 as $t, 0 as $t, 0 as $t,
70                        0 as $t, 1 as $t, 0 as $t,
71                        0 as $t, 0 as $t, 1 as $t
72                    );
73                    for i in 0..3 {
74                        for j in 0..3 {
75                            assert_relative_eq!(res[(i, j)] as f64, e[(i, j)] as f64, epsilon = f64::EPSILON);
76                        }
77                    }
78                }
79            }
80
81            item! {
82                #[test]
83                #[should_panic]
84                #[allow(unused)]
85                fn [<test_eye_like_panic_ $t>]() {
86                    let a = Matrix2x3::new(
87                        0 as $t, 2 as $t, 6 as $t,
88                        3 as $t, 2 as $t, 7 as $t,
89                    );
90                    let e: Matrix2x3<$t> = a.eye_like();
91                }
92            }
93        };
94    }
95
96    make_test!(i8);
97    make_test!(u8);
98    make_test!(i16);
99    make_test!(u16);
100    make_test!(i32);
101    make_test!(u32);
102    make_test!(i64);
103    make_test!(u64);
104    make_test!(f32);
105    make_test!(f64);
106}