argmin_math/nalgebra_m/
random.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use rand::{distributions::uniform::SampleUniform, Rng};
9
10use crate::{Allocator, ArgminRandom};
11
12use nalgebra::{
13    base::{dimension::Dim, Scalar},
14    DefaultAllocator, OMatrix,
15};
16
17impl<N, R, C> ArgminRandom for OMatrix<N, R, C>
18where
19    N: Scalar + PartialOrd + SampleUniform,
20    R: Dim,
21    C: Dim,
22    DefaultAllocator: Allocator<N, R, C>,
23{
24    #[inline]
25    fn rand_from_range<T: Rng>(min: &Self, max: &Self, rng: &mut T) -> OMatrix<N, R, C> {
26        assert!(!min.is_empty());
27        assert_eq!(min.shape(), max.shape());
28
29        Self::from_iterator_generic(
30            R::from_usize(min.nrows()),
31            C::from_usize(min.ncols()),
32            min.iter().zip(max.iter()).map(|(a, b)| {
33                // Do not require a < b:
34
35                // We do want to know if a and b are *exactly* the same.
36                #[allow(clippy::float_cmp)]
37                if a == b {
38                    a.clone()
39                } else if a < b {
40                    rng.gen_range(a.clone()..b.clone())
41                } else {
42                    rng.gen_range(b.clone()..a.clone())
43                }
44            }),
45        )
46    }
47}
48
49#[cfg(test)]
50mod tests {
51    use super::*;
52    use nalgebra::{Matrix2x3, Vector3};
53    use paste::item;
54    use rand::SeedableRng;
55
56    macro_rules! make_test {
57        ($t:ty) => {
58            item! {
59                #[test]
60                fn [<test_random_vec_ $t>]() {
61                    let a = Vector3::new(1 as $t, 2 as $t, 3 as $t);
62                    let b = Vector3::new(2 as $t, 3 as $t, 4 as $t);
63                    let mut rng = rand::rngs::StdRng::seed_from_u64(42);
64                    let random = Vector3::<$t>::rand_from_range(&a, &b, &mut rng);
65                    for i in 0..3 {
66                        assert!(random[i] >= a[i]);
67                        assert!(random[i] <= b[i]);
68                    }
69                }
70            }
71
72            item! {
73                #[test]
74                fn [<test_random_vec_equal $t>]() {
75                    let a = Vector3::new(1 as $t, 2 as $t, 3 as $t);
76                    let b = Vector3::new(1 as $t, 2 as $t, 3 as $t);
77                    let mut rng = rand::rngs::StdRng::seed_from_u64(42);
78                    let random = Vector3::<$t>::rand_from_range(&a, &b, &mut rng);
79                    for i in 0..3 {
80                        assert!((random[i] as f64 - a[i] as f64).abs() < f64::EPSILON);
81                        assert!((random[i] as f64 - b[i] as f64).abs() < f64::EPSILON);
82                    }
83                }
84            }
85
86            item! {
87                #[test]
88                fn [<test_random_vec_reverse_ $t>]() {
89                    let b = Vector3::new(1 as $t, 2 as $t, 3 as $t);
90                    let a = Vector3::new(2 as $t, 3 as $t, 4 as $t);
91                    let mut rng = rand::rngs::StdRng::seed_from_u64(42);
92                    let random = Vector3::<$t>::rand_from_range(&a, &b, &mut rng);
93                    for i in 0..3 {
94                        assert!(random[i] >= b[i]);
95                        assert!(random[i] <= a[i]);
96                    }
97                }
98            }
99
100            item! {
101                #[test]
102                fn [<test_random_mat_ $t>]() {
103                    let a = Matrix2x3::new(
104                        1 as $t, 3 as $t, 5 as $t,
105                        2 as $t, 4 as $t, 6 as $t
106                    );
107                    let b = Matrix2x3::new(
108                        2 as $t, 4 as $t, 6 as $t,
109                        3 as $t, 5 as $t, 7 as $t
110                    );
111                    let mut rng = rand::rngs::StdRng::seed_from_u64(42);
112                    let random = Matrix2x3::<$t>::rand_from_range(&a, &b, &mut rng);
113                    for i in 0..3 {
114                        for j in 0..2 {
115                            assert!(random[(j, i)] >= a[(j, i)]);
116                            assert!(random[(j, i)] <= b[(j, i)]);
117                        }
118                    }
119                }
120            }
121        };
122    }
123
124    make_test!(i8);
125    make_test!(u8);
126    make_test!(i16);
127    make_test!(u16);
128    make_test!(i32);
129    make_test!(u32);
130    make_test!(i64);
131    make_test!(u64);
132    make_test!(f32);
133    make_test!(f64);
134}