argmin_testfunctions/
himmelblau.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! # Himmelblau's test function.
9//!
10//! Defined as
11//!
12//! $$
13//! f(x_1,\\,x_2) = (x_1^2 + x_2 - 11)^2 + (x_1 + x_2^2 - 7)^2
14//! $$
15//!
16//! where $x_i \in [-5,\\,5]$.
17//!
18//! The global minima are at
19//!  * $f(x_1,\\,x_2) = f(3,\\,2) = 0$.
20//!  * $f(x_1,\\,x_2) = f(-2.805118,\\,3.131312) = 0$.
21//!  * $f(x_1,\\,x_2) = f(-3.779310,\\,-3.283186) = 0$.
22//!  * $f(x_1,\\,x_2) = f(3.584428,\\,-1.848126) = 0$.
23
24use num::{Float, FromPrimitive};
25
26/// Himmelblau's test function.
27///
28/// Defined as
29///
30/// $$
31/// f(x_1,\\,x_2) = (x_1^2 + x_2 - 11)^2 + (x_1 + x_2^2 - 7)^2
32/// $$
33///
34/// where $x_i \in [-5,\\,5]$.
35///
36/// The global minima are at
37///  * $f(x_1,\\,x_2) = f(3,\\,2) = 0$.
38///  * $f(x_1,\\,x_2) = f(-2.805118,\\,3.131312) = 0$.
39///  * $f(x_1,\\,x_2) = f(-3.779310,\\,-3.283186) = 0$.
40///  * $f(x_1,\\,x_2) = f(3.584428,\\,-1.848126) = 0$.
41pub fn himmelblau<T>(param: &[T; 2]) -> T
42where
43    T: Float + FromPrimitive,
44{
45    let [x1, x2] = *param;
46    let n7 = T::from_f64(7.0).unwrap();
47    let n11 = T::from_f64(11.0).unwrap();
48    (x1.powi(2) + x2 - n11).powi(2) + (x1 + x2.powi(2) - n7).powi(2)
49}
50
51/// Derivative of Himmelblau's test function.
52pub fn himmelblau_derivative<T>(param: &[T; 2]) -> [T; 2]
53where
54    T: Float + FromPrimitive,
55{
56    let [x1, x2] = *param;
57
58    let n2 = T::from_f64(2.0).unwrap();
59    let n4 = T::from_f64(4.0).unwrap();
60    let n7 = T::from_f64(7.0).unwrap();
61    let n11 = T::from_f64(11.0).unwrap();
62
63    [
64        n4 * x1 * (x1.powi(2) + x2 - n11) + n2 * (x1 + x2.powi(2) - n7),
65        n4 * x2 * (x2.powi(2) + x1 - n7) + n2 * (x2 + x1.powi(2) - n11),
66    ]
67}
68
69/// Hessian of Himmelblau's test function.
70pub fn himmelblau_hessian<T>(param: &[T; 2]) -> [[T; 2]; 2]
71where
72    T: Float + FromPrimitive,
73{
74    let [x1, x2] = *param;
75
76    let n2 = T::from_f64(2.0).unwrap();
77    let n4 = T::from_f64(4.0).unwrap();
78    let n7 = T::from_f64(7.0).unwrap();
79    let n8 = T::from_f64(8.0).unwrap();
80    let n11 = T::from_f64(11.0).unwrap();
81
82    let offdiag = n4 * (x1 + x2);
83
84    [
85        [n4 * (x1.powi(2) + x2 - n11) + n8 * x1.powi(2) + n2, offdiag],
86        [offdiag, n4 * (x2.powi(2) + x1 - n7) + n8 * x2.powi(2) + n2],
87    ]
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93    use approx::assert_relative_eq;
94    use finitediff::FiniteDiff;
95    use proptest::prelude::*;
96    use std::{f32, f64};
97
98    #[test]
99    fn test_himmelblau_optimum() {
100        assert_relative_eq!(himmelblau(&[3.0_f32, 2.0_f32]), 0.0, epsilon = f32::EPSILON);
101        assert_relative_eq!(
102            himmelblau(&[-2.805118_f32, 3.131312_f32]),
103            0.0,
104            epsilon = f32::EPSILON
105        );
106        assert_relative_eq!(
107            himmelblau(&[-3.779310_f32, -3.283186_f32]),
108            0.0,
109            epsilon = f32::EPSILON
110        );
111        assert_relative_eq!(
112            himmelblau(&[3.584428_f32, -1.848126_f32]),
113            0.0,
114            epsilon = f32::EPSILON
115        );
116
117        // Since we don't know the 64bit location of the minima,the f64 version cannot be reliably
118        // tested without allowing an error larger than f64::EPSILON.
119        assert_relative_eq!(himmelblau(&[3.0_f64, 2.0_f64]), 0.0, epsilon = f64::EPSILON);
120        assert_relative_eq!(
121            himmelblau(&[-2.805118_f64, 3.131312_f64]),
122            0.0,
123            epsilon = f32::EPSILON.into()
124        );
125        assert_relative_eq!(
126            himmelblau(&[-3.779310_f64, -3.283186_f64]),
127            0.0,
128            epsilon = f32::EPSILON.into()
129        );
130        assert_relative_eq!(
131            himmelblau(&[3.584428_f64, -1.848126_f64]),
132            0.0,
133            epsilon = f32::EPSILON.into()
134        );
135
136        let deriv = himmelblau_derivative(&[3.0_f32, 2.0_f32]);
137        for i in 0..2 {
138            assert_relative_eq!(deriv[i], 0.0, epsilon = f32::EPSILON);
139        }
140
141        let deriv = himmelblau_derivative(&[-2.805118_f32, 3.131312_f32]);
142        for i in 0..2 {
143            assert_relative_eq!(deriv[i], 0.0, epsilon = 1e-4);
144        }
145
146        let deriv = himmelblau_derivative(&[-3.779310_f64, -3.283186_f64]);
147        for i in 0..2 {
148            assert_relative_eq!(deriv[i], 0.0, epsilon = 1e-4);
149        }
150
151        let deriv = himmelblau_derivative(&[3.584428_f64, -1.848126_f64]);
152        for i in 0..2 {
153            assert_relative_eq!(deriv[i], 0.0, epsilon = 1e-4);
154        }
155    }
156
157    proptest! {
158        #[test]
159        fn test_himmelblau_derivative_finitediff(a in -5.0..5.0, b in -5.0..5.0) {
160            let param = [a, b];
161            let derivative = himmelblau_derivative(&param);
162            let derivative_fd = Vec::from(param).central_diff(&|x| himmelblau(&[x[0], x[1]]));
163            for i in 0..derivative.len() {
164                assert_relative_eq!(
165                    derivative[i],
166                    derivative_fd[i],
167                    epsilon = 1e-4,
168                    max_relative = 1e-2
169                );
170            }
171        }
172    }
173
174    proptest! {
175        #[test]
176        fn test_himmelblau_hessian_finitediff(a in -5.0..5.0, b in -5.0..5.0) {
177            let param = [a, b];
178            let hessian = himmelblau_hessian(&param);
179            let hessian_fd =
180                Vec::from(param).central_hessian(&|x| himmelblau_derivative(&[x[0], x[1]]).to_vec());
181            let n = hessian.len();
182            // println!("1: {hessian:?} at {a}/{b}");
183            // println!("2: {hessian_fd:?} at {a}/{b}");
184            for i in 0..n {
185                assert_eq!(hessian[i].len(), n);
186                for j in 0..n {
187                    if hessian_fd[i][j].is_finite() {
188                        assert_relative_eq!(
189                            hessian[i][j],
190                            hessian_fd[i][j],
191                            epsilon = 1e-5,
192                            max_relative = 1e-2
193                        );
194                    }
195                }
196            }
197        }
198    }
199}