finitediff/array/
diff.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use anyhow::Error;
9use num::Float;
10use num::FromPrimitive;
11
12use crate::utils::mod_and_calc_const;
13
14use super::CostFn;
15
16pub fn forward_diff_const<const N: usize, F>(
17    x: &[F; N],
18    f: CostFn<'_, N, F>,
19) -> Result<[F; N], Error>
20where
21    F: Float + FromPrimitive,
22{
23    let fx = (f)(x)?;
24    let mut xt = *x;
25    let eps_sqrt = F::epsilon().sqrt();
26    let mut out = [F::from_f64(0.0).unwrap(); N];
27    out.iter_mut()
28        .enumerate()
29        .map(|(i, o)| -> Result<_, Error> {
30            let fx1 = mod_and_calc_const(&mut xt, f, i, eps_sqrt)?;
31            *o = (fx1 - fx) / eps_sqrt;
32            Ok(())
33        })
34        .count();
35    Ok(out)
36}
37
38pub fn central_diff_const<const N: usize, F>(
39    x: &[F; N],
40    f: CostFn<'_, N, F>,
41) -> Result<[F; N], Error>
42where
43    F: Float + FromPrimitive,
44{
45    let mut xt = *x;
46    let eps_cbrt = F::epsilon().cbrt();
47    let mut out = [F::from_f64(0.0).unwrap(); N];
48    out.iter_mut()
49        .enumerate()
50        .map(|(i, o)| -> Result<_, Error> {
51            let fx1 = mod_and_calc_const(&mut xt, f, i, eps_cbrt)?;
52            let fx2 = mod_and_calc_const(&mut xt, f, i, -eps_cbrt)?;
53            *o = (fx1 - fx2) / (F::from_f64(2.0).unwrap() * eps_cbrt);
54            Ok(())
55        })
56        .count();
57    Ok(out)
58}
59
60#[cfg(test)]
61mod tests {
62    use super::*;
63
64    const COMP_ACC: f64 = 1e-6;
65
66    fn f(x: &[f64; 2]) -> Result<f64, Error> {
67        Ok(x[0] + x[1].powi(2))
68    }
69
70    fn f2(x: &[f64; 2]) -> Result<f64, Error> {
71        Ok(x[0] + x[1].powi(2))
72    }
73
74    #[test]
75    fn test_forward_diff_const_f64() {
76        let p = [1.0f64, 1.0f64];
77        let grad = forward_diff_const(&p, &f2).unwrap();
78        let res = [1.0f64, 2.0];
79
80        (0..2)
81            .map(|i| assert!((res[i] - grad[i]).abs() < COMP_ACC))
82            .count();
83
84        let p = [1.0f64, 2.0f64];
85        let grad = forward_diff_const(&p, &f2).unwrap();
86        let res = [1.0f64, 4.0];
87
88        (0..2)
89            .map(|i| assert!((res[i] - grad[i]).abs() < COMP_ACC))
90            .count();
91    }
92
93    #[test]
94    fn test_central_diff_vec_f64() {
95        let p = [1.0f64, 1.0f64];
96        let grad = central_diff_const(&p, &f).unwrap();
97        let res = [1.0f64, 2.0];
98
99        (0..2)
100            .map(|i| assert!((res[i] - grad[i]).abs() < COMP_ACC))
101            .count();
102
103        let p = [1.0f64, 2.0f64];
104        let grad = central_diff_const(&p, &f).unwrap();
105        let res = [1.0f64, 4.0];
106
107        (0..2)
108            .map(|i| assert!((res[i] - grad[i]).abs() < COMP_ACC))
109            .count();
110    }
111}