finitediff/vec/
diff.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use anyhow::Error;
9use num::Float;
10use num::FromPrimitive;
11
12use crate::utils::mod_and_calc;
13
14use super::CostFn;
15
16pub fn forward_diff_vec<F>(x: &Vec<F>, f: CostFn<'_, F>) -> Result<Vec<F>, Error>
17where
18    F: Float,
19{
20    let fx = (f)(x)?;
21    let mut xt = x.clone();
22    let eps_sqrt = F::epsilon().sqrt();
23    (0..x.len())
24        .map(|i| -> Result<F, Error> {
25            let fx1 = mod_and_calc(&mut xt, f, i, eps_sqrt)?;
26            Ok((fx1 - fx) / eps_sqrt)
27        })
28        .collect()
29}
30
31pub fn central_diff_vec<F>(x: &[F], f: CostFn<'_, F>) -> Result<Vec<F>, Error>
32where
33    F: Float + FromPrimitive,
34{
35    let mut xt = x.to_owned();
36    let eps_cbrt = F::epsilon().cbrt();
37    (0..x.len())
38        .map(|i| -> Result<_, Error> {
39            let fx1 = mod_and_calc(&mut xt, f, i, eps_cbrt)?;
40            let fx2 = mod_and_calc(&mut xt, f, i, -eps_cbrt)?;
41            Ok((fx1 - fx2) / (F::from_f64(2.0).unwrap() * eps_cbrt))
42        })
43        .collect()
44}
45
46#[cfg(test)]
47mod tests {
48    use super::*;
49
50    const COMP_ACC: f64 = 1e-6;
51
52    fn f(x: &Vec<f64>) -> Result<f64, Error> {
53        Ok(x[0] + x[1].powi(2))
54    }
55
56    #[test]
57    fn test_forward_diff_vec_f64() {
58        let p = vec![1.0f64, 1.0f64];
59        let grad = forward_diff_vec(&p, &f).unwrap();
60        let res = [1.0f64, 2.0];
61
62        (0..2)
63            .map(|i| assert!((res[i] - grad[i]).abs() < COMP_ACC))
64            .count();
65
66        let p = vec![1.0f64, 2.0f64];
67        let grad = forward_diff_vec(&p, &f).unwrap();
68        let res = [1.0f64, 4.0];
69
70        (0..2)
71            .map(|i| assert!((res[i] - grad[i]).abs() < COMP_ACC))
72            .count();
73    }
74
75    #[test]
76    fn test_central_diff_vec_f64() {
77        let p = vec![1.0f64, 1.0f64];
78        let grad = central_diff_vec(&p, &f).unwrap();
79        let res = [1.0f64, 2.0];
80
81        (0..2)
82            .map(|i| assert!((res[i] - grad[i]).abs() < COMP_ACC))
83            .count();
84
85        let p = vec![1.0f64, 2.0f64];
86        let grad = central_diff_vec(&p, &f).unwrap();
87        let res = [1.0f64, 4.0];
88
89        (0..2)
90            .map(|i| assert!((res[i] - grad[i]).abs() < COMP_ACC))
91            .count();
92    }
93}