argmin_math/ndarray_m/
random.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use rand::Rng;
9
10use crate::ArgminRandom;
11
12macro_rules! make_random {
13    ($t:ty) => {
14        impl ArgminRandom for ndarray::Array1<$t> {
15            fn rand_from_range<R: Rng>(min: &Self, max: &Self, rng: &mut R) -> ndarray::Array1<$t> {
16                assert!(!min.is_empty());
17                assert_eq!(min.len(), max.len());
18
19                ndarray::Array1::from_iter(min.iter().zip(max.iter()).map(|(a, b)| {
20                    // Do not require a < b:
21
22                    // We do want to know if a and b are *exactly* the same.
23                    #[allow(clippy::float_cmp)]
24                    if a == b {
25                        a.clone()
26                    } else if a < b {
27                        rng.gen_range(a.clone()..b.clone())
28                    } else {
29                        rng.gen_range(b.clone()..a.clone())
30                    }
31                }))
32            }
33        }
34
35        impl ArgminRandom for ndarray::Array2<$t> {
36            fn rand_from_range<R: Rng>(min: &Self, max: &Self, rng: &mut R) -> ndarray::Array2<$t> {
37                assert!(!min.is_empty());
38                assert_eq!(min.raw_dim(), max.raw_dim());
39
40                ndarray::Array2::from_shape_fn(min.raw_dim(), |(i, j)| {
41                    let a = min.get((i, j)).unwrap();
42                    let b = max.get((i, j)).unwrap();
43
44                    // We do want to know if a and b are *exactly* the same.
45                    #[allow(clippy::float_cmp)]
46                    if a == b {
47                        a.clone()
48                    } else if a < b {
49                        rng.gen_range(a.clone()..b.clone())
50                    } else {
51                        rng.gen_range(b.clone()..a.clone())
52                    }
53                })
54            }
55        }
56    };
57}
58
59make_random!(i8);
60make_random!(u8);
61make_random!(i16);
62make_random!(u16);
63make_random!(i32);
64make_random!(u32);
65make_random!(i64);
66make_random!(u64);
67make_random!(f32);
68make_random!(f64);
69
70// All code that does not depend on a linked ndarray-linalg backend can still be tested as normal.
71// To avoid dublicating tests and to allow convenient testing of functionality that does not need ndarray-linalg the tests are still included here.
72// The tests expect the name for the crate containing the tested functions to be argmin_math
73#[cfg(test)]
74use crate as argmin_math;
75include!(concat!(
76    env!("CARGO_MANIFEST_DIR"),
77    "/ndarray-tests-src/random.rs"
78));