finitediff/vec/
mod.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8// Some types just are complex
9#![allow(clippy::type_complexity)]
10
11pub mod diff;
12pub mod hessian;
13pub mod jacobian;
14
15use std::ops::AddAssign;
16
17use anyhow::Error;
18use num::{Float, FromPrimitive};
19
20use crate::PerturbationVectors;
21use diff::{central_diff_vec, forward_diff_vec};
22use hessian::{
23    central_hessian_vec, central_hessian_vec_prod_vec, forward_hessian_nograd_sparse_vec,
24    forward_hessian_nograd_vec, forward_hessian_vec, forward_hessian_vec_prod_vec,
25};
26use jacobian::{
27    central_jacobian_pert_vec, central_jacobian_vec, central_jacobian_vec_prod_vec,
28    forward_jacobian_pert_vec, forward_jacobian_vec, forward_jacobian_vec_prod_vec,
29};
30
31pub(crate) type CostFn<'a, F> = &'a dyn Fn(&Vec<F>) -> Result<F, Error>;
32pub(crate) type GradientFn<'a, F> = &'a dyn Fn(&Vec<F>) -> Result<Vec<F>, Error>;
33pub(crate) type OpFn<'a, F> = &'a dyn Fn(&Vec<F>) -> Result<Vec<F>, Error>;
34
35// pub trait GradientImpl<'a, F>: Fn(&Vec<F>) -> Result<Vec<F>, Error> + 'a {}
36// impl<'a, F, T: Fn(&Vec<F>) -> Result<Vec<F>, Error> + 'a> GradientImpl<'a, F> for T {}
37// pub fn forward_diff<F>(f: CostFn<'_, F>) -> impl GradientImpl<'_, F> { .. }
38
39#[inline(always)]
40pub fn forward_diff<F>(f: CostFn<'_, F>) -> impl Fn(&Vec<F>) -> Result<Vec<F>, Error> + '_
41where
42    F: Float + FromPrimitive,
43{
44    move |p: &Vec<F>| forward_diff_vec(p, f)
45}
46
47#[inline(always)]
48pub fn central_diff<F>(f: CostFn<'_, F>) -> impl Fn(&Vec<F>) -> Result<Vec<F>, Error> + '_
49where
50    F: Float + FromPrimitive,
51{
52    move |p: &Vec<F>| central_diff_vec(p, f)
53}
54
55#[inline(always)]
56pub fn forward_jacobian<F>(f: OpFn<'_, F>) -> impl Fn(&Vec<F>) -> Result<Vec<Vec<F>>, Error> + '_
57where
58    F: Float + FromPrimitive,
59{
60    move |p: &Vec<F>| forward_jacobian_vec(p, f)
61}
62
63#[inline(always)]
64pub fn central_jacobian<F>(f: OpFn<'_, F>) -> impl Fn(&Vec<F>) -> Result<Vec<Vec<F>>, Error> + '_
65where
66    F: Float + FromPrimitive,
67{
68    move |p: &Vec<F>| central_jacobian_vec(p, f)
69}
70
71#[inline(always)]
72pub fn forward_jacobian_vec_prod<F>(
73    f: OpFn<'_, F>,
74) -> impl Fn(&Vec<F>, &Vec<F>) -> Result<Vec<F>, Error> + '_
75where
76    F: Float + FromPrimitive,
77{
78    move |p: &Vec<F>, v: &Vec<F>| forward_jacobian_vec_prod_vec(p, f, v)
79}
80
81#[inline(always)]
82pub fn central_jacobian_vec_prod<F>(
83    f: OpFn<'_, F>,
84) -> impl Fn(&Vec<F>, &Vec<F>) -> Result<Vec<F>, Error> + '_
85where
86    F: Float + FromPrimitive,
87{
88    move |p: &Vec<F>, v: &Vec<F>| central_jacobian_vec_prod_vec(p, f, v)
89}
90
91#[inline(always)]
92pub fn forward_jacobian_pert<F>(
93    f: OpFn<'_, F>,
94) -> impl Fn(&Vec<F>, &PerturbationVectors) -> Result<Vec<Vec<F>>, Error> + '_
95where
96    F: Float + FromPrimitive + AddAssign,
97{
98    move |p: &Vec<F>, pert: &PerturbationVectors| forward_jacobian_pert_vec(p, f, pert)
99}
100
101#[inline(always)]
102pub fn central_jacobian_pert<F>(
103    f: OpFn<'_, F>,
104) -> impl Fn(&Vec<F>, &PerturbationVectors) -> Result<Vec<Vec<F>>, Error> + '_
105where
106    F: Float + FromPrimitive + AddAssign,
107{
108    move |p: &Vec<F>, pert: &PerturbationVectors| central_jacobian_pert_vec(p, f, pert)
109}
110
111#[inline(always)]
112pub fn forward_hessian<F>(
113    f: GradientFn<'_, F>,
114) -> impl Fn(&Vec<F>) -> Result<Vec<Vec<F>>, Error> + '_
115where
116    F: Float + FromPrimitive,
117{
118    move |p: &Vec<F>| forward_hessian_vec(p, f)
119}
120
121#[inline(always)]
122pub fn central_hessian<F>(
123    f: GradientFn<'_, F>,
124) -> impl Fn(&Vec<F>) -> Result<Vec<Vec<F>>, Error> + '_
125where
126    F: Float + FromPrimitive,
127{
128    move |p: &Vec<F>| central_hessian_vec(p, f)
129}
130
131#[inline(always)]
132pub fn forward_hessian_vec_prod<F>(
133    f: GradientFn<'_, F>,
134) -> impl Fn(&Vec<F>, &Vec<F>) -> Result<Vec<F>, Error> + '_
135where
136    F: Float + FromPrimitive,
137{
138    move |p: &Vec<F>, v: &Vec<F>| forward_hessian_vec_prod_vec(p, f, v)
139}
140
141#[inline(always)]
142pub fn central_hessian_vec_prod<F>(
143    f: GradientFn<'_, F>,
144) -> impl Fn(&Vec<F>, &Vec<F>) -> Result<Vec<F>, Error> + '_
145where
146    F: Float + FromPrimitive,
147{
148    move |p: &Vec<F>, v: &Vec<F>| central_hessian_vec_prod_vec(p, f, v)
149}
150
151#[inline(always)]
152pub fn forward_hessian_nograd<F>(
153    f: CostFn<'_, F>,
154) -> impl Fn(&Vec<F>) -> Result<Vec<Vec<F>>, Error> + '_
155where
156    F: Float + FromPrimitive + AddAssign,
157{
158    move |p: &Vec<F>| forward_hessian_nograd_vec(p, f)
159}
160
161#[inline(always)]
162pub fn forward_hessian_nograd_sparse<F>(
163    f: CostFn<'_, F>,
164) -> impl Fn(&Vec<F>, Vec<[usize; 2]>) -> Result<Vec<Vec<F>>, Error> + '_
165where
166    F: Float + FromPrimitive + AddAssign,
167{
168    move |p: &Vec<F>, indices: Vec<[usize; 2]>| forward_hessian_nograd_sparse_vec(p, f, indices)
169}
170
171#[cfg(test)]
172mod tests {
173    use crate::{PerturbationVector, PerturbationVectors};
174
175    use super::*;
176
177    const COMP_ACC: f64 = 1e-6;
178
179    fn f1(x: &Vec<f64>) -> Result<f64, Error> {
180        Ok(x[0] + x[1].powi(2))
181    }
182
183    fn f2(x: &Vec<f64>) -> Result<Vec<f64>, Error> {
184        Ok(vec![
185            2.0 * (x[1].powi(3) - x[0].powi(2)),
186            3.0 * (x[1].powi(3) - x[0].powi(2)) + 2.0 * (x[2].powi(3) - x[1].powi(2)),
187            3.0 * (x[2].powi(3) - x[1].powi(2)) + 2.0 * (x[3].powi(3) - x[2].powi(2)),
188            3.0 * (x[3].powi(3) - x[2].powi(2)) + 2.0 * (x[4].powi(3) - x[3].powi(2)),
189            3.0 * (x[4].powi(3) - x[3].powi(2)) + 2.0 * (x[5].powi(3) - x[4].powi(2)),
190            3.0 * (x[5].powi(3) - x[4].powi(2)),
191        ])
192    }
193
194    fn f3(x: &Vec<f64>) -> Result<f64, Error> {
195        Ok(x[0] + x[1].powi(2) + x[2] * x[3].powi(2))
196    }
197
198    fn g(x: &Vec<f64>) -> Result<Vec<f64>, Error> {
199        Ok(vec![1.0, 2.0 * x[1], x[3].powi(2), 2.0 * x[3] * x[2]])
200    }
201
202    fn x1() -> Vec<f64> {
203        vec![1.0f64, 1.0f64]
204    }
205
206    fn x2() -> Vec<f64> {
207        vec![1.0f64, 1.0, 1.0, 1.0, 1.0, 1.0]
208    }
209
210    fn x3() -> Vec<f64> {
211        vec![1.0f64, 1.0, 1.0, 1.0]
212    }
213
214    fn res1() -> Vec<Vec<f64>> {
215        vec![
216            vec![-4.0, 6.0, 0.0, 0.0, 0.0, 0.0],
217            vec![-6.0, 5.0, 6.0, 0.0, 0.0, 0.0],
218            vec![0.0, -6.0, 5.0, 6.0, 0.0, 0.0],
219            vec![0.0, 0.0, -6.0, 5.0, 6.0, 0.0],
220            vec![0.0, 0.0, 0.0, -6.0, 5.0, 6.0],
221            vec![0.0, 0.0, 0.0, 0.0, -6.0, 9.0],
222        ]
223    }
224
225    fn res2() -> Vec<Vec<f64>> {
226        vec![
227            vec![0.0, 0.0, 0.0, 0.0],
228            vec![0.0, 2.0, 0.0, 0.0],
229            vec![0.0, 0.0, 0.0, 2.0],
230            vec![0.0, 0.0, 2.0, 2.0],
231        ]
232    }
233
234    fn res3() -> Vec<f64> {
235        vec![8.0, 22.0, 27.0, 32.0, 37.0, 24.0]
236    }
237
238    fn pert() -> PerturbationVectors {
239        vec![
240            PerturbationVector::new()
241                .add(0, vec![0, 1])
242                .add(3, vec![2, 3, 4]),
243            PerturbationVector::new()
244                .add(1, vec![0, 1, 2])
245                .add(4, vec![3, 4, 5]),
246            PerturbationVector::new()
247                .add(2, vec![1, 2, 3])
248                .add(5, vec![4, 5]),
249        ]
250    }
251
252    fn p1() -> Vec<f64> {
253        vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0]
254    }
255
256    fn p2() -> Vec<f64> {
257        vec![2.0, 3.0, 4.0, 5.0]
258    }
259
260    #[test]
261    fn test_forward_diff_func() {
262        let grad = forward_diff(&f1);
263        let out = grad(&x1()).unwrap();
264        let res = [1.0, 2.0];
265
266        for i in 0..2 {
267            assert!((res[i] - out[i]).abs() < COMP_ACC)
268        }
269
270        let p = vec![1.0, 2.0];
271        let grad = forward_diff(&f1);
272        let out = grad(&p).unwrap();
273        let res = [1.0, 4.0];
274
275        for i in 0..2 {
276            assert!((res[i] - out[i]).abs() < COMP_ACC)
277        }
278    }
279
280    #[test]
281    fn test_central_diff_func() {
282        let grad = central_diff(&f1);
283        let out = grad(&x1()).unwrap();
284        let res = [1.0f64, 2.0];
285
286        for i in 0..2 {
287            assert!((res[i] - out[i]).abs() < COMP_ACC)
288        }
289
290        let p = vec![1.0f64, 2.0f64];
291        let grad = central_diff(&f1);
292        let out = grad(&p).unwrap();
293        let res = [1.0f64, 4.0];
294
295        for i in 0..2 {
296            assert!((res[i] - out[i]).abs() < COMP_ACC)
297        }
298    }
299
300    #[test]
301    fn test_forward_jacobian_func() {
302        let jacobian = forward_jacobian(&f2);
303        let out = jacobian(&x2()).unwrap();
304        let res = res1();
305        // println!("{:?}", out);
306        // println!("{:?}", res);
307        for i in 0..6 {
308            for j in 0..6 {
309                assert!((res[i][j] - out[i][j]).abs() < COMP_ACC)
310            }
311        }
312    }
313
314    #[test]
315    fn test_central_jacobian_vec_f64_trait() {
316        let jacobian = central_jacobian(&f2);
317        let out = jacobian(&x2()).unwrap();
318        let res = res1();
319        // println!("{:?}", jacobian);
320        for i in 0..6 {
321            for j in 0..6 {
322                assert!((res[i][j] - out[i][j]).abs() < COMP_ACC)
323            }
324        }
325    }
326
327    #[test]
328    fn test_forward_jacobian_vec_prod_vec_func() {
329        let jacobian = forward_jacobian_vec_prod(&f2);
330        let out = jacobian(&x2(), &p1()).unwrap();
331        let res = res3();
332        // println!("{:?}", jacobian);
333        // the accuracy for this is pretty bad!!
334        for i in 0..6 {
335            assert!((res[i] - out[i]).abs() < 5.5 * COMP_ACC)
336        }
337    }
338
339    #[test]
340    fn test_central_jacobian_vec_prod_vec_func() {
341        let jacobian = central_jacobian_vec_prod(&f2);
342        let out = jacobian(&x2(), &p1()).unwrap();
343        let res = res3();
344        // println!("{:?}", jacobian);
345        for i in 0..6 {
346            assert!((res[i] - out[i]).abs() < COMP_ACC)
347        }
348    }
349
350    #[test]
351    fn test_forward_jacobian_pert_func() {
352        let jacobian = forward_jacobian_pert(&f2);
353        let out = jacobian(&x2(), &pert()).unwrap();
354        let res = res1();
355        // println!("jacobian:\n{:?}", jacobian);
356        // println!("res:\n{:?}", res);
357        for i in 0..6 {
358            for j in 0..6 {
359                assert!((res[i][j] - out[i][j]).abs() < COMP_ACC)
360            }
361        }
362    }
363
364    #[test]
365    fn test_central_jacobian_pert_func() {
366        let jacobian = central_jacobian_pert(&f2);
367        let out = jacobian(&x2(), &pert()).unwrap();
368        let res = res1();
369        // println!("jacobian:\n{:?}", jacobian);
370        // println!("res:\n{:?}", res);
371        for i in 0..6 {
372            for j in 0..6 {
373                assert!((res[i][j] - out[i][j]).abs() < COMP_ACC)
374            }
375        }
376    }
377
378    #[test]
379    fn test_forward_hessian_func() {
380        let hessian = forward_hessian(&g);
381        let out = hessian(&x3()).unwrap();
382        let res = res2();
383        // println!("hessian:\n{:#?}", hessian);
384        // println!("diff:\n{:#?}", diff);
385        for i in 0..4 {
386            for j in 0..4 {
387                assert!((res[i][j] - out[i][j]).abs() < COMP_ACC)
388            }
389        }
390    }
391
392    #[test]
393    fn test_central_hessian_func() {
394        let hessian = central_hessian(&g);
395        let out = hessian(&x3()).unwrap();
396        let res = res2();
397        // println!("hessian:\n{:#?}", hessian);
398        // println!("diff:\n{:#?}", diff);
399        for i in 0..4 {
400            for j in 0..4 {
401                assert!((res[i][j] - out[i][j]).abs() < COMP_ACC)
402            }
403        }
404    }
405
406    #[test]
407    fn test_forward_hessian_vec_prod_func() {
408        let hessian = forward_hessian_vec_prod(&g);
409        let out = hessian(&x3(), &p2()).unwrap();
410        let res = [0.0, 6.0, 10.0, 18.0];
411        // println!("hessian:\n{:#?}", hessian);
412        // println!("diff:\n{:#?}", diff);
413        for i in 0..4 {
414            assert!((res[i] - out[i]).abs() < COMP_ACC)
415        }
416    }
417
418    #[test]
419    fn test_central_hessian_vec_prod_func() {
420        let hessian = central_hessian_vec_prod(&g);
421        let out = hessian(&x3(), &p2()).unwrap();
422        let res = [0.0, 6.0, 10.0, 18.0];
423        // println!("hessian:\n{:#?}", hessian);
424        // println!("diff:\n{:#?}", diff);
425        for i in 0..4 {
426            assert!((res[i] - out[i]).abs() < COMP_ACC)
427        }
428    }
429
430    #[test]
431    fn test_forward_hessian_nograd_func() {
432        let hessian = forward_hessian_nograd(&f3);
433        let out = hessian(&x3()).unwrap();
434        let res = res2();
435        // println!("hessian:\n{:#?}", hessian);
436        // println!("diff:\n{:#?}", diff);
437        for i in 0..4 {
438            for j in 0..4 {
439                assert!((res[i][j] - out[i][j]).abs() < COMP_ACC)
440            }
441        }
442    }
443
444    #[test]
445    fn test_forward_hessian_nograd_sparse_func() {
446        let indices = vec![[1, 1], [2, 3], [3, 3]];
447        let hessian = forward_hessian_nograd_sparse(&f3);
448        let out = hessian(&x3(), indices).unwrap();
449        let res = res2();
450        // println!("hessian:\n{:#?}", hessian);
451        // println!("diff:\n{:#?}", diff);
452        for i in 0..4 {
453            for j in 0..4 {
454                assert!((res[i][j] - out[i][j]).abs() < COMP_ACC)
455            }
456        }
457    }
458}