argmin_testfunctions/
booth.rs1use num::{Float, FromPrimitive};
21
22pub fn booth<T>(param: &[T; 2]) -> T
34where
35 T: Float + FromPrimitive,
36{
37 let n2 = T::from_f64(2.0).unwrap();
38 let n5 = T::from_f64(5.0).unwrap();
39 let n7 = T::from_f64(7.0).unwrap();
40
41 let [x1, x2] = *param;
42 (x1 + n2 * x2 - n7).powi(2) + (n2 * x1 + x2 - n5).powi(2)
43}
44
45pub fn booth_derivative<T>(param: &[T; 2]) -> [T; 2]
47where
48 T: Float + FromPrimitive,
49{
50 let n8 = T::from_f64(8.0).unwrap();
51 let n10 = T::from_f64(10.0).unwrap();
52 let n34 = T::from_f64(34.0).unwrap();
53 let n38 = T::from_f64(38.0).unwrap();
54
55 let [x1, x2] = *param;
56
57 [n10 * x1 + n8 * x2 - n34, n8 * x1 + n10 * x2 - n38]
58}
59
60pub fn booth_hessian<T>(_param: &[T; 2]) -> [[T; 2]; 2]
64where
65 T: Float + FromPrimitive,
66{
67 let n8 = T::from_f64(8.0).unwrap();
68 let n10 = T::from_f64(10.0).unwrap();
69
70 [[n10, n8], [n8, n10]]
71}
72
73#[cfg(test)]
74mod tests {
75 use super::*;
76 use approx::assert_relative_eq;
77 use finitediff::FiniteDiff;
78 use proptest::prelude::*;
79 use std::{f32, f64};
80
81 #[test]
82 fn test_booth_optimum() {
83 assert_relative_eq!(booth(&[1_f32, 3_f32]), 0.0, epsilon = f32::EPSILON);
84 assert_relative_eq!(booth(&[1_f64, 3_f64]), 0.0, epsilon = f64::EPSILON);
85
86 let deriv = booth_derivative(&[1.0, 3.0]);
87 for i in 0..2 {
88 assert_relative_eq!(deriv[i], 0.0, epsilon = f64::EPSILON);
89 }
90 }
91
92 proptest! {
93 #[test]
94 fn test_booth_derivative_finitediff(a in -10.0..10.0, b in -10.0..10.0) {
95 let param = [a, b];
96 let derivative = booth_derivative(¶m);
97 let derivative_fd = Vec::from(param).central_diff(&|x| booth(&[x[0], x[1]]));
98 for i in 0..derivative.len() {
99 assert_relative_eq!(
100 derivative[i],
101 derivative_fd[i],
102 epsilon = 1e-4,
103 max_relative = 1e-2
104 );
105 }
106 }
107 }
108
109 proptest! {
110 #[test]
111 fn test_booth_hessian(a in -10.0..10.0, b in -10.0..10.0) {
112 let param = [a, b];
113 let hessian = booth_hessian(¶m);
114 let hessian_fd = [[10.0, 8.0], [8.0, 10.0]];
115 let n = hessian.len();
116 for i in 0..n {
117 assert_eq!(hessian[i].len(), n);
118 for j in 0..n {
119 assert_relative_eq!(
120 hessian[i][j],
121 hessian_fd[i][j],
122 epsilon = 1e-5,
123 max_relative = 1e-2
124 );
125 }
126 }
127 }
128 }
129}