1use num::{Float, FromPrimitive};
35use std::f64::consts::PI;
36use std::iter::Sum;
37
38pub fn levy<T>(param: &[T]) -> T
51where
52 T: Float + FromPrimitive + Sum,
53{
54 let plen = param.len();
55 assert!(plen >= 2);
56
57 let n1 = T::from_f64(1.0).unwrap();
58 let n2 = T::from_f64(2.0).unwrap();
59 let n4 = T::from_f64(4.0).unwrap();
60 let n10 = T::from_f64(10.0).unwrap();
61 let pi = T::from_f64(PI).unwrap();
62
63 let w = |x: T| n1 + (x - n1) / n4;
64
65 (pi * w(param[0])).sin().powi(2)
66 + param[1..(plen - 1)]
67 .iter()
68 .map(|x| w(*x))
69 .map(|wi: T| (wi - n1).powi(2) * (n1 + n10 * (pi * wi + n1).sin().powi(2)))
70 .sum()
71 + (w(param[plen - 1]) - n1).powi(2) * (n1 + (n2 * pi * w(param[plen - 1])).sin().powi(2))
72}
73
74pub fn levy_derivative<T>(param: &[T]) -> Vec<T>
76where
77 T: Float + FromPrimitive + Sum,
78{
79 let d = param.len();
80 assert!(d >= 2);
81
82 let n1 = T::from_f64(1.0).unwrap();
83 let n2 = T::from_f64(2.0).unwrap();
84 let n4 = T::from_f64(4.0).unwrap();
85 let n5 = T::from_f64(5.0).unwrap();
86 let n8 = T::from_f64(8.0).unwrap();
87 let n10 = T::from_f64(10.0).unwrap();
88 let n16 = T::from_f64(16.0).unwrap();
89 let pi = T::from_f64(PI).unwrap();
90
91 param
92 .iter()
93 .enumerate()
94 .map(|(i, x)| (i, x, pi * ((*x - n1) / n4 + n1)))
95 .map(|(i, &x, wp)| {
96 if i == 0 {
97 pi / n2 * wp.cos() * wp.sin()
98 } else if i == d - 1 {
99 ((n2 * wp).sin().powi(2) + n1) * (x - n1) / n8
100 + pi / n16 * (n2 * wp).cos() * (n2 * wp).sin() * (x - n1).powi(2)
101 } else {
102 (n10 * (wp + n1).sin().powi(2) + n1) * (x - n1) / n8
103 + n5 / n16 * pi * (wp + n1).cos() * (wp + n1).sin() * (x - n1).powi(2)
104 }
105 })
106 .collect()
107}
108
109pub fn levy_derivative_const<const N: usize, T>(param: &[T; N]) -> [T; N]
114where
115 T: Float + FromPrimitive + Sum,
116{
117 assert!(N >= 2);
118
119 let n1 = T::from_f64(1.0).unwrap();
120 let n0 = T::from_f64(0.0).unwrap();
121 let n2 = T::from_f64(2.0).unwrap();
122 let n4 = T::from_f64(4.0).unwrap();
123 let n5 = T::from_f64(5.0).unwrap();
124 let n8 = T::from_f64(8.0).unwrap();
125 let n10 = T::from_f64(10.0).unwrap();
126 let n16 = T::from_f64(16.0).unwrap();
127 let pi = T::from_f64(PI).unwrap();
128
129 let mut out = [n0; N];
130
131 param
132 .iter()
133 .zip(out.iter_mut())
134 .enumerate()
135 .map(|(i, (x, o))| (i, x, pi * ((*x - n1) / n4 + n1), o))
136 .map(|(i, &x, wp, o)| {
137 *o = if i == 0 {
138 pi / n2 * wp.cos() * wp.sin()
139 } else if i == N - 1 {
140 ((n2 * wp).sin().powi(2) + n1) * (x - n1) / n8
141 + pi / n16 * (n2 * wp).cos() * (n2 * wp).sin() * (x - n1).powi(2)
142 } else {
143 (n10 * (wp + n1).sin().powi(2) + n1) * (x - n1) / n8
144 + n5 / n16 * pi * (wp + n1).cos() * (wp + n1).sin() * (x - n1).powi(2)
145 }
146 })
147 .count();
148 out
149}
150
151pub fn levy_hessian<T>(param: &[T]) -> Vec<Vec<T>>
153where
154 T: Float + FromPrimitive + Sum,
155{
156 let d = param.len();
157 assert!(d >= 2);
158
159 let x = param;
160
161 let n0 = T::from_f64(0.0).unwrap();
162 let n1 = T::from_f64(1.0).unwrap();
163 let n2 = T::from_f64(2.0).unwrap();
164 let n4 = T::from_f64(4.0).unwrap();
165 let n5 = T::from_f64(5.0).unwrap();
166 let n6 = T::from_f64(6.0).unwrap();
167 let n8 = T::from_f64(8.0).unwrap();
168 let n10 = T::from_f64(10.0).unwrap();
169 let n32 = T::from_f64(32.0).unwrap();
170 let n64 = T::from_f64(64.0).unwrap();
171 let pi = T::from_f64(PI).unwrap();
172 let pi2 = T::from_f64(PI.powi(2)).unwrap();
173
174 let mut out = vec![vec![n0; d]; d];
175
176 for i in 0..d {
177 let xin1 = x[i] - n1;
178 let wp = pi * (xin1 / n4 + n1);
179 out[i][i] = if i == 0 {
180 pi2 / n8 * (wp.cos().powi(2) - wp.sin().powi(2))
181 } else if i == d - 1 {
182 -(n4 * pi * xin1 * (pi * x[i]).sin() + (pi2 * xin1.powi(2) - n2) * (pi * x[i]).cos()
183 - n6)
184 / n32
185 } else {
186 let wp1cos = (wp + n1).cos();
187 let wp1sin = (wp + n1).sin();
188 n5 / n4 * pi * wp1cos * wp1sin * xin1
189 + n5 / n64 * pi2 * xin1.powi(2) * (wp1cos.powi(2) - wp1sin.powi(2))
190 + (n10 * wp1sin.powi(2) + n1) / n8
191 }
192 }
193
194 out
195}
196
197pub fn levy_hessian_const<const N: usize, T>(param: &[T; N]) -> [[T; N]; N]
202where
203 T: Float + FromPrimitive + Sum,
204{
205 assert!(N >= 2);
206
207 let x = param;
208
209 let n0 = T::from_f64(0.0).unwrap();
210 let n1 = T::from_f64(1.0).unwrap();
211 let n2 = T::from_f64(2.0).unwrap();
212 let n4 = T::from_f64(4.0).unwrap();
213 let n5 = T::from_f64(5.0).unwrap();
214 let n6 = T::from_f64(6.0).unwrap();
215 let n8 = T::from_f64(8.0).unwrap();
216 let n10 = T::from_f64(10.0).unwrap();
217 let n32 = T::from_f64(32.0).unwrap();
218 let n64 = T::from_f64(64.0).unwrap();
219 let pi = T::from_f64(PI).unwrap();
220 let pi2 = T::from_f64(PI.powi(2)).unwrap();
221
222 let mut out = [[n0; N]; N];
223
224 for i in 0..N {
225 let xin1 = x[i] - n1;
226 let wp = pi * (xin1 / n4 + n1);
227 out[i][i] = if i == 0 {
228 pi2 / n8 * (wp.cos().powi(2) - wp.sin().powi(2))
229 } else if i == N - 1 {
230 -(n4 * pi * xin1 * (pi * x[i]).sin() + (pi2 * xin1.powi(2) - n2) * (pi * x[i]).cos()
231 - n6)
232 / n32
233 } else {
234 let wp1cos = (wp + n1).cos();
235 let wp1sin = (wp + n1).sin();
236 n5 / n4 * pi * wp1cos * wp1sin * xin1
237 + n5 / n64 * pi2 * xin1.powi(2) * (wp1cos.powi(2) - wp1sin.powi(2))
238 + (n10 * wp1sin.powi(2) + n1) / n8
239 }
240 }
241
242 out
243}
244
245pub fn levy_n13<T>(param: &[T; 2]) -> T
258where
259 T: Float + FromPrimitive + Sum,
260{
261 let [x1, x2] = *param;
262
263 let n1 = T::from_f64(1.0).unwrap();
264 let n2 = T::from_f64(2.0).unwrap();
265 let n3 = T::from_f64(3.0).unwrap();
266 let pi = T::from_f64(PI).unwrap();
267
268 (n3 * pi * x1).sin().powi(2)
269 + (x1 - n1).powi(2) * (n1 + (n3 * pi * x2).sin().powi(2))
270 + (x2 - n1).powi(2) * (n1 + (n2 * pi * x2).sin().powi(2))
271}
272
273pub fn levy_n13_derivative<T>(param: &[T; 2]) -> [T; 2]
275where
276 T: Float + FromPrimitive + Sum,
277{
278 let [x1, x2] = *param;
279
280 let n1 = T::from_f64(1.0).unwrap();
281 let n2 = T::from_f64(2.0).unwrap();
282 let n3 = T::from_f64(3.0).unwrap();
283 let n4 = T::from_f64(4.0).unwrap();
284 let n6 = T::from_f64(6.0).unwrap();
285 let pi = T::from_f64(PI).unwrap();
286
287 let x1t3 = n3 * pi * x1;
288 let x2t3 = n3 * pi * x2;
289 let x2t2 = n2 * pi * x2;
290 let x1t3s = x1t3.sin();
291 let x1t3c = x1t3.cos();
292 let x2t3s = x2t3.sin();
293 let x2t3c = x2t3.cos();
294 let x2t3s2 = x2t3s.powi(2);
295 let x2t2s = x2t2.sin();
296 let x2t2c = x2t2.cos();
297 let x2t2s2 = x2t2s.powi(2);
298
299 [
300 n6 * pi * x1t3c * x1t3s + n2 * (x2t3s2 + n1) * (x1 - n1),
301 n6 * pi * (x1 - n1).powi(2) * x2t3c * x2t3s
302 + n2 * (x2 - n1) * (x2t2s2 + n1)
303 + n4 * pi * (x2 - n1).powi(2) * x2t2c * x2t2s,
304 ]
305}
306
307pub fn levy_n13_hessian<T>(param: &[T; 2]) -> [[T; 2]; 2]
309where
310 T: Float + FromPrimitive + Sum,
311{
312 let [x1, x2] = *param;
313
314 let n1 = T::from_f64(1.0).unwrap();
315 let n2 = T::from_f64(2.0).unwrap();
316 let n3 = T::from_f64(3.0).unwrap();
317 let n8 = T::from_f64(8.0).unwrap();
318 let n12 = T::from_f64(12.0).unwrap();
319 let n16 = T::from_f64(16.0).unwrap();
320 let n18 = T::from_f64(18.0).unwrap();
321 let pi = T::from_f64(PI).unwrap();
322 let pi2 = T::from_f64(PI.powi(2)).unwrap();
323
324 let x1t3 = n3 * pi * x1;
325 let x2t3 = n3 * pi * x2;
326 let x2t2 = n2 * pi * x2;
327 let x1t3s = x1t3.sin();
328 let x1t3c = x1t3.cos();
329 let x2t3s = x2t3.sin();
330 let x2t3c = x2t3.cos();
331 let x1t3s2 = x1t3s.powi(2);
332 let x1t3c2 = x1t3c.powi(2);
333 let x2t3s2 = x2t3s.powi(2);
334 let x2t3c2 = x2t3c.powi(2);
335 let x2t2s = x2t2.sin();
336 let x2t2c = x2t2.cos();
337 let x2t2s2 = x2t2s.powi(2);
338 let x2t2c2 = x2t2c.powi(2);
339
340 let a = n18 * pi2 * (-x1t3s2 + x1t3c2) + n2 * (x2t3s2 + n1);
341 let b = n18 * pi2 * (x1 - n1).powi(2) * (-x2t3s2 + x2t3c2)
342 + n2 * (x2t2s2 + n1)
343 + n8 * pi2 * (x2 - n1).powi(2) * (-x2t2s2 + x2t2c2)
344 + n16 * pi * (x2 - n1) * x2t2s * x2t2c;
345 let offdiag = n12 * pi * (x1 - n1) * x2t3c * x2t3s;
346
347 [[a, offdiag], [offdiag, b]]
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353 use approx::assert_relative_eq;
354 use finitediff::FiniteDiff;
355 use proptest::prelude::*;
356 use std::{f32, f64};
357
358 #[test]
359 fn test_levy_optimum() {
360 assert_relative_eq!(levy(&[1_f32, 1_f32, 1_f32]), 0.0, epsilon = f32::EPSILON);
361 assert_relative_eq!(levy(&[1_f64, 1_f64, 1_f64]), 0.0, epsilon = f64::EPSILON);
362
363 let deriv = levy_derivative(&[1_f64, 1_f64, 1_f64]);
364 for i in 0..2 {
365 assert_relative_eq!(deriv[i], 0.0, epsilon = 1e-12, max_relative = 1e-12);
366 }
367
368 let deriv = levy_derivative_const(&[1_f64, 1_f64, 1_f64]);
369 for i in 0..2 {
370 assert_relative_eq!(deriv[i], 0.0, epsilon = 1e-12, max_relative = 1e-12);
371 }
372 }
373
374 #[test]
375 fn test_levy_n13_optimum() {
376 assert_relative_eq!(levy_n13(&[1_f32, 1_f32]), 0.0, epsilon = f32::EPSILON);
377 assert_relative_eq!(levy_n13(&[1_f64, 1_f64]), 0.0, epsilon = f64::EPSILON);
378
379 let deriv = levy_n13_derivative(&[1_f64, 1_f64]);
380 for i in 0..2 {
381 assert_relative_eq!(deriv[i], 0.0, epsilon = 1e-12, max_relative = 1e-12);
382 }
383 }
384
385 #[test]
386 #[should_panic]
387 fn test_levy_param_length() {
388 levy(&[0.0_f32]);
389 }
390
391 proptest! {
392 #[test]
393 fn test_levy_n13_derivative_finitediff(a in -10.0..10.0, b in -10.0..10.0) {
394 let param = [a, b];
395 let derivative = levy_n13_derivative(¶m);
396 let derivative_fd = Vec::from(param).central_diff(&|x| levy_n13(&[x[0], x[1]]));
397 for i in 0..derivative.len() {
400 assert_relative_eq!(
401 derivative[i],
402 derivative_fd[i],
403 epsilon = 1e-5,
404 max_relative = 1e-2,
405 );
406 }
407 }
408 }
409
410 proptest! {
411 #[test]
412 fn test_levy_n13_hessian_finitediff(a in -10.0..10.0, b in -10.0..10.0) {
413 let param = [a, b];
414 let hessian = levy_n13_hessian(¶m);
415 let hessian_fd =
416 Vec::from(param).central_hessian(&|x| levy_n13_derivative(&[x[0], x[1]]).to_vec());
417 let n = hessian.len();
418 for i in 0..n {
421 assert_eq!(hessian[i].len(), n);
422 for j in 0..n {
423 if hessian[i][j].is_finite() {
424 assert_relative_eq!(
425 hessian[i][j],
426 hessian_fd[i][j],
427 epsilon = 1e-5,
428 max_relative = 1e-2
429 );
430 }
431 }
432 }
433 }
434 }
435
436 proptest! {
437 #[test]
438 fn test_levy_derivative_finitediff(a in -10.0..10.0, b in -10.0..10.0, c in -10.0..10.0) {
439 let param = [a, b, c];
440 let derivative = levy_derivative(¶m);
441 let derivative_fd = Vec::from(param).central_diff(&|x| levy(&[x[0], x[1], x[2]]));
442 for i in 0..derivative.len() {
445 assert_relative_eq!(
446 derivative[i],
447 derivative_fd[i],
448 epsilon = 1e-5,
449 max_relative = 1e-2,
450 );
451 }
452 }
453 }
454
455 proptest! {
456 #[test]
457 fn test_levy_derivative_const_finitediff(a in -10.0..10.0, b in -10.0..10.0, c in -10.0..10.0) {
458 let param = [a, b, c];
459 let derivative = levy_derivative_const(¶m);
460 let derivative_fd = Vec::from(param).central_diff(&|x| levy(&[x[0], x[1], x[2]]));
461 for i in 0..derivative.len() {
464 assert_relative_eq!(
465 derivative[i],
466 derivative_fd[i],
467 epsilon = 1e-5,
468 max_relative = 1e-2,
469 );
470 }
471 }
472 }
473
474 proptest! {
475 #[test]
476 fn test_levy_hessian_finitediff(a in -10.0..10.0, b in -10.0..10.0, c in -10.0..10.0) {
477 let param = [a, b, c];
478 let hessian = levy_hessian(¶m);
479 let hessian_fd = Vec::from(param).central_hessian(&|x| levy_derivative(&x).to_vec());
480 let n = hessian.len();
481 for i in 0..n {
484 assert_eq!(hessian[i].len(), n);
485 for j in 0..n {
486 if hessian[i][j].is_finite() {
487 assert_relative_eq!(
488 hessian[i][j],
489 hessian_fd[i][j],
490 epsilon = 1e-5,
491 max_relative = 1e-2
492 );
493 }
494 }
495 }
496 }
497 }
498
499 proptest! {
500 #[test]
501 fn test_levy_hessian_const_finitediff(a in -10.0..10.0, b in -10.0..10.0, c in -10.0..10.0) {
502 let param = [a, b, c];
503 let hessian = levy_hessian_const(¶m);
504 let hessian_fd = Vec::from(param).central_hessian(&|x| levy_derivative(&x).to_vec());
505 let n = hessian.len();
506 for i in 0..n {
509 assert_eq!(hessian[i].len(), n);
510 for j in 0..n {
511 if hessian[i][j].is_finite() {
512 assert_relative_eq!(
513 hessian[i][j],
514 hessian_fd[i][j],
515 epsilon = 1e-5,
516 max_relative = 1e-2
517 );
518 }
519 }
520 }
521 }
522 }
523}