1use num::{Float, FromPrimitive};
21use std::{iter::Sum, ops::AddAssign};
22
23pub fn rosenbrock<T>(param: &[T]) -> T
37where
38 T: Float + FromPrimitive + Sum,
39{
40 rosenbrock_ab(
41 param,
42 T::from_f64(1.0).unwrap(),
43 T::from_f64(100.0).unwrap(),
44 )
45}
46
47pub fn rosenbrock_ab<T>(param: &[T], a: T, b: T) -> T
51where
52 T: Float + FromPrimitive + Sum,
53{
54 param
55 .iter()
56 .zip(param.iter().skip(1))
57 .map(|(&xi, &xi1)| (a - xi).powi(2) + b * (xi1 - xi.powi(2)).powi(2))
58 .sum()
59}
60pub fn rosenbrock_derivative<T>(param: &[T]) -> Vec<T>
64where
65 T: Float + FromPrimitive + AddAssign,
66{
67 rosenbrock_ab_derivative(
68 param,
69 T::from_f64(1.0).unwrap(),
70 T::from_f64(100.0).unwrap(),
71 )
72}
73
74pub fn rosenbrock_ab_derivative<T>(param: &[T], a: T, b: T) -> Vec<T>
78where
79 T: Float + FromPrimitive + AddAssign,
80{
81 let n0 = T::from_f64(0.0).unwrap();
82 let n2 = T::from_f64(2.0).unwrap();
83 let n4 = T::from_f64(4.0).unwrap();
84
85 let n = param.len();
86
87 let mut result = vec![n0; n];
88
89 for i in 0..(n - 1) {
90 let xi = param[i];
91 let xi1 = param[i + 1];
92
93 let t1 = -n4 * b * xi * (xi1 - xi.powi(2));
94 let t2 = n2 * b * (xi1 - xi.powi(2));
95
96 result[i] += t1 + n2 * (xi - a);
97 result[i + 1] += t2;
98 }
99 result
100}
101
102pub fn rosenbrock_hessian<T>(param: &[T]) -> Vec<Vec<T>>
106where
107 T: Float + FromPrimitive + AddAssign,
108{
109 rosenbrock_ab_hessian(
110 param,
111 T::from_f64(1.0).unwrap(),
112 T::from_f64(100.0).unwrap(),
113 )
114}
115
116pub fn rosenbrock_ab_hessian<T>(param: &[T], a: T, b: T) -> Vec<Vec<T>>
120where
121 T: Float + FromPrimitive + AddAssign,
122{
123 let n0 = T::from_f64(0.0).unwrap();
124 let n2 = T::from_f64(2.0).unwrap();
125 let n4 = T::from_f64(4.0).unwrap();
126 let n12 = T::from_f64(12.0).unwrap();
127
128 let n = param.len();
129 let mut hessian = vec![vec![n0; n]; n];
130
131 for i in 0..n - 1 {
132 let xi = param[i];
133 let xi1 = param[i + 1];
134
135 hessian[i][i] += n12 * b * xi.powi(2) - n4 * b * xi1 + n2 * a;
136 hessian[i + 1][i + 1] = n2 * b;
137 hessian[i][i + 1] = -n4 * b * xi;
138 hessian[i + 1][i] = -n4 * b * xi;
139 }
140 hessian
141}
142
143pub fn rosenbrock_derivative_const<const N: usize, T>(param: &[T; N]) -> [T; N]
150where
151 T: Float + FromPrimitive + AddAssign,
152{
153 rosenbrock_ab_derivative_const(
154 param,
155 T::from_f64(1.0).unwrap(),
156 T::from_f64(100.0).unwrap(),
157 )
158}
159
160pub fn rosenbrock_ab_derivative_const<const N: usize, T>(param: &[T; N], a: T, b: T) -> [T; N]
167where
168 T: Float + FromPrimitive + AddAssign,
169{
170 let n0 = T::from_f64(0.0).unwrap();
171 let n2 = T::from_f64(2.0).unwrap();
172 let n4 = T::from_f64(4.0).unwrap();
173
174 let mut result = [n0; N];
175
176 for i in 0..(N - 1) {
177 let xi = param[i];
178 let xi1 = param[i + 1];
179
180 let t1 = -n4 * b * xi * (xi1 - xi.powi(2));
181 let t2 = n2 * b * (xi1 - xi.powi(2));
182
183 result[i] += t1 + n2 * (xi - a);
184 result[i + 1] += t2;
185 }
186 result
187}
188
189pub fn rosenbrock_hessian_const<const N: usize, T>(param: &[T; N]) -> [[T; N]; N]
196where
197 T: Float + FromPrimitive + AddAssign,
198{
199 rosenbrock_ab_hessian_const(
200 param,
201 T::from_f64(1.0).unwrap(),
202 T::from_f64(100.0).unwrap(),
203 )
204}
205
206pub fn rosenbrock_ab_hessian_const<const N: usize, T>(x: &[T; N], a: T, b: T) -> [[T; N]; N]
213where
214 T: Float + FromPrimitive + AddAssign,
215{
216 let n0 = T::from_f64(0.0).unwrap();
217 let n2 = T::from_f64(2.0).unwrap();
218 let n4 = T::from_f64(4.0).unwrap();
219 let n12 = T::from_f64(12.0).unwrap();
220
221 let mut hessian = [[n0; N]; N];
222
223 for i in 0..(N - 1) {
224 let xi = x[i];
225 let xi1 = x[i + 1];
226
227 hessian[i][i] += n12 * b * xi.powi(2) - n4 * b * xi1 + n2 * a;
228 hessian[i + 1][i + 1] = n2 * b;
229 hessian[i][i + 1] = -n4 * b * xi;
230 hessian[i + 1][i] = -n4 * b * xi;
231 }
232 hessian
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238 use approx::assert_relative_eq;
239 use finitediff::FiniteDiff;
240 use proptest::prelude::*;
241
242 #[test]
243 fn test_rosenbrock_optimum() {
244 assert_relative_eq!(rosenbrock(&[1.0_f32, 1.0_f32]), 0.0, epsilon = f32::EPSILON);
245 assert_relative_eq!(rosenbrock(&[1.0, 1.0]), 0.0, epsilon = f64::EPSILON);
246 assert_relative_eq!(rosenbrock(&[1.0, 1.0, 1.0]), 0.0, epsilon = f64::EPSILON);
247 }
248
249 #[test]
250 fn test_rosenbrock_derivative_optimum() {
251 let derivative = rosenbrock_derivative(&[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]);
252 for elem in derivative {
253 assert_relative_eq!(elem, 0.0, epsilon = f64::EPSILON);
254 }
255 }
256
257 #[test]
258 fn test_rosenbrock_derivative_const_optimum() {
259 let derivative = rosenbrock_derivative_const(&[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]);
260 for elem in derivative {
261 assert_relative_eq!(elem, 0.0, epsilon = f64::EPSILON);
262 }
263 }
264
265 #[test]
266 fn test_rosenbrock_hessian() {
267 let hessian = rosenbrock_hessian(&[0.0, 0.1, 0.2, 0.3]);
269 let res = vec![
270 vec![-38.0, 0.0, 0.0, 0.0],
271 vec![0.0, 134.0, -40.0, 0.0],
272 vec![0.0, -40.0, 130.0, -80.0],
273 vec![0.0, 0.0, -80.0, 200.0],
274 ];
275 let n = hessian.len();
276 for i in 0..n {
277 assert_eq!(hessian[i].len(), n);
278 for j in 0..n {
279 assert_relative_eq!(
280 hessian[i][j],
281 res[i][j],
282 epsilon = 1e-5,
283 max_relative = 1e-2
284 );
285 }
286 }
287 }
288
289 #[test]
290 fn test_rosenbrock_hessian_const() {
291 let hessian = rosenbrock_hessian_const(&[0.0, 0.1, 0.2, 0.3]);
293 let res = vec![
294 vec![-38.0, 0.0, 0.0, 0.0],
295 vec![0.0, 134.0, -40.0, 0.0],
296 vec![0.0, -40.0, 130.0, -80.0],
297 vec![0.0, 0.0, -80.0, 200.0],
298 ];
299 let n = hessian.len();
300 for i in 0..n {
301 assert_eq!(hessian[i].len(), n);
302 for j in 0..n {
303 assert_relative_eq!(
304 hessian[i][j],
305 res[i][j],
306 epsilon = 1e-5,
307 max_relative = 1e-2
308 );
309 }
310 }
311 }
312
313 proptest! {
314 #[test]
315 fn test_rosenbrock_derivative_finitediff(a in -1.0..1.0,
316 b in -1.0..1.0,
317 c in -1.0..1.0,
318 d in -1.0..1.0,
319 e in -1.0..1.0,
320 f in -1.0..1.0,
321 g in -1.0..1.0,
322 h in -1.0..1.0) {
323 let param = [a, b, c, d, e, f, g, h];
324 let derivative = rosenbrock_derivative(¶m);
325 let derivative_fd = Vec::from(param).central_diff(&|x| rosenbrock(&x));
326 for i in 0..derivative.len() {
327 assert_relative_eq!(
328 derivative[i],
329 derivative_fd[i],
330 epsilon = 1e-4,
331 max_relative = 1e-2
332 );
333 }
334 }
335 }
336
337 proptest! {
338 #[test]
339 fn test_rosenbrock_derivative_const_finitediff(a in -1.0..1.0,
340 b in -1.0..1.0,
341 c in -1.0..1.0,
342 d in -1.0..1.0,
343 e in -1.0..1.0,
344 f in -1.0..1.0,
345 g in -1.0..1.0,
346 h in -1.0..1.0) {
347 let param = [a, b, c, d, e, f, g, h];
348 let derivative = rosenbrock_derivative_const(¶m);
349 let derivative_fd = Vec::from(param).central_diff(&|x| rosenbrock(&x));
350 for i in 0..derivative.len() {
351 assert_relative_eq!(
352 derivative[i],
353 derivative_fd[i],
354 epsilon = 1e-4,
355 max_relative = 1e-2
356 );
357 }
358 }
359 }
360
361 proptest! {
362 #[test]
363 fn test_rosenbrock_hessian_finitediff(a in -1.0..1.0,
364 b in -1.0..1.0,
365 c in -1.0..1.0,
366 d in -1.0..1.0,
367 e in -1.0..1.0,
368 f in -1.0..1.0,
369 g in -1.0..1.0,
370 h in -1.0..1.0) {
371 let param = [a, b, c, d, e, f, g, h];
372 let hessian = rosenbrock_hessian(¶m);
373 let hessian_fd =
374 Vec::from(param).forward_hessian(&|x| rosenbrock_derivative(&x));
375 let n = hessian.len();
376 for i in 0..n {
377 assert_eq!(hessian[i].len(), n);
378 for j in 0..n {
379 assert_relative_eq!(
380 hessian[i][j],
381 hessian_fd[i][j],
382 epsilon = 1e-4,
383 max_relative = 1e-2
384 );
385 }
386 }
387 }
388 }
389
390 proptest! {
391 #[test]
392 fn test_rosenbrock_hessian_const_finitediff(a in -1.0..1.0,
393 b in -1.0..1.0,
394 c in -1.0..1.0,
395 d in -1.0..1.0,
396 e in -1.0..1.0,
397 f in -1.0..1.0,
398 g in -1.0..1.0,
399 h in -1.0..1.0) {
400 let param = [a, b, c, d, e, f, g, h];
401 let hessian = rosenbrock_hessian_const(¶m);
402 let hessian_fd =
403 Vec::from(param).forward_hessian(&|x| rosenbrock_derivative(&x));
404 let n = hessian.len();
405 for i in 0..n {
406 assert_eq!(hessian[i].len(), n);
407 for j in 0..n {
408 assert_relative_eq!(
409 hessian[i][j],
410 hessian_fd[i][j],
411 epsilon = 1e-4,
412 max_relative = 1e-2
413 );
414 }
415 }
416 }
417 }
418}