argmin_math/nalgebra_m/
random.rs1use rand::{distributions::uniform::SampleUniform, Rng};
9
10use crate::{Allocator, ArgminRandom};
11
12use nalgebra::{
13 base::{dimension::Dim, Scalar},
14 DefaultAllocator, OMatrix,
15};
16
17impl<N, R, C> ArgminRandom for OMatrix<N, R, C>
18where
19 N: Scalar + PartialOrd + SampleUniform,
20 R: Dim,
21 C: Dim,
22 DefaultAllocator: Allocator<N, R, C>,
23{
24 #[inline]
25 fn rand_from_range<T: Rng>(min: &Self, max: &Self, rng: &mut T) -> OMatrix<N, R, C> {
26 assert!(!min.is_empty());
27 assert_eq!(min.shape(), max.shape());
28
29 Self::from_iterator_generic(
30 R::from_usize(min.nrows()),
31 C::from_usize(min.ncols()),
32 min.iter().zip(max.iter()).map(|(a, b)| {
33 #[allow(clippy::float_cmp)]
37 if a == b {
38 a.clone()
39 } else if a < b {
40 rng.gen_range(a.clone()..b.clone())
41 } else {
42 rng.gen_range(b.clone()..a.clone())
43 }
44 }),
45 )
46 }
47}
48
49#[cfg(test)]
50mod tests {
51 use super::*;
52 use nalgebra::{Matrix2x3, Vector3};
53 use paste::item;
54 use rand::SeedableRng;
55
56 macro_rules! make_test {
57 ($t:ty) => {
58 item! {
59 #[test]
60 fn [<test_random_vec_ $t>]() {
61 let a = Vector3::new(1 as $t, 2 as $t, 3 as $t);
62 let b = Vector3::new(2 as $t, 3 as $t, 4 as $t);
63 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
64 let random = Vector3::<$t>::rand_from_range(&a, &b, &mut rng);
65 for i in 0..3 {
66 assert!(random[i] >= a[i]);
67 assert!(random[i] <= b[i]);
68 }
69 }
70 }
71
72 item! {
73 #[test]
74 fn [<test_random_vec_equal $t>]() {
75 let a = Vector3::new(1 as $t, 2 as $t, 3 as $t);
76 let b = Vector3::new(1 as $t, 2 as $t, 3 as $t);
77 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
78 let random = Vector3::<$t>::rand_from_range(&a, &b, &mut rng);
79 for i in 0..3 {
80 assert!((random[i] as f64 - a[i] as f64).abs() < f64::EPSILON);
81 assert!((random[i] as f64 - b[i] as f64).abs() < f64::EPSILON);
82 }
83 }
84 }
85
86 item! {
87 #[test]
88 fn [<test_random_vec_reverse_ $t>]() {
89 let b = Vector3::new(1 as $t, 2 as $t, 3 as $t);
90 let a = Vector3::new(2 as $t, 3 as $t, 4 as $t);
91 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
92 let random = Vector3::<$t>::rand_from_range(&a, &b, &mut rng);
93 for i in 0..3 {
94 assert!(random[i] >= b[i]);
95 assert!(random[i] <= a[i]);
96 }
97 }
98 }
99
100 item! {
101 #[test]
102 fn [<test_random_mat_ $t>]() {
103 let a = Matrix2x3::new(
104 1 as $t, 3 as $t, 5 as $t,
105 2 as $t, 4 as $t, 6 as $t
106 );
107 let b = Matrix2x3::new(
108 2 as $t, 4 as $t, 6 as $t,
109 3 as $t, 5 as $t, 7 as $t
110 );
111 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
112 let random = Matrix2x3::<$t>::rand_from_range(&a, &b, &mut rng);
113 for i in 0..3 {
114 for j in 0..2 {
115 assert!(random[(j, i)] >= a[(j, i)]);
116 assert!(random[(j, i)] <= b[(j, i)]);
117 }
118 }
119 }
120 }
121 };
122 }
123
124 make_test!(i8);
125 make_test!(u8);
126 make_test!(i16);
127 make_test!(u16);
128 make_test!(i32);
129 make_test!(u32);
130 make_test!(i64);
131 make_test!(u64);
132 make_test!(f32);
133 make_test!(f64);
134}