argmin_testfunctions/
goldsteinprice.rs1use num::{Float, FromPrimitive};
25
26pub fn goldsteinprice<T>(param: &[T; 2]) -> T
42where
43 T: Float + FromPrimitive,
44{
45 let [x1, x2] = *param;
46 let n1 = T::from_f64(1.0).unwrap();
47 let n2 = T::from_f64(2.0).unwrap();
48 let n3 = T::from_f64(3.0).unwrap();
49 let n6 = T::from_f64(6.0).unwrap();
50 let n12 = T::from_f64(12.0).unwrap();
51 let n14 = T::from_f64(14.0).unwrap();
52 let n18 = T::from_f64(18.0).unwrap();
53 let n19 = T::from_f64(19.0).unwrap();
54 let n27 = T::from_f64(27.0).unwrap();
55 let n30 = T::from_f64(30.0).unwrap();
56 let n32 = T::from_f64(32.0).unwrap();
57 let n36 = T::from_f64(36.0).unwrap();
58 let n48 = T::from_f64(48.0).unwrap();
59 (n1 + (x1 + x2 + n1).powi(2)
60 * (n19 - n14 * (x1 + x2) + n3 * (x1.powi(2) + x2.powi(2)) + n6 * x1 * x2))
61 * (n30
62 + (n2 * x1 - n3 * x2).powi(2)
63 * (n18 - n32 * x1 + n12 * x1.powi(2) + n48 * x2 - n36 * x1 * x2 + n27 * x2.powi(2)))
64}
65
66pub fn goldsteinprice_derivative<T>(param: &[T; 2]) -> [T; 2]
68where
69 T: Float + FromPrimitive,
70{
71 let [x1, x2] = *param;
72
73 let n1 = T::from_f64(1.0).unwrap();
74 let n2 = T::from_f64(2.0).unwrap();
75 let n3 = T::from_f64(3.0).unwrap();
76 let n4 = T::from_f64(4.0).unwrap();
77 let n6 = T::from_f64(6.0).unwrap();
78 let n12 = T::from_f64(12.0).unwrap();
79 let n14 = T::from_f64(14.0).unwrap();
80 let n18 = T::from_f64(18.0).unwrap();
81 let n19 = T::from_f64(19.0).unwrap();
82 let n24 = T::from_f64(24.0).unwrap();
83 let n27 = T::from_f64(27.0).unwrap();
84 let n30 = T::from_f64(30.0).unwrap();
85 let n32 = T::from_f64(32.0).unwrap();
86 let n36 = T::from_f64(36.0).unwrap();
87 let n48 = T::from_f64(48.0).unwrap();
88 let n54 = T::from_f64(54.0).unwrap();
89
90 let x1s = x1.powi(2);
91 let x2s = x2.powi(2);
92
93 [
94 (n2 * (x1 + x2 + n1) * (n3 * x1s + n6 * x2 * x1 - n14 * x1 + n3 * x2s - n14 * x2 + n19)
95 + (x1 + x2 + n1).powi(2) * (n6 * x1 + n6 * x2 - n14))
96 * ((n2 * x1 - n3 * x2).powi(2)
97 * (n12 * x1s - n36 * x2 * x1 - n32 * x1 + n27 * x2s + n48 * x2 + n18)
98 + n30)
99 + ((x1 + x2 + n1).powi(2)
100 * (n3 * x1s + n6 * x2 * x1 - n14 * x1 + n3 * x2s - n14 * x2 + n19)
101 + n1)
102 * (n4
103 * (n2 * x1 - n3 * x2)
104 * (n12 * x1s - n36 * x2 * x1 - n32 * x1 + n27 * x2s + n48 * x2 + n18)
105 + (n2 * x1 - n3 * x2).powi(2) * (n24 * x1 - n36 * x2 - n32)),
106 ((x2 + x1 + n1).powi(2) * (n3 * x2s + n6 * x1 * x2 - n14 * x2 + n3 * x1s - n14 * x1 + n19)
107 + n1)
108 * ((n2 * x1 - n3 * x2).powi(2) * (n54 * x2 - n36 * x1 + n48)
109 - n6 * (n2 * x1 - n3 * x2)
110 * (n27 * x2s - n36 * x1 * x2 + n48 * x2 + n12 * x1s - n32 * x1 + n18))
111 + (n2
112 * (x2 + x1 + n1)
113 * (n3 * x2s + n6 * x1 * x2 - n14 * x2 + n3 * x1s - n14 * x1 + n19)
114 + (x2 + x1 + n1).powi(2) * (n6 * x2 + n6 * x1 - n14))
115 * ((n2 * x1 - n3 * x2).powi(2)
116 * (n27 * x2s - n36 * x1 * x2 + n48 * x2 + n12 * x1s - n32 * x1 + n18)
117 + n30),
118 ]
119}
120
121pub fn goldsteinprice_hessian<T>(param: &[T; 2]) -> [[T; 2]; 2]
123where
124 T: Float + FromPrimitive,
125{
126 let [x1, x2] = *param;
127
128 let n840 = T::from_f64(840.0).unwrap();
129 let n1296 = T::from_f64(1296.0).unwrap();
130 let n2016 = T::from_f64(2016.0).unwrap();
131 let n2520 = T::from_f64(2520.0).unwrap();
132 let n2916 = T::from_f64(2916.0).unwrap();
133 let n3360 = T::from_f64(3360.0).unwrap();
134 let n4680 = T::from_f64(4680.0).unwrap();
135 let n5184 = T::from_f64(5184.0).unwrap();
136 let n5940 = T::from_f64(5940.0).unwrap();
137 let n6120 = T::from_f64(6120.0).unwrap();
138 let n6432 = T::from_f64(6432.0).unwrap();
139 let n6804 = T::from_f64(6804.0).unwrap();
140 let n7344 = T::from_f64(7344.0).unwrap();
141 let n7440 = T::from_f64(7440.0).unwrap();
142 let n7776 = T::from_f64(7776.0).unwrap();
143 let n8064 = T::from_f64(8064.0).unwrap();
144 let n10080 = T::from_f64(10080.0).unwrap();
145 let n10740 = T::from_f64(10740.0).unwrap();
146 let n11016 = T::from_f64(11016.0).unwrap();
147 let n11160 = T::from_f64(11160.0).unwrap();
148 let n11664 = T::from_f64(11664.0).unwrap();
149 let n12096 = T::from_f64(12096.0).unwrap();
150 let n14688 = T::from_f64(14688.0).unwrap();
151 let n15552 = T::from_f64(15552.0).unwrap();
152 let n15660 = T::from_f64(15660.0).unwrap();
153 let n17352 = T::from_f64(17352.0).unwrap();
154 let n17460 = T::from_f64(17460.0).unwrap();
155 let n17496 = T::from_f64(17496.0).unwrap();
156 let n18360 = T::from_f64(18360.0).unwrap();
157 let n19440 = T::from_f64(19440.0).unwrap();
158 let n19680 = T::from_f64(19680.0).unwrap();
159 let n20880 = T::from_f64(20880.0).unwrap();
160 let n23760 = T::from_f64(23760.0).unwrap();
161 let n24480 = T::from_f64(24480.0).unwrap();
162 let n25920 = T::from_f64(25920.0).unwrap();
163 let n26880 = T::from_f64(26880.0).unwrap();
164 let n27216 = T::from_f64(27216.0).unwrap();
165 let n27540 = T::from_f64(27540.0).unwrap();
166 let n28560 = T::from_f64(28560.0).unwrap();
167 let n29448 = T::from_f64(29448.0).unwrap();
168 let n30240 = T::from_f64(30240.0).unwrap();
169 let n30720 = T::from_f64(30720.0).unwrap();
170 let n31104 = T::from_f64(31104.0).unwrap();
171 let n32256 = T::from_f64(32256.0).unwrap();
172 let n34704 = T::from_f64(34704.0).unwrap();
173 let n36720 = T::from_f64(36720.0).unwrap();
174 let n38592 = T::from_f64(38592.0).unwrap();
175 let n38880 = T::from_f64(38880.0).unwrap();
176 let n40320 = T::from_f64(40320.0).unwrap();
177 let n40824 = T::from_f64(40824.0).unwrap();
178 let n41760 = T::from_f64(41760.0).unwrap();
179 let n42960 = T::from_f64(42960.0).unwrap();
180 let n43740 = T::from_f64(43740.0).unwrap();
181 let n47520 = T::from_f64(47520.0).unwrap();
182 let n48960 = T::from_f64(48960.0).unwrap();
183 let n51840 = T::from_f64(51840.0).unwrap();
184 let n58320 = T::from_f64(58320.0).unwrap();
185 let n59040 = T::from_f64(59040.0).unwrap();
186 let n64440 = T::from_f64(64440.0).unwrap();
187 let n69840 = T::from_f64(69840.0).unwrap();
188 let n70848 = T::from_f64(70848.0).unwrap();
189 let n73440 = T::from_f64(73440.0).unwrap();
190 let n73728 = T::from_f64(73728.0).unwrap();
191 let n92160 = T::from_f64(92160.0).unwrap();
192 let n104760 = T::from_f64(104760.0).unwrap();
193 let n132840 = T::from_f64(132840.0).unwrap();
194 let n142560 = T::from_f64(142560.0).unwrap();
195 let n141696 = T::from_f64(141696.0).unwrap();
196 let n172152 = T::from_f64(172152.0).unwrap();
197
198 let x1p2 = x1.powi(2);
199 let x1p3 = x1.powi(3);
200 let x1p4 = x1.powi(4);
201 let x1p5 = x1.powi(5);
202 let x1p6 = x1.powi(6);
203 let x2p2 = x2.powi(2);
204 let x2p3 = x2.powi(3);
205 let x2p4 = x2.powi(4);
206 let x2p5 = x2.powi(5);
207 let x2p6 = x2.powi(6);
208
209 let a = n8064 * x1p6
210 + (-n12096 * x2 - n32256) * x1p5
211 + (-n19440 * x2p2 + n40320 * x2 + n28560) * x1p4
212 + (n24480 * x2p3 + n51840 * x2p2 - n3360 * x2 + n26880) * x1p3
213 + (n15660 * x2p4 - n48960 * x2p3 - n64440 * x2p2 - n92160 * x2 - n29448) * x1p2
214 + (-n11016 * x2p5 - n20880 * x2p4 + n7440 * x2p3 + n59040 * x2p2 + n34704 * x2 - n6432)
215 * x1
216 - n2916 * x2p6
217 + n7344 * x2p5
218 + n17460 * x2p4
219 + n10080 * x2p3
220 + n15552 * x2p2
221 + n14688 * x2
222 + n2520;
223
224 let b = n40824 * x2p6
225 + (n40824 * x1 - n27216) * x2p5
226 + (-n43740 * x1p2 + n58320 * x1 - n132840) * x2p4
227 + (-n36720 * x1p3 + n73440 * x1p2 - n23760 * x1 + n38880) * x2p3
228 + (n15660 * x1p4 - n41760 * x1p3 + n104760 * x1p2 - n142560 * x1 + n172152) * x2p2
229 + (n7344 * x1p5 - n24480 * x1p4 + n7440 * x1p3 + n30240 * x1p2 - n141696 * x1 + n73728)
230 * x2
231 - n1296 * x1p6
232 + n5184 * x1p5
233 - n10740 * x1p4
234 + n19680 * x1p3
235 + n15552 * x1p2
236 - n38592 * x1
237 + n6120;
238
239 let offdiag = n6804 * x2p6
240 + (n11664 - n17496 * x1) * x2p5
241 + (-n27540 * x1p2 + n36720 * x1 - n5940) * x2p4
242 + (n20880 * x1p3 - n41760 * x1p2 + n69840 * x1 - n47520) * x2p3
243 + (n18360 * x1p4 - n48960 * x1p3 + n11160 * x1p2 + n30240 * x1 - n70848) * x2p2
244 + (-n7776 * x1p5 + n25920 * x1p4 - n42960 * x1p3 + n59040 * x1p2 + n31104 * x1 - n38592)
245 * x2
246 - n2016 * x1p6
247 + n8064 * x1p5
248 - n840 * x1p4
249 - n30720 * x1p3
250 + n17352 * x1p2
251 + n14688 * x1
252 - n4680;
253
254 [[a, offdiag], [offdiag, b]]
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260 use approx::assert_relative_eq;
261 use finitediff::FiniteDiff;
262 use proptest::prelude::*;
263 use std::{f32, f64};
264
265 #[test]
266 fn test_goldsteinprice_optimum() {
267 assert_relative_eq!(
268 goldsteinprice(&[0.0_f32, -1.0_f32]),
269 3_f32,
270 epsilon = f32::EPSILON
271 );
272 assert_relative_eq!(
273 goldsteinprice(&[0.0_f64, -1.0_f64]),
274 3_f64,
275 epsilon = f64::EPSILON
276 );
277
278 let deriv = goldsteinprice_derivative(&[0.0_f64, -1.0_f64]);
279 for i in 0..2 {
280 assert_relative_eq!(deriv[i], 0.0, epsilon = f64::EPSILON);
281 }
282 }
283
284 proptest! {
285 #[test]
286 fn test_goldsteinprice_derivative_finitediff(a in -2.0..2.0, b in -2.0..2.0) {
287 let param = [a, b];
288 let derivative = goldsteinprice_derivative(¶m);
289 let derivative_fd = Vec::from(param).central_diff(&|x| goldsteinprice(&[x[0], x[1]]));
290 for i in 0..derivative.len() {
293 assert_relative_eq!(
294 derivative[i],
295 derivative_fd[i],
296 epsilon = 1e-3,
297 max_relative = 1e-1
298 );
299 }
300 }
301 }
302
303 proptest! {
304 #[test]
305 fn test_goldsteinprice_derivative_finitediff_narrow(a in -0.5..0.5, b in -0.5..0.5) {
306 let param = [a, b];
310 let derivative = goldsteinprice_derivative(¶m);
311 let derivative_fd = Vec::from(param).central_diff(&|x| goldsteinprice(&[x[0], x[1]]));
312 for i in 0..derivative.len() {
315 assert_relative_eq!(
316 derivative[i],
317 derivative_fd[i],
318 epsilon = 1e-3,
319 max_relative = 1e-2
320 );
321 }
322 }
323 }
324
325 proptest! {
326 #[test]
327 fn test_goldsteinprice_hessian_finitediff(a in -2.0..2.0, b in -2.0..2.0) {
328 let param = [a, b];
329 let hessian = goldsteinprice_hessian(¶m);
330 let hessian_fd =
331 Vec::from(param).central_hessian(&|x| goldsteinprice_derivative(&[x[0], x[1]]).to_vec());
332 let n = hessian.len();
333 for i in 0..n {
336 assert_eq!(hessian[i].len(), n);
337 for j in 0..n {
338 if hessian_fd[i][j].is_finite() {
339 assert_relative_eq!(
340 hessian[i][j],
341 hessian_fd[i][j],
342 epsilon = 1e-5,
343 max_relative = 1e-1
344 );
345 }
346 }
347 }
348 }
349 }
350
351 proptest! {
352 #[test]
353 fn test_goldsteinprice_hessian_finitediff_narrow(a in -0.5..0.5, b in -0.5..0.5) {
354 let param = [a, b];
358 let hessian = goldsteinprice_hessian(¶m);
359 let hessian_fd =
360 Vec::from(param).central_hessian(&|x| goldsteinprice_derivative(&[x[0], x[1]]).to_vec());
361 let n = hessian.len();
362 for i in 0..n {
365 assert_eq!(hessian[i].len(), n);
366 for j in 0..n {
367 if hessian_fd[i][j].is_finite() {
368 assert_relative_eq!(
369 hessian[i][j],
370 hessian_fd[i][j],
371 epsilon = 1e-5,
372 max_relative = 1e-2
373 );
374 }
375 }
376 }
377 }
378 }
379}