argmin_testfunctions/
goldsteinprice.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! # Goldstein-Price test function.
9//!
10//! Defined as
11//!
12//! $$
13//! \begin{aligned}
14//! f(x_1,\\,x_2) &= [1 + (x_1 + x_2 + 1)^2(19 - 14 x_1 + 3 x_1^2 - 14 x_2 + 6 x_1 x_2 + 3 x_2^2)] \\\\
15//!                &\times [30 + (2 x_1 - 3 x_2)^2(18 - 32 x_1 + 12 x_1^2 + 48 x_2 -
16//!                   36x_1 x_2 + 27 x_2^2) ]
17//! \end{aligned}
18//! $$
19//!
20//! where $x_i \in [-2,\\,2]$.
21//!
22//! The global minimum is at $f(x_1,\\,x_2) = f(0,\\,-1) = 3$.
23
24use num::{Float, FromPrimitive};
25
26/// Goldstein-Price test function.
27///
28/// Defined as
29///
30/// $$
31/// \begin{aligned}
32/// f(x_1,\\,x_2) &= [1 + (x_1 + x_2 + 1)^2(19 - 14 x_1 + 3 x_1^2 - 14 x_2 + 6 x_1 x_2 + 3 x_2^2)] \\\\
33///                &\times [30 + (2 x_1 - 3 x_2)^2(18 - 32 x_1 + 12 x_1^2 + 48 x_2 -
34///                   36x_1 x_2 + 27 x_2^2) ]
35/// \end{aligned}
36/// $$
37///
38/// where $x_i \in [-2,\\,2]$.
39///
40/// The global minimum is at $f(x_1,\\,x_2) = f(0,\\,-1) = 3$.
41pub fn goldsteinprice<T>(param: &[T; 2]) -> T
42where
43    T: Float + FromPrimitive,
44{
45    let [x1, x2] = *param;
46    let n1 = T::from_f64(1.0).unwrap();
47    let n2 = T::from_f64(2.0).unwrap();
48    let n3 = T::from_f64(3.0).unwrap();
49    let n6 = T::from_f64(6.0).unwrap();
50    let n12 = T::from_f64(12.0).unwrap();
51    let n14 = T::from_f64(14.0).unwrap();
52    let n18 = T::from_f64(18.0).unwrap();
53    let n19 = T::from_f64(19.0).unwrap();
54    let n27 = T::from_f64(27.0).unwrap();
55    let n30 = T::from_f64(30.0).unwrap();
56    let n32 = T::from_f64(32.0).unwrap();
57    let n36 = T::from_f64(36.0).unwrap();
58    let n48 = T::from_f64(48.0).unwrap();
59    (n1 + (x1 + x2 + n1).powi(2)
60        * (n19 - n14 * (x1 + x2) + n3 * (x1.powi(2) + x2.powi(2)) + n6 * x1 * x2))
61        * (n30
62            + (n2 * x1 - n3 * x2).powi(2)
63                * (n18 - n32 * x1 + n12 * x1.powi(2) + n48 * x2 - n36 * x1 * x2 + n27 * x2.powi(2)))
64}
65
66/// Derivative of Goldstein-Price test function.
67pub fn goldsteinprice_derivative<T>(param: &[T; 2]) -> [T; 2]
68where
69    T: Float + FromPrimitive,
70{
71    let [x1, x2] = *param;
72
73    let n1 = T::from_f64(1.0).unwrap();
74    let n2 = T::from_f64(2.0).unwrap();
75    let n3 = T::from_f64(3.0).unwrap();
76    let n4 = T::from_f64(4.0).unwrap();
77    let n6 = T::from_f64(6.0).unwrap();
78    let n12 = T::from_f64(12.0).unwrap();
79    let n14 = T::from_f64(14.0).unwrap();
80    let n18 = T::from_f64(18.0).unwrap();
81    let n19 = T::from_f64(19.0).unwrap();
82    let n24 = T::from_f64(24.0).unwrap();
83    let n27 = T::from_f64(27.0).unwrap();
84    let n30 = T::from_f64(30.0).unwrap();
85    let n32 = T::from_f64(32.0).unwrap();
86    let n36 = T::from_f64(36.0).unwrap();
87    let n48 = T::from_f64(48.0).unwrap();
88    let n54 = T::from_f64(54.0).unwrap();
89
90    let x1s = x1.powi(2);
91    let x2s = x2.powi(2);
92
93    [
94        (n2 * (x1 + x2 + n1) * (n3 * x1s + n6 * x2 * x1 - n14 * x1 + n3 * x2s - n14 * x2 + n19)
95            + (x1 + x2 + n1).powi(2) * (n6 * x1 + n6 * x2 - n14))
96            * ((n2 * x1 - n3 * x2).powi(2)
97                * (n12 * x1s - n36 * x2 * x1 - n32 * x1 + n27 * x2s + n48 * x2 + n18)
98                + n30)
99            + ((x1 + x2 + n1).powi(2)
100                * (n3 * x1s + n6 * x2 * x1 - n14 * x1 + n3 * x2s - n14 * x2 + n19)
101                + n1)
102                * (n4
103                    * (n2 * x1 - n3 * x2)
104                    * (n12 * x1s - n36 * x2 * x1 - n32 * x1 + n27 * x2s + n48 * x2 + n18)
105                    + (n2 * x1 - n3 * x2).powi(2) * (n24 * x1 - n36 * x2 - n32)),
106        ((x2 + x1 + n1).powi(2) * (n3 * x2s + n6 * x1 * x2 - n14 * x2 + n3 * x1s - n14 * x1 + n19)
107            + n1)
108            * ((n2 * x1 - n3 * x2).powi(2) * (n54 * x2 - n36 * x1 + n48)
109                - n6 * (n2 * x1 - n3 * x2)
110                    * (n27 * x2s - n36 * x1 * x2 + n48 * x2 + n12 * x1s - n32 * x1 + n18))
111            + (n2
112                * (x2 + x1 + n1)
113                * (n3 * x2s + n6 * x1 * x2 - n14 * x2 + n3 * x1s - n14 * x1 + n19)
114                + (x2 + x1 + n1).powi(2) * (n6 * x2 + n6 * x1 - n14))
115                * ((n2 * x1 - n3 * x2).powi(2)
116                    * (n27 * x2s - n36 * x1 * x2 + n48 * x2 + n12 * x1s - n32 * x1 + n18)
117                    + n30),
118    ]
119}
120
121/// Hessian of Goldstein-Price test function.
122pub fn goldsteinprice_hessian<T>(param: &[T; 2]) -> [[T; 2]; 2]
123where
124    T: Float + FromPrimitive,
125{
126    let [x1, x2] = *param;
127
128    let n840 = T::from_f64(840.0).unwrap();
129    let n1296 = T::from_f64(1296.0).unwrap();
130    let n2016 = T::from_f64(2016.0).unwrap();
131    let n2520 = T::from_f64(2520.0).unwrap();
132    let n2916 = T::from_f64(2916.0).unwrap();
133    let n3360 = T::from_f64(3360.0).unwrap();
134    let n4680 = T::from_f64(4680.0).unwrap();
135    let n5184 = T::from_f64(5184.0).unwrap();
136    let n5940 = T::from_f64(5940.0).unwrap();
137    let n6120 = T::from_f64(6120.0).unwrap();
138    let n6432 = T::from_f64(6432.0).unwrap();
139    let n6804 = T::from_f64(6804.0).unwrap();
140    let n7344 = T::from_f64(7344.0).unwrap();
141    let n7440 = T::from_f64(7440.0).unwrap();
142    let n7776 = T::from_f64(7776.0).unwrap();
143    let n8064 = T::from_f64(8064.0).unwrap();
144    let n10080 = T::from_f64(10080.0).unwrap();
145    let n10740 = T::from_f64(10740.0).unwrap();
146    let n11016 = T::from_f64(11016.0).unwrap();
147    let n11160 = T::from_f64(11160.0).unwrap();
148    let n11664 = T::from_f64(11664.0).unwrap();
149    let n12096 = T::from_f64(12096.0).unwrap();
150    let n14688 = T::from_f64(14688.0).unwrap();
151    let n15552 = T::from_f64(15552.0).unwrap();
152    let n15660 = T::from_f64(15660.0).unwrap();
153    let n17352 = T::from_f64(17352.0).unwrap();
154    let n17460 = T::from_f64(17460.0).unwrap();
155    let n17496 = T::from_f64(17496.0).unwrap();
156    let n18360 = T::from_f64(18360.0).unwrap();
157    let n19440 = T::from_f64(19440.0).unwrap();
158    let n19680 = T::from_f64(19680.0).unwrap();
159    let n20880 = T::from_f64(20880.0).unwrap();
160    let n23760 = T::from_f64(23760.0).unwrap();
161    let n24480 = T::from_f64(24480.0).unwrap();
162    let n25920 = T::from_f64(25920.0).unwrap();
163    let n26880 = T::from_f64(26880.0).unwrap();
164    let n27216 = T::from_f64(27216.0).unwrap();
165    let n27540 = T::from_f64(27540.0).unwrap();
166    let n28560 = T::from_f64(28560.0).unwrap();
167    let n29448 = T::from_f64(29448.0).unwrap();
168    let n30240 = T::from_f64(30240.0).unwrap();
169    let n30720 = T::from_f64(30720.0).unwrap();
170    let n31104 = T::from_f64(31104.0).unwrap();
171    let n32256 = T::from_f64(32256.0).unwrap();
172    let n34704 = T::from_f64(34704.0).unwrap();
173    let n36720 = T::from_f64(36720.0).unwrap();
174    let n38592 = T::from_f64(38592.0).unwrap();
175    let n38880 = T::from_f64(38880.0).unwrap();
176    let n40320 = T::from_f64(40320.0).unwrap();
177    let n40824 = T::from_f64(40824.0).unwrap();
178    let n41760 = T::from_f64(41760.0).unwrap();
179    let n42960 = T::from_f64(42960.0).unwrap();
180    let n43740 = T::from_f64(43740.0).unwrap();
181    let n47520 = T::from_f64(47520.0).unwrap();
182    let n48960 = T::from_f64(48960.0).unwrap();
183    let n51840 = T::from_f64(51840.0).unwrap();
184    let n58320 = T::from_f64(58320.0).unwrap();
185    let n59040 = T::from_f64(59040.0).unwrap();
186    let n64440 = T::from_f64(64440.0).unwrap();
187    let n69840 = T::from_f64(69840.0).unwrap();
188    let n70848 = T::from_f64(70848.0).unwrap();
189    let n73440 = T::from_f64(73440.0).unwrap();
190    let n73728 = T::from_f64(73728.0).unwrap();
191    let n92160 = T::from_f64(92160.0).unwrap();
192    let n104760 = T::from_f64(104760.0).unwrap();
193    let n132840 = T::from_f64(132840.0).unwrap();
194    let n142560 = T::from_f64(142560.0).unwrap();
195    let n141696 = T::from_f64(141696.0).unwrap();
196    let n172152 = T::from_f64(172152.0).unwrap();
197
198    let x1p2 = x1.powi(2);
199    let x1p3 = x1.powi(3);
200    let x1p4 = x1.powi(4);
201    let x1p5 = x1.powi(5);
202    let x1p6 = x1.powi(6);
203    let x2p2 = x2.powi(2);
204    let x2p3 = x2.powi(3);
205    let x2p4 = x2.powi(4);
206    let x2p5 = x2.powi(5);
207    let x2p6 = x2.powi(6);
208
209    let a = n8064 * x1p6
210        + (-n12096 * x2 - n32256) * x1p5
211        + (-n19440 * x2p2 + n40320 * x2 + n28560) * x1p4
212        + (n24480 * x2p3 + n51840 * x2p2 - n3360 * x2 + n26880) * x1p3
213        + (n15660 * x2p4 - n48960 * x2p3 - n64440 * x2p2 - n92160 * x2 - n29448) * x1p2
214        + (-n11016 * x2p5 - n20880 * x2p4 + n7440 * x2p3 + n59040 * x2p2 + n34704 * x2 - n6432)
215            * x1
216        - n2916 * x2p6
217        + n7344 * x2p5
218        + n17460 * x2p4
219        + n10080 * x2p3
220        + n15552 * x2p2
221        + n14688 * x2
222        + n2520;
223
224    let b = n40824 * x2p6
225        + (n40824 * x1 - n27216) * x2p5
226        + (-n43740 * x1p2 + n58320 * x1 - n132840) * x2p4
227        + (-n36720 * x1p3 + n73440 * x1p2 - n23760 * x1 + n38880) * x2p3
228        + (n15660 * x1p4 - n41760 * x1p3 + n104760 * x1p2 - n142560 * x1 + n172152) * x2p2
229        + (n7344 * x1p5 - n24480 * x1p4 + n7440 * x1p3 + n30240 * x1p2 - n141696 * x1 + n73728)
230            * x2
231        - n1296 * x1p6
232        + n5184 * x1p5
233        - n10740 * x1p4
234        + n19680 * x1p3
235        + n15552 * x1p2
236        - n38592 * x1
237        + n6120;
238
239    let offdiag = n6804 * x2p6
240        + (n11664 - n17496 * x1) * x2p5
241        + (-n27540 * x1p2 + n36720 * x1 - n5940) * x2p4
242        + (n20880 * x1p3 - n41760 * x1p2 + n69840 * x1 - n47520) * x2p3
243        + (n18360 * x1p4 - n48960 * x1p3 + n11160 * x1p2 + n30240 * x1 - n70848) * x2p2
244        + (-n7776 * x1p5 + n25920 * x1p4 - n42960 * x1p3 + n59040 * x1p2 + n31104 * x1 - n38592)
245            * x2
246        - n2016 * x1p6
247        + n8064 * x1p5
248        - n840 * x1p4
249        - n30720 * x1p3
250        + n17352 * x1p2
251        + n14688 * x1
252        - n4680;
253
254    [[a, offdiag], [offdiag, b]]
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260    use approx::assert_relative_eq;
261    use finitediff::FiniteDiff;
262    use proptest::prelude::*;
263    use std::{f32, f64};
264
265    #[test]
266    fn test_goldsteinprice_optimum() {
267        assert_relative_eq!(
268            goldsteinprice(&[0.0_f32, -1.0_f32]),
269            3_f32,
270            epsilon = f32::EPSILON
271        );
272        assert_relative_eq!(
273            goldsteinprice(&[0.0_f64, -1.0_f64]),
274            3_f64,
275            epsilon = f64::EPSILON
276        );
277
278        let deriv = goldsteinprice_derivative(&[0.0_f64, -1.0_f64]);
279        for i in 0..2 {
280            assert_relative_eq!(deriv[i], 0.0, epsilon = f64::EPSILON);
281        }
282    }
283
284    proptest! {
285        #[test]
286        fn test_goldsteinprice_derivative_finitediff(a in -2.0..2.0, b in -2.0..2.0) {
287            let param = [a, b];
288            let derivative = goldsteinprice_derivative(&param);
289            let derivative_fd = Vec::from(param).central_diff(&|x| goldsteinprice(&[x[0], x[1]]));
290            // println!("1: {derivative:?} at {a}/{b}");
291            // println!("2: {derivative_fd:?} at {a}/{b}");
292            for i in 0..derivative.len() {
293                assert_relative_eq!(
294                    derivative[i],
295                    derivative_fd[i],
296                    epsilon = 1e-3,
297                    max_relative = 1e-1
298                );
299            }
300        }
301    }
302
303    proptest! {
304        #[test]
305        fn test_goldsteinprice_derivative_finitediff_narrow(a in -0.5..0.5, b in -0.5..0.5) {
306            // This evaluates the function on a narrower domain, which allows us to have a lower
307            // epsilon, as the function is pretty steep at the boundary, which isn't great for
308            // accuracy when using finite differentiation.
309            let param = [a, b];
310            let derivative = goldsteinprice_derivative(&param);
311            let derivative_fd = Vec::from(param).central_diff(&|x| goldsteinprice(&[x[0], x[1]]));
312            // println!("1: {derivative:?} at {a}/{b}");
313            // println!("2: {derivative_fd:?} at {a}/{b}");
314            for i in 0..derivative.len() {
315                assert_relative_eq!(
316                    derivative[i],
317                    derivative_fd[i],
318                    epsilon = 1e-3,
319                    max_relative = 1e-2
320                );
321            }
322        }
323    }
324
325    proptest! {
326        #[test]
327        fn test_goldsteinprice_hessian_finitediff(a in -2.0..2.0, b in -2.0..2.0) {
328            let param = [a, b];
329            let hessian = goldsteinprice_hessian(&param);
330            let hessian_fd =
331                Vec::from(param).central_hessian(&|x| goldsteinprice_derivative(&[x[0], x[1]]).to_vec());
332            let n = hessian.len();
333            // println!("1: {hessian:?} at {a}/{b}");
334            // println!("2: {hessian_fd:?} at {a}/{b}");
335            for i in 0..n {
336                assert_eq!(hessian[i].len(), n);
337                for j in 0..n {
338                    if hessian_fd[i][j].is_finite() {
339                        assert_relative_eq!(
340                            hessian[i][j],
341                            hessian_fd[i][j],
342                            epsilon = 1e-5,
343                            max_relative = 1e-1
344                        );
345                    }
346                }
347            }
348        }
349    }
350
351    proptest! {
352        #[test]
353        fn test_goldsteinprice_hessian_finitediff_narrow(a in -0.5..0.5, b in -0.5..0.5) {
354            // This evaluates the function on a narrower domain, which allows us to have a lower
355            // epsilon, as the function is pretty steep at the boundary, which isn't great for
356            // accuracy when using finite differentiation.
357            let param = [a, b];
358            let hessian = goldsteinprice_hessian(&param);
359            let hessian_fd =
360                Vec::from(param).central_hessian(&|x| goldsteinprice_derivative(&[x[0], x[1]]).to_vec());
361            let n = hessian.len();
362            // println!("1: {hessian:?} at {a}/{b}");
363            // println!("2: {hessian_fd:?} at {a}/{b}");
364            for i in 0..n {
365                assert_eq!(hessian[i].len(), n);
366                for j in 0..n {
367                    if hessian_fd[i][j].is_finite() {
368                        assert_relative_eq!(
369                            hessian[i][j],
370                            hessian_fd[i][j],
371                            epsilon = 1e-5,
372                            max_relative = 1e-2
373                        );
374                    }
375                }
376            }
377        }
378    }
379}