1use anyhow::Error;
9use num::Float;
10use num::FromPrimitive;
11
12use crate::utils::mod_and_calc;
13
14use super::CostFn;
15
16pub fn forward_diff_vec<F>(x: &Vec<F>, f: CostFn<'_, F>) -> Result<Vec<F>, Error>
17where
18 F: Float,
19{
20 let fx = (f)(x)?;
21 let mut xt = x.clone();
22 let eps_sqrt = F::epsilon().sqrt();
23 (0..x.len())
24 .map(|i| -> Result<F, Error> {
25 let fx1 = mod_and_calc(&mut xt, f, i, eps_sqrt)?;
26 Ok((fx1 - fx) / eps_sqrt)
27 })
28 .collect()
29}
30
31pub fn central_diff_vec<F>(x: &[F], f: CostFn<'_, F>) -> Result<Vec<F>, Error>
32where
33 F: Float + FromPrimitive,
34{
35 let mut xt = x.to_owned();
36 let eps_cbrt = F::epsilon().cbrt();
37 (0..x.len())
38 .map(|i| -> Result<_, Error> {
39 let fx1 = mod_and_calc(&mut xt, f, i, eps_cbrt)?;
40 let fx2 = mod_and_calc(&mut xt, f, i, -eps_cbrt)?;
41 Ok((fx1 - fx2) / (F::from_f64(2.0).unwrap() * eps_cbrt))
42 })
43 .collect()
44}
45
46#[cfg(test)]
47mod tests {
48 use super::*;
49
50 const COMP_ACC: f64 = 1e-6;
51
52 fn f(x: &Vec<f64>) -> Result<f64, Error> {
53 Ok(x[0] + x[1].powi(2))
54 }
55
56 #[test]
57 fn test_forward_diff_vec_f64() {
58 let p = vec![1.0f64, 1.0f64];
59 let grad = forward_diff_vec(&p, &f).unwrap();
60 let res = [1.0f64, 2.0];
61
62 (0..2)
63 .map(|i| assert!((res[i] - grad[i]).abs() < COMP_ACC))
64 .count();
65
66 let p = vec![1.0f64, 2.0f64];
67 let grad = forward_diff_vec(&p, &f).unwrap();
68 let res = [1.0f64, 4.0];
69
70 (0..2)
71 .map(|i| assert!((res[i] - grad[i]).abs() < COMP_ACC))
72 .count();
73 }
74
75 #[test]
76 fn test_central_diff_vec_f64() {
77 let p = vec![1.0f64, 1.0f64];
78 let grad = central_diff_vec(&p, &f).unwrap();
79 let res = [1.0f64, 2.0];
80
81 (0..2)
82 .map(|i| assert!((res[i] - grad[i]).abs() < COMP_ACC))
83 .count();
84
85 let p = vec![1.0f64, 2.0f64];
86 let grad = central_diff_vec(&p, &f).unwrap();
87 let res = [1.0f64, 4.0];
88
89 (0..2)
90 .map(|i| assert!((res[i] - grad[i]).abs() < COMP_ACC))
91 .count();
92 }
93}