1use anyhow::Error;
9use num::Float;
10use num::FromPrimitive;
11
12use crate::utils::mod_and_calc_const;
13
14use super::CostFn;
15
16pub fn forward_diff_const<const N: usize, F>(
17 x: &[F; N],
18 f: CostFn<'_, N, F>,
19) -> Result<[F; N], Error>
20where
21 F: Float + FromPrimitive,
22{
23 let fx = (f)(x)?;
24 let mut xt = *x;
25 let eps_sqrt = F::epsilon().sqrt();
26 let mut out = [F::from_f64(0.0).unwrap(); N];
27 out.iter_mut()
28 .enumerate()
29 .map(|(i, o)| -> Result<_, Error> {
30 let fx1 = mod_and_calc_const(&mut xt, f, i, eps_sqrt)?;
31 *o = (fx1 - fx) / eps_sqrt;
32 Ok(())
33 })
34 .count();
35 Ok(out)
36}
37
38pub fn central_diff_const<const N: usize, F>(
39 x: &[F; N],
40 f: CostFn<'_, N, F>,
41) -> Result<[F; N], Error>
42where
43 F: Float + FromPrimitive,
44{
45 let mut xt = *x;
46 let eps_cbrt = F::epsilon().cbrt();
47 let mut out = [F::from_f64(0.0).unwrap(); N];
48 out.iter_mut()
49 .enumerate()
50 .map(|(i, o)| -> Result<_, Error> {
51 let fx1 = mod_and_calc_const(&mut xt, f, i, eps_cbrt)?;
52 let fx2 = mod_and_calc_const(&mut xt, f, i, -eps_cbrt)?;
53 *o = (fx1 - fx2) / (F::from_f64(2.0).unwrap() * eps_cbrt);
54 Ok(())
55 })
56 .count();
57 Ok(out)
58}
59
60#[cfg(test)]
61mod tests {
62 use super::*;
63
64 const COMP_ACC: f64 = 1e-6;
65
66 fn f(x: &[f64; 2]) -> Result<f64, Error> {
67 Ok(x[0] + x[1].powi(2))
68 }
69
70 fn f2(x: &[f64; 2]) -> Result<f64, Error> {
71 Ok(x[0] + x[1].powi(2))
72 }
73
74 #[test]
75 fn test_forward_diff_const_f64() {
76 let p = [1.0f64, 1.0f64];
77 let grad = forward_diff_const(&p, &f2).unwrap();
78 let res = [1.0f64, 2.0];
79
80 (0..2)
81 .map(|i| assert!((res[i] - grad[i]).abs() < COMP_ACC))
82 .count();
83
84 let p = [1.0f64, 2.0f64];
85 let grad = forward_diff_const(&p, &f2).unwrap();
86 let res = [1.0f64, 4.0];
87
88 (0..2)
89 .map(|i| assert!((res[i] - grad[i]).abs() < COMP_ACC))
90 .count();
91 }
92
93 #[test]
94 fn test_central_diff_vec_f64() {
95 let p = [1.0f64, 1.0f64];
96 let grad = central_diff_const(&p, &f).unwrap();
97 let res = [1.0f64, 2.0];
98
99 (0..2)
100 .map(|i| assert!((res[i] - grad[i]).abs() < COMP_ACC))
101 .count();
102
103 let p = [1.0f64, 2.0f64];
104 let grad = central_diff_const(&p, &f).unwrap();
105 let res = [1.0f64, 4.0];
106
107 (0..2)
108 .map(|i| assert!((res[i] - grad[i]).abs() < COMP_ACC))
109 .count();
110 }
111}