finitediff/array/
mod.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8// Some types just are complex
9#![allow(clippy::type_complexity)]
10
11pub mod diff;
12pub mod hessian;
13pub mod jacobian;
14
15use std::ops::AddAssign;
16
17use anyhow::Error;
18use num::{Float, FromPrimitive};
19
20use crate::PerturbationVectors;
21use diff::{central_diff_const, forward_diff_const};
22use hessian::{
23    central_hessian_const, central_hessian_vec_prod_const, forward_hessian_const,
24    forward_hessian_nograd_const, forward_hessian_nograd_sparse_const,
25    forward_hessian_vec_prod_const,
26};
27use jacobian::{
28    central_jacobian_const, central_jacobian_pert_const, central_jacobian_vec_prod_const,
29    forward_jacobian_const, forward_jacobian_pert_const, forward_jacobian_vec_prod_const,
30};
31
32pub(crate) type CostFn<'a, const N: usize, F> = &'a dyn Fn(&[F; N]) -> Result<F, Error>;
33pub(crate) type GradientFn<'a, const N: usize, F> = &'a dyn Fn(&[F; N]) -> Result<[F; N], Error>;
34pub(crate) type OpFn<'a, const N: usize, const M: usize, F> =
35    &'a dyn Fn(&[F; N]) -> Result<[F; M], Error>;
36
37#[inline(always)]
38pub fn forward_diff<const N: usize, F>(
39    f: CostFn<'_, N, F>,
40) -> impl Fn(&[F; N]) -> Result<[F; N], Error> + '_
41where
42    F: Float + FromPrimitive,
43{
44    move |p: &[F; N]| forward_diff_const(p, &f)
45}
46
47#[inline(always)]
48pub fn central_diff<const N: usize, F>(
49    f: CostFn<'_, N, F>,
50) -> impl Fn(&[F; N]) -> Result<[F; N], Error> + '_
51where
52    F: Float + FromPrimitive,
53{
54    move |p: &[F; N]| central_diff_const(p, &f)
55}
56
57#[inline(always)]
58pub fn forward_jacobian<const N: usize, const M: usize, F>(
59    f: OpFn<'_, N, M, F>,
60) -> impl Fn(&[F; N]) -> Result<[[F; N]; M], Error> + '_
61where
62    F: Float + FromPrimitive,
63{
64    move |p: &[F; N]| forward_jacobian_const(p, &f)
65}
66
67#[inline(always)]
68pub fn central_jacobian<const N: usize, const M: usize, F>(
69    f: OpFn<'_, N, M, F>,
70) -> impl Fn(&[F; N]) -> Result<[[F; N]; M], Error> + '_
71where
72    F: Float + FromPrimitive,
73{
74    move |p: &[F; N]| central_jacobian_const(p, &f)
75}
76
77#[inline(always)]
78pub fn forward_jacobian_vec_prod<const N: usize, const M: usize, F>(
79    f: OpFn<'_, N, M, F>,
80) -> impl Fn(&[F; N], &[F; N]) -> Result<[F; M], Error> + '_
81where
82    F: Float + FromPrimitive,
83{
84    move |p: &[F; N], v: &[F; N]| forward_jacobian_vec_prod_const(p, f, v)
85}
86
87#[inline(always)]
88pub fn central_jacobian_vec_prod<const N: usize, const M: usize, F>(
89    f: OpFn<'_, N, M, F>,
90) -> impl Fn(&[F; N], &[F; N]) -> Result<[F; M], Error> + '_
91where
92    F: Float + FromPrimitive,
93{
94    move |p: &[F; N], v: &[F; N]| central_jacobian_vec_prod_const(p, f, v)
95}
96
97#[inline(always)]
98pub fn forward_jacobian_pert<const N: usize, const M: usize, F>(
99    f: OpFn<'_, N, M, F>,
100) -> impl Fn(&[F; N], &PerturbationVectors) -> Result<[[F; N]; M], Error> + '_
101where
102    F: Float + FromPrimitive + AddAssign,
103{
104    move |p: &[F; N], pert: &PerturbationVectors| forward_jacobian_pert_const(p, f, pert)
105}
106
107#[inline(always)]
108pub fn central_jacobian_pert<const N: usize, const M: usize, F>(
109    f: OpFn<'_, N, M, F>,
110) -> impl Fn(&[F; N], &PerturbationVectors) -> Result<[[F; N]; M], Error> + '_
111where
112    F: Float + FromPrimitive + AddAssign,
113{
114    move |p: &[F; N], pert: &PerturbationVectors| central_jacobian_pert_const(p, f, pert)
115}
116
117#[inline(always)]
118pub fn forward_hessian<const N: usize, F>(
119    f: GradientFn<'_, N, F>,
120) -> impl Fn(&[F; N]) -> Result<[[F; N]; N], Error> + '_
121where
122    F: Float + FromPrimitive,
123{
124    move |p: &[F; N]| forward_hessian_const(p, f)
125}
126
127#[inline(always)]
128pub fn central_hessian<const N: usize, F>(
129    f: GradientFn<'_, N, F>,
130) -> impl Fn(&[F; N]) -> Result<[[F; N]; N], Error> + '_
131where
132    F: Float + FromPrimitive,
133{
134    move |p: &[F; N]| central_hessian_const(p, f)
135}
136
137#[inline(always)]
138pub fn forward_hessian_vec_prod<const N: usize, F>(
139    f: GradientFn<'_, N, F>,
140) -> impl Fn(&[F; N], &[F; N]) -> Result<[F; N], Error> + '_
141where
142    F: Float + FromPrimitive,
143{
144    move |p: &[F; N], v: &[F; N]| forward_hessian_vec_prod_const(p, f, v)
145}
146
147#[inline(always)]
148pub fn central_hessian_vec_prod<const N: usize, F>(
149    f: GradientFn<'_, N, F>,
150) -> impl Fn(&[F; N], &[F; N]) -> Result<[F; N], Error> + '_
151where
152    F: Float + FromPrimitive,
153{
154    move |p: &[F; N], v: &[F; N]| central_hessian_vec_prod_const(p, f, v)
155}
156
157#[inline(always)]
158pub fn forward_hessian_nograd<const N: usize, F>(
159    f: CostFn<'_, N, F>,
160) -> impl Fn(&[F; N]) -> Result<[[F; N]; N], Error> + '_
161where
162    F: Float + FromPrimitive + AddAssign,
163{
164    move |p: &[F; N]| forward_hessian_nograd_const(p, f)
165}
166
167#[inline(always)]
168pub fn forward_hessian_nograd_sparse<const N: usize, F>(
169    f: CostFn<'_, N, F>,
170) -> impl Fn(&[F; N], Vec<[usize; 2]>) -> Result<[[F; N]; N], Error> + '_
171where
172    F: Float + FromPrimitive + AddAssign,
173{
174    move |p: &[F; N], indices: Vec<[usize; 2]>| forward_hessian_nograd_sparse_const(p, f, indices)
175}
176
177#[cfg(test)]
178mod tests {
179    use crate::{PerturbationVector, PerturbationVectors};
180
181    use super::*;
182
183    const COMP_ACC: f64 = 1e-6;
184
185    fn f1(x: &[f64; 2]) -> Result<f64, Error> {
186        Ok(x[0] + x[1].powi(2))
187    }
188
189    fn f2(x: &[f64; 6]) -> Result<[f64; 6], Error> {
190        Ok([
191            2.0 * (x[1].powi(3) - x[0].powi(2)),
192            3.0 * (x[1].powi(3) - x[0].powi(2)) + 2.0 * (x[2].powi(3) - x[1].powi(2)),
193            3.0 * (x[2].powi(3) - x[1].powi(2)) + 2.0 * (x[3].powi(3) - x[2].powi(2)),
194            3.0 * (x[3].powi(3) - x[2].powi(2)) + 2.0 * (x[4].powi(3) - x[3].powi(2)),
195            3.0 * (x[4].powi(3) - x[3].powi(2)) + 2.0 * (x[5].powi(3) - x[4].powi(2)),
196            3.0 * (x[5].powi(3) - x[4].powi(2)),
197        ])
198    }
199
200    fn f3(x: &[f64; 4]) -> Result<f64, Error> {
201        Ok(x[0] + x[1].powi(2) + x[2] * x[3].powi(2))
202    }
203
204    fn g(x: &[f64; 4]) -> Result<[f64; 4], Error> {
205        Ok([1.0, 2.0 * x[1], x[3].powi(2), 2.0 * x[3] * x[2]])
206    }
207
208    fn x1() -> [f64; 2] {
209        [1.0f64, 1.0f64]
210    }
211
212    fn x2() -> [f64; 6] {
213        [1.0f64, 1.0, 1.0, 1.0, 1.0, 1.0]
214    }
215
216    fn x3() -> [f64; 4] {
217        [1.0f64, 1.0, 1.0, 1.0]
218    }
219
220    fn res1() -> [[f64; 6]; 6] {
221        [
222            [-4.0, 6.0, 0.0, 0.0, 0.0, 0.0],
223            [-6.0, 5.0, 6.0, 0.0, 0.0, 0.0],
224            [0.0, -6.0, 5.0, 6.0, 0.0, 0.0],
225            [0.0, 0.0, -6.0, 5.0, 6.0, 0.0],
226            [0.0, 0.0, 0.0, -6.0, 5.0, 6.0],
227            [0.0, 0.0, 0.0, 0.0, -6.0, 9.0],
228        ]
229    }
230
231    fn res2() -> [[f64; 4]; 4] {
232        [
233            [0.0, 0.0, 0.0, 0.0],
234            [0.0, 2.0, 0.0, 0.0],
235            [0.0, 0.0, 0.0, 2.0],
236            [0.0, 0.0, 2.0, 2.0],
237        ]
238    }
239
240    fn res3() -> [f64; 6] {
241        [8.0, 22.0, 27.0, 32.0, 37.0, 24.0]
242    }
243
244    fn pert() -> PerturbationVectors {
245        vec![
246            PerturbationVector::new()
247                .add(0, vec![0, 1])
248                .add(3, vec![2, 3, 4]),
249            PerturbationVector::new()
250                .add(1, vec![0, 1, 2])
251                .add(4, vec![3, 4, 5]),
252            PerturbationVector::new()
253                .add(2, vec![1, 2, 3])
254                .add(5, vec![4, 5]),
255        ]
256    }
257
258    fn p1() -> [f64; 6] {
259        [1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0]
260    }
261
262    fn p2() -> [f64; 4] {
263        [2.0, 3.0, 4.0, 5.0]
264    }
265
266    #[test]
267    fn test_forward_diff_func() {
268        let grad = forward_diff(&f1);
269        let out = grad(&x1()).unwrap();
270        let res = [1.0, 2.0];
271
272        for i in 0..2 {
273            assert!((res[i] - out[i]).abs() < COMP_ACC)
274        }
275
276        let p = [1.0, 2.0];
277        let grad = forward_diff(&f1);
278        let out = grad(&p).unwrap();
279        let res = [1.0, 4.0];
280
281        for i in 0..2 {
282            assert!((res[i] - out[i]).abs() < COMP_ACC)
283        }
284    }
285
286    #[test]
287    fn test_central_diff_func() {
288        let grad = central_diff(&f1);
289        let out = grad(&x1()).unwrap();
290        let res = [1.0f64, 2.0];
291
292        for i in 0..2 {
293            assert!((res[i] - out[i]).abs() < COMP_ACC)
294        }
295
296        let p = [1.0f64, 2.0f64];
297        let grad = central_diff(&f1);
298        let out = grad(&p).unwrap();
299        let res = [1.0f64, 4.0];
300
301        for i in 0..2 {
302            assert!((res[i] - out[i]).abs() < COMP_ACC)
303        }
304    }
305
306    #[test]
307    fn test_forward_jacobian_func() {
308        let jacobian = forward_jacobian(&f2);
309        let out = jacobian(&x2()).unwrap();
310        let res = res1();
311        // println!("{:?}", out);
312        // println!("{:?}", res);
313        for i in 0..6 {
314            for j in 0..6 {
315                assert!((res[i][j] - out[i][j]).abs() < COMP_ACC)
316            }
317        }
318    }
319
320    #[test]
321    fn test_central_jacobian_vec_f64_trait() {
322        let jacobian = central_jacobian(&f2);
323        let out = jacobian(&x2()).unwrap();
324        let res = res1();
325        // println!("{:?}", jacobian);
326        for i in 0..6 {
327            for j in 0..6 {
328                assert!((res[i][j] - out[i][j]).abs() < COMP_ACC)
329            }
330        }
331    }
332
333    #[test]
334    fn test_forward_jacobian_vec_prod_vec_func() {
335        let jacobian = forward_jacobian_vec_prod(&f2);
336        let out = jacobian(&x2(), &p1()).unwrap();
337        let res = res3();
338        // println!("{:?}", jacobian);
339        // the accuracy for this is pretty bad!!
340        for i in 0..6 {
341            assert!((res[i] - out[i]).abs() < 5.5 * COMP_ACC)
342        }
343    }
344
345    #[test]
346    fn test_central_jacobian_vec_prod_vec_func() {
347        let jacobian = central_jacobian_vec_prod(&f2);
348        let out = jacobian(&x2(), &p1()).unwrap();
349        let res = res3();
350        // println!("{:?}", jacobian);
351        for i in 0..6 {
352            assert!((res[i] - out[i]).abs() < COMP_ACC)
353        }
354    }
355
356    #[test]
357    fn test_forward_jacobian_pert_func() {
358        let jacobian = forward_jacobian_pert(&f2);
359        let out = jacobian(&x2(), &pert()).unwrap();
360        let res = res1();
361        // println!("jacobian:\n{:?}", jacobian);
362        // println!("res:\n{:?}", res);
363        for i in 0..6 {
364            for j in 0..6 {
365                assert!((res[i][j] - out[i][j]).abs() < COMP_ACC)
366            }
367        }
368    }
369
370    #[test]
371    fn test_central_jacobian_pert_func() {
372        let jacobian = central_jacobian_pert(&f2);
373        let out = jacobian(&x2(), &pert()).unwrap();
374        let res = res1();
375        // println!("jacobian:\n{:?}", jacobian);
376        // println!("res:\n{:?}", res);
377        for i in 0..6 {
378            for j in 0..6 {
379                assert!((res[i][j] - out[i][j]).abs() < COMP_ACC)
380            }
381        }
382    }
383
384    #[test]
385    fn test_forward_hessian_func() {
386        let hessian = forward_hessian(&g);
387        let out = hessian(&x3()).unwrap();
388        let res = res2();
389        // println!("hessian:\n{:#?}", hessian);
390        // println!("diff:\n{:#?}", diff);
391        for i in 0..4 {
392            for j in 0..4 {
393                assert!((res[i][j] - out[i][j]).abs() < COMP_ACC)
394            }
395        }
396    }
397
398    #[test]
399    fn test_central_hessian_func() {
400        let hessian = central_hessian(&g);
401        let out = hessian(&x3()).unwrap();
402        let res = res2();
403        // println!("hessian:\n{:#?}", hessian);
404        // println!("diff:\n{:#?}", diff);
405        for i in 0..4 {
406            for j in 0..4 {
407                assert!((res[i][j] - out[i][j]).abs() < COMP_ACC)
408            }
409        }
410    }
411
412    #[test]
413    fn test_forward_hessian_vec_prod_func() {
414        let hessian = forward_hessian_vec_prod(&g);
415        let out = hessian(&x3(), &p2()).unwrap();
416        let res = [0.0, 6.0, 10.0, 18.0];
417        // println!("hessian:\n{:#?}", hessian);
418        // println!("diff:\n{:#?}", diff);
419        for i in 0..4 {
420            assert!((res[i] - out[i]).abs() < COMP_ACC)
421        }
422    }
423
424    #[test]
425    fn test_central_hessian_vec_prod_func() {
426        let hessian = central_hessian_vec_prod(&g);
427        let out = hessian(&x3(), &p2()).unwrap();
428        let res = [0.0, 6.0, 10.0, 18.0];
429        // println!("hessian:\n{:#?}", hessian);
430        // println!("diff:\n{:#?}", diff);
431        for i in 0..4 {
432            assert!((res[i] - out[i]).abs() < COMP_ACC)
433        }
434    }
435
436    #[test]
437    fn test_forward_hessian_nograd_func() {
438        let hessian = forward_hessian_nograd(&f3);
439        let out = hessian(&x3()).unwrap();
440        let res = res2();
441        // println!("hessian:\n{:#?}", hessian);
442        // println!("diff:\n{:#?}", diff);
443        for i in 0..4 {
444            for j in 0..4 {
445                assert!((res[i][j] - out[i][j]).abs() < COMP_ACC)
446            }
447        }
448    }
449
450    #[test]
451    fn test_forward_hessian_nograd_sparse_func() {
452        let indices = vec![[1, 1], [2, 3], [3, 3]];
453        let hessian = forward_hessian_nograd_sparse(&f3);
454        let out = hessian(&x3(), indices).unwrap();
455        let res = res2();
456        // println!("hessian:\n{:#?}", hessian);
457        // println!("diff:\n{:#?}", diff);
458        for i in 0..4 {
459            for j in 0..4 {
460                assert!((res[i][j] - out[i][j]).abs() < COMP_ACC)
461            }
462        }
463    }
464}