argmin_testfunctions/
rastrigin.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! # Rastrigin test function
9//!
10//! Defined as
11//!
12//! $$
13//! f(x_1,\\,x_2,\\,\ldots,\\,x_d) = a\cdot d + \sum_{i=1}^{d} \left[ x_i^2 - a\cos(2\pi x_i) \right]
14//! $$
15//!
16//! where $x_i \in [-5.12,\\,5.12]$ and $a = 10$.
17//!
18//! The global minimum is at $f(x_1,\\,x_2,\\,\ldots,\\,x_d) = f(0,\\,0,\\,\ldots,\\,0) = 0$.
19
20use num::{Float, FromPrimitive};
21use std::f64::consts::PI;
22use std::iter::Sum;
23
24/// Rastrigin test function.
25///
26/// Defined as
27///
28/// $$
29/// f(x_1,\\,x_2,\\,\ldots,\\,x_d) = a\cdot d + \sum_{i=1}^{d} \left[ x_i^2 - a\cos(2\pi x_i) \right]
30/// $$
31///
32/// where $x_i \in [-5.12,\\,5.12]$ and $a = 10$.
33///
34/// The global minimum is at $f(x_1,\\,x_2,\\,\ldots,\\,x_d) = f(0,\\,0,\\,\ldots,\\,0) = 0$.
35///
36/// See [`rastrigin_a`] for a variant where the parameter `a` can be chosen freely.
37pub fn rastrigin<T>(param: &[T]) -> T
38where
39    T: Float + FromPrimitive + Sum,
40{
41    rastrigin_a(param, T::from_f64(10.0).unwrap())
42}
43
44/// Rastrigin test function.
45///
46/// The same as [`rastrigin`]; however, it allows to set the parameter a.
47pub fn rastrigin_a<T>(param: &[T], a: T) -> T
48where
49    T: Float + FromPrimitive + Sum,
50{
51    a * T::from_usize(param.len()).unwrap()
52        + param
53            .iter()
54            .map(|&x| x.powi(2) - a * (T::from_f64(2.0 * PI).unwrap() * x).cos())
55            .sum()
56}
57
58/// Derivative of Rastrigin test function where the parameter `a` can be chosen freely.
59pub fn rastrigin_a_derivative<T>(param: &[T], a: T) -> Vec<T>
60where
61    T: Float + FromPrimitive + Sum + Into<f64>,
62{
63    let npi2 = T::from_f64(2.0 * PI).unwrap();
64    let n2 = T::from_f64(2.0).unwrap();
65    param
66        .iter()
67        .map(|x| n2 * *x + npi2 * a * T::from_f64(f64::sin((npi2 * *x).into())).unwrap())
68        .collect()
69}
70
71/// Derivative of Rastrigin test function with `a = 10`.
72pub fn rastrigin_derivative<T>(param: &[T]) -> Vec<T>
73where
74    T: Float + FromPrimitive + Sum + Into<f64>,
75{
76    rastrigin_a_derivative(param, T::from_f64(10.0).unwrap())
77}
78
79/// Derivative of Rastrigin test function where the parameter `a` can be chosen freely.
80///
81/// This is the const generics version, which requires the number of parameters to be known
82/// at compile time.
83pub fn rastrigin_a_derivative_const<const N: usize, T>(param: &[T; N], a: T) -> [T; N]
84where
85    T: Float + FromPrimitive + Sum + Into<f64>,
86{
87    let npi2 = T::from_f64(2.0 * PI).unwrap();
88    let n2 = T::from_f64(2.0).unwrap();
89    let mut result = [T::from_f64(0.0).unwrap(); N];
90    for i in 0..N {
91        result[i] =
92            n2 * param[i] + npi2 * a * T::from_f64(f64::sin((npi2 * param[i]).into())).unwrap();
93    }
94    result
95}
96
97/// Derivative of Rastrigin test function with `a = 10`.
98///
99/// This is the const generics version, which requires the number of parameters to be known
100/// at compile time.
101pub fn rastrigin_derivative_const<const N: usize, T>(param: &[T; N]) -> [T; N]
102where
103    T: Float + FromPrimitive + Sum + Into<f64>,
104{
105    rastrigin_a_derivative_const(param, T::from_f64(10.0).unwrap())
106}
107
108/// Hessian of Rastrigin test function where the parameter `a` can be chosen freely.
109pub fn rastrigin_a_hessian<T>(param: &[T], a: T) -> Vec<Vec<T>>
110where
111    T: Float + FromPrimitive + Sum + Into<f64>,
112{
113    let npi2 = T::from_f64(2.0 * PI).unwrap();
114    let n4pisq = T::from_f64(4.0 * PI.powi(2)).unwrap();
115    let n2 = T::from_f64(2.0).unwrap();
116    let n0 = T::from_f64(0.0).unwrap();
117
118    let n = param.len();
119    let mut hessian = vec![vec![n0; n]; n];
120
121    for i in 0..n {
122        hessian[i][i] = n2 + n4pisq * a * T::from_f64(f64::cos((npi2 * param[i]).into())).unwrap();
123    }
124    hessian
125}
126
127/// Hessian of Rastrigin test function with `a = 10`.
128pub fn rastrigin_hessian<T>(param: &[T]) -> Vec<Vec<T>>
129where
130    T: Float + FromPrimitive + Sum + Into<f64>,
131{
132    rastrigin_a_hessian(param, T::from_f64(10.0).unwrap())
133}
134
135/// Hessian of Rastrigin test function where `a` can be chosen freely.
136///
137/// This is the const generics version, which requires the number of parameters to be known
138/// at compile time.
139pub fn rastrigin_a_hessian_const<const N: usize, T>(param: &[T], a: T) -> [[T; N]; N]
140where
141    T: Float + FromPrimitive + Sum + Into<f64>,
142{
143    let npi2 = T::from_f64(2.0 * PI).unwrap();
144    let n4pisq = T::from_f64(4.0 * PI.powi(2)).unwrap();
145    let n2 = T::from_f64(2.0).unwrap();
146    let n0 = T::from_f64(0.0).unwrap();
147
148    let mut hessian = [[n0; N]; N];
149
150    for i in 0..N {
151        hessian[i][i] = n2 + n4pisq * a * T::from_f64(f64::cos((npi2 * param[i]).into())).unwrap();
152    }
153    hessian
154}
155
156/// Hessian of Rastrigin test function with `a = 10`.
157///
158/// This is the const generics version, which requires the number of parameters to be known
159/// at compile time.
160pub fn rastrigin_hessian_const<const N: usize, T>(param: &[T; N]) -> [[T; N]; N]
161where
162    T: Float + FromPrimitive + Sum + Into<f64>,
163{
164    rastrigin_a_hessian_const(param, T::from_f64(10.0).unwrap())
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use approx::assert_relative_eq;
171    use finitediff::FiniteDiff;
172    use proptest::prelude::*;
173    use std::{f32, f64};
174
175    #[test]
176    fn test_rastrigin_optimum() {
177        assert_relative_eq!(rastrigin(&[0.0_f32, 0.0_f32]), 0.0, epsilon = f32::EPSILON);
178        assert_relative_eq!(rastrigin(&[0.0_f64, 0.0_f64]), 0.0, epsilon = f64::EPSILON);
179    }
180
181    #[test]
182    fn test_parameter_a() {
183        assert_relative_eq!(
184            rastrigin(&[0.0_f32, 0.0_f32]),
185            rastrigin_a(&[0.0_f32, 0.0_f32], 10.0),
186            epsilon = f32::EPSILON
187        );
188
189        assert_relative_eq!(
190            rastrigin(&[0.0_f64, 0.0_f64]),
191            rastrigin_a(&[0.0_f64, 0.0_f64], 10.0),
192            epsilon = f64::EPSILON
193        );
194
195        let derivative = rastrigin_derivative(&[1.0_f64, -1.0_f64]);
196        let derivative_a = rastrigin_a_derivative(&[1.0_f64, -1.0_f64], 10.0);
197        for i in 0..derivative.len() {
198            assert_relative_eq!(derivative[i], derivative_a[i], epsilon = f64::EPSILON);
199        }
200
201        let derivative = rastrigin_derivative_const(&[1.0_f64, -1.0_f64]);
202        let derivative_a = rastrigin_a_derivative_const(&[1.0_f64, -1.0_f64], 10.0);
203        for i in 0..derivative.len() {
204            assert_relative_eq!(derivative[i], derivative_a[i], epsilon = f64::EPSILON);
205        }
206
207        let hessian = rastrigin_hessian(&[1.0_f64, -1.0_f64]);
208        let hessian_a = rastrigin_a_hessian(&[1.0_f64, -1.0_f64], 10.0);
209        for i in 0..hessian.len() {
210            for j in 0..hessian.len() {
211                assert_relative_eq!(hessian[i][j], hessian_a[i][j], epsilon = f64::EPSILON);
212            }
213        }
214
215        let hessian = rastrigin_hessian_const(&[1.0_f64, -1.0_f64]);
216        let hessian_a: [[_; 2]; 2] = rastrigin_a_hessian_const(&[1.0_f64, -1.0_f64], 10.0);
217        for i in 0..hessian.len() {
218            for j in 0..hessian.len() {
219                assert_relative_eq!(hessian[i][j], hessian_a[i][j], epsilon = f64::EPSILON);
220            }
221        }
222    }
223
224    #[test]
225    fn test_rastrigin_a_derivative_optimum() {
226        let derivative = rastrigin_a_derivative(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 10.0);
227        for elem in derivative {
228            assert_relative_eq!(elem, 0.0, epsilon = f64::EPSILON);
229        }
230    }
231
232    #[test]
233    fn test_rastrigin_a_derivative_const_optimum() {
234        let derivative =
235            rastrigin_a_derivative_const(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 10.0);
236        for elem in derivative {
237            assert_relative_eq!(elem, 0.0, epsilon = f64::EPSILON);
238        }
239    }
240
241    proptest! {
242        #[test]
243        fn test_rastrigin_derivative_finitediff(a in -5.12..5.12,
244                                                b in -5.12..5.12,
245                                                c in -5.12..5.12,
246                                                d in -5.12..5.12,
247                                                e in -5.12..5.12,
248                                                f in -5.12..5.12,
249                                                g in -5.12..5.12,
250                                                h in -5.12..5.12) {
251            let param = [a, b, c, d, e, f, g, h];
252            let derivative = rastrigin_derivative(&param);
253            let derivative_fd = Vec::from(param).central_diff(&|x| rastrigin(&x));
254            for i in 0..derivative.len() {
255                assert_relative_eq!(
256                    derivative[i],
257                    derivative_fd[i],
258                    epsilon = 1e-5,
259                    max_relative = 1e-2
260                );
261            }
262        }
263    }
264
265    proptest! {
266        #[test]
267        fn test_rastrigin_derivative_const_finitediff(a in -5.12..5.12,
268                                                      b in -5.12..5.12,
269                                                      c in -5.12..5.12,
270                                                      d in -5.12..5.12,
271                                                      e in -5.12..5.12,
272                                                      f in -5.12..5.12,
273                                                      g in -5.12..5.12,
274                                                      h in -5.12..5.12) {
275            let param = [a, b, c, d, e, f, g, h];
276            let derivative = rastrigin_derivative_const(&param);
277            let derivative_fd = Vec::from(param).central_diff(&|x| rastrigin(&x));
278            for i in 0..derivative.len() {
279                assert_relative_eq!(
280                    derivative[i],
281                    derivative_fd[i],
282                    epsilon = 1e-5,
283                    max_relative = 1e-2
284                );
285            }
286        }
287    }
288
289    proptest! {
290        #[test]
291        fn test_rastrigin_hessian_finitediff(a in -5.12..5.12,
292                                             b in -5.12..5.12,
293                                             c in -5.12..5.12,
294                                             d in -5.12..5.12,
295                                             e in -5.12..5.12,
296                                             f in -5.12..5.12,
297                                             g in -5.12..5.12,
298                                             h in -5.12..5.12) {
299            let param = [a, b, c, d, e, f, g, h];
300            let hessian = rastrigin_hessian(&param);
301            let hessian_fd =
302                Vec::from(param).forward_hessian(&|x| rastrigin_derivative(&x));
303            let n = hessian.len();
304            for i in 0..n {
305                assert_eq!(hessian[i].len(), n);
306                for j in 0..n {
307                    assert_relative_eq!(
308                        hessian[i][j],
309                        hessian_fd[i][j],
310                        epsilon = 1e-4,
311                        max_relative = 1e-2
312                    );
313                }
314            }
315        }
316    }
317
318    proptest! {
319        #[test]
320        fn test_rastrigin_hessian_const_finitediff(a in -5.12..5.12,
321                                                   b in -5.12..5.12,
322                                                   c in -5.12..5.12,
323                                                   d in -5.12..5.12,
324                                                   e in -5.12..5.12,
325                                                   f in -5.12..5.12,
326                                                   g in -5.12..5.12,
327                                                   h in -5.12..5.12) {
328            let param = [a, b, c, d, e, f, g, h];
329            let hessian = rastrigin_hessian_const(&param);
330            let hessian_fd =
331                Vec::from(param).forward_hessian(&|x| rastrigin_derivative(&x));
332            let n = hessian.len();
333            for i in 0..n {
334                assert_eq!(hessian[i].len(), n);
335                for j in 0..n {
336                    assert_relative_eq!(
337                        hessian[i][j],
338                        hessian_fd[i][j],
339                        epsilon = 1e-4,
340                        max_relative = 1e-2
341                    );
342                }
343            }
344        }
345    }
346}