1#![allow(clippy::type_complexity)]
10
11pub mod diff;
12pub mod hessian;
13pub mod jacobian;
14
15use std::ops::AddAssign;
16
17use anyhow::Error;
18use num::{Float, FromPrimitive};
19
20use crate::PerturbationVectors;
21use diff::{central_diff_const, forward_diff_const};
22use hessian::{
23 central_hessian_const, central_hessian_vec_prod_const, forward_hessian_const,
24 forward_hessian_nograd_const, forward_hessian_nograd_sparse_const,
25 forward_hessian_vec_prod_const,
26};
27use jacobian::{
28 central_jacobian_const, central_jacobian_pert_const, central_jacobian_vec_prod_const,
29 forward_jacobian_const, forward_jacobian_pert_const, forward_jacobian_vec_prod_const,
30};
31
32pub(crate) type CostFn<'a, const N: usize, F> = &'a dyn Fn(&[F; N]) -> Result<F, Error>;
33pub(crate) type GradientFn<'a, const N: usize, F> = &'a dyn Fn(&[F; N]) -> Result<[F; N], Error>;
34pub(crate) type OpFn<'a, const N: usize, const M: usize, F> =
35 &'a dyn Fn(&[F; N]) -> Result<[F; M], Error>;
36
37#[inline(always)]
38pub fn forward_diff<const N: usize, F>(
39 f: CostFn<'_, N, F>,
40) -> impl Fn(&[F; N]) -> Result<[F; N], Error> + '_
41where
42 F: Float + FromPrimitive,
43{
44 move |p: &[F; N]| forward_diff_const(p, &f)
45}
46
47#[inline(always)]
48pub fn central_diff<const N: usize, F>(
49 f: CostFn<'_, N, F>,
50) -> impl Fn(&[F; N]) -> Result<[F; N], Error> + '_
51where
52 F: Float + FromPrimitive,
53{
54 move |p: &[F; N]| central_diff_const(p, &f)
55}
56
57#[inline(always)]
58pub fn forward_jacobian<const N: usize, const M: usize, F>(
59 f: OpFn<'_, N, M, F>,
60) -> impl Fn(&[F; N]) -> Result<[[F; N]; M], Error> + '_
61where
62 F: Float + FromPrimitive,
63{
64 move |p: &[F; N]| forward_jacobian_const(p, &f)
65}
66
67#[inline(always)]
68pub fn central_jacobian<const N: usize, const M: usize, F>(
69 f: OpFn<'_, N, M, F>,
70) -> impl Fn(&[F; N]) -> Result<[[F; N]; M], Error> + '_
71where
72 F: Float + FromPrimitive,
73{
74 move |p: &[F; N]| central_jacobian_const(p, &f)
75}
76
77#[inline(always)]
78pub fn forward_jacobian_vec_prod<const N: usize, const M: usize, F>(
79 f: OpFn<'_, N, M, F>,
80) -> impl Fn(&[F; N], &[F; N]) -> Result<[F; M], Error> + '_
81where
82 F: Float + FromPrimitive,
83{
84 move |p: &[F; N], v: &[F; N]| forward_jacobian_vec_prod_const(p, f, v)
85}
86
87#[inline(always)]
88pub fn central_jacobian_vec_prod<const N: usize, const M: usize, F>(
89 f: OpFn<'_, N, M, F>,
90) -> impl Fn(&[F; N], &[F; N]) -> Result<[F; M], Error> + '_
91where
92 F: Float + FromPrimitive,
93{
94 move |p: &[F; N], v: &[F; N]| central_jacobian_vec_prod_const(p, f, v)
95}
96
97#[inline(always)]
98pub fn forward_jacobian_pert<const N: usize, const M: usize, F>(
99 f: OpFn<'_, N, M, F>,
100) -> impl Fn(&[F; N], &PerturbationVectors) -> Result<[[F; N]; M], Error> + '_
101where
102 F: Float + FromPrimitive + AddAssign,
103{
104 move |p: &[F; N], pert: &PerturbationVectors| forward_jacobian_pert_const(p, f, pert)
105}
106
107#[inline(always)]
108pub fn central_jacobian_pert<const N: usize, const M: usize, F>(
109 f: OpFn<'_, N, M, F>,
110) -> impl Fn(&[F; N], &PerturbationVectors) -> Result<[[F; N]; M], Error> + '_
111where
112 F: Float + FromPrimitive + AddAssign,
113{
114 move |p: &[F; N], pert: &PerturbationVectors| central_jacobian_pert_const(p, f, pert)
115}
116
117#[inline(always)]
118pub fn forward_hessian<const N: usize, F>(
119 f: GradientFn<'_, N, F>,
120) -> impl Fn(&[F; N]) -> Result<[[F; N]; N], Error> + '_
121where
122 F: Float + FromPrimitive,
123{
124 move |p: &[F; N]| forward_hessian_const(p, f)
125}
126
127#[inline(always)]
128pub fn central_hessian<const N: usize, F>(
129 f: GradientFn<'_, N, F>,
130) -> impl Fn(&[F; N]) -> Result<[[F; N]; N], Error> + '_
131where
132 F: Float + FromPrimitive,
133{
134 move |p: &[F; N]| central_hessian_const(p, f)
135}
136
137#[inline(always)]
138pub fn forward_hessian_vec_prod<const N: usize, F>(
139 f: GradientFn<'_, N, F>,
140) -> impl Fn(&[F; N], &[F; N]) -> Result<[F; N], Error> + '_
141where
142 F: Float + FromPrimitive,
143{
144 move |p: &[F; N], v: &[F; N]| forward_hessian_vec_prod_const(p, f, v)
145}
146
147#[inline(always)]
148pub fn central_hessian_vec_prod<const N: usize, F>(
149 f: GradientFn<'_, N, F>,
150) -> impl Fn(&[F; N], &[F; N]) -> Result<[F; N], Error> + '_
151where
152 F: Float + FromPrimitive,
153{
154 move |p: &[F; N], v: &[F; N]| central_hessian_vec_prod_const(p, f, v)
155}
156
157#[inline(always)]
158pub fn forward_hessian_nograd<const N: usize, F>(
159 f: CostFn<'_, N, F>,
160) -> impl Fn(&[F; N]) -> Result<[[F; N]; N], Error> + '_
161where
162 F: Float + FromPrimitive + AddAssign,
163{
164 move |p: &[F; N]| forward_hessian_nograd_const(p, f)
165}
166
167#[inline(always)]
168pub fn forward_hessian_nograd_sparse<const N: usize, F>(
169 f: CostFn<'_, N, F>,
170) -> impl Fn(&[F; N], Vec<[usize; 2]>) -> Result<[[F; N]; N], Error> + '_
171where
172 F: Float + FromPrimitive + AddAssign,
173{
174 move |p: &[F; N], indices: Vec<[usize; 2]>| forward_hessian_nograd_sparse_const(p, f, indices)
175}
176
177#[cfg(test)]
178mod tests {
179 use crate::{PerturbationVector, PerturbationVectors};
180
181 use super::*;
182
183 const COMP_ACC: f64 = 1e-6;
184
185 fn f1(x: &[f64; 2]) -> Result<f64, Error> {
186 Ok(x[0] + x[1].powi(2))
187 }
188
189 fn f2(x: &[f64; 6]) -> Result<[f64; 6], Error> {
190 Ok([
191 2.0 * (x[1].powi(3) - x[0].powi(2)),
192 3.0 * (x[1].powi(3) - x[0].powi(2)) + 2.0 * (x[2].powi(3) - x[1].powi(2)),
193 3.0 * (x[2].powi(3) - x[1].powi(2)) + 2.0 * (x[3].powi(3) - x[2].powi(2)),
194 3.0 * (x[3].powi(3) - x[2].powi(2)) + 2.0 * (x[4].powi(3) - x[3].powi(2)),
195 3.0 * (x[4].powi(3) - x[3].powi(2)) + 2.0 * (x[5].powi(3) - x[4].powi(2)),
196 3.0 * (x[5].powi(3) - x[4].powi(2)),
197 ])
198 }
199
200 fn f3(x: &[f64; 4]) -> Result<f64, Error> {
201 Ok(x[0] + x[1].powi(2) + x[2] * x[3].powi(2))
202 }
203
204 fn g(x: &[f64; 4]) -> Result<[f64; 4], Error> {
205 Ok([1.0, 2.0 * x[1], x[3].powi(2), 2.0 * x[3] * x[2]])
206 }
207
208 fn x1() -> [f64; 2] {
209 [1.0f64, 1.0f64]
210 }
211
212 fn x2() -> [f64; 6] {
213 [1.0f64, 1.0, 1.0, 1.0, 1.0, 1.0]
214 }
215
216 fn x3() -> [f64; 4] {
217 [1.0f64, 1.0, 1.0, 1.0]
218 }
219
220 fn res1() -> [[f64; 6]; 6] {
221 [
222 [-4.0, 6.0, 0.0, 0.0, 0.0, 0.0],
223 [-6.0, 5.0, 6.0, 0.0, 0.0, 0.0],
224 [0.0, -6.0, 5.0, 6.0, 0.0, 0.0],
225 [0.0, 0.0, -6.0, 5.0, 6.0, 0.0],
226 [0.0, 0.0, 0.0, -6.0, 5.0, 6.0],
227 [0.0, 0.0, 0.0, 0.0, -6.0, 9.0],
228 ]
229 }
230
231 fn res2() -> [[f64; 4]; 4] {
232 [
233 [0.0, 0.0, 0.0, 0.0],
234 [0.0, 2.0, 0.0, 0.0],
235 [0.0, 0.0, 0.0, 2.0],
236 [0.0, 0.0, 2.0, 2.0],
237 ]
238 }
239
240 fn res3() -> [f64; 6] {
241 [8.0, 22.0, 27.0, 32.0, 37.0, 24.0]
242 }
243
244 fn pert() -> PerturbationVectors {
245 vec![
246 PerturbationVector::new()
247 .add(0, vec![0, 1])
248 .add(3, vec![2, 3, 4]),
249 PerturbationVector::new()
250 .add(1, vec![0, 1, 2])
251 .add(4, vec![3, 4, 5]),
252 PerturbationVector::new()
253 .add(2, vec![1, 2, 3])
254 .add(5, vec![4, 5]),
255 ]
256 }
257
258 fn p1() -> [f64; 6] {
259 [1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0]
260 }
261
262 fn p2() -> [f64; 4] {
263 [2.0, 3.0, 4.0, 5.0]
264 }
265
266 #[test]
267 fn test_forward_diff_func() {
268 let grad = forward_diff(&f1);
269 let out = grad(&x1()).unwrap();
270 let res = [1.0, 2.0];
271
272 for i in 0..2 {
273 assert!((res[i] - out[i]).abs() < COMP_ACC)
274 }
275
276 let p = [1.0, 2.0];
277 let grad = forward_diff(&f1);
278 let out = grad(&p).unwrap();
279 let res = [1.0, 4.0];
280
281 for i in 0..2 {
282 assert!((res[i] - out[i]).abs() < COMP_ACC)
283 }
284 }
285
286 #[test]
287 fn test_central_diff_func() {
288 let grad = central_diff(&f1);
289 let out = grad(&x1()).unwrap();
290 let res = [1.0f64, 2.0];
291
292 for i in 0..2 {
293 assert!((res[i] - out[i]).abs() < COMP_ACC)
294 }
295
296 let p = [1.0f64, 2.0f64];
297 let grad = central_diff(&f1);
298 let out = grad(&p).unwrap();
299 let res = [1.0f64, 4.0];
300
301 for i in 0..2 {
302 assert!((res[i] - out[i]).abs() < COMP_ACC)
303 }
304 }
305
306 #[test]
307 fn test_forward_jacobian_func() {
308 let jacobian = forward_jacobian(&f2);
309 let out = jacobian(&x2()).unwrap();
310 let res = res1();
311 for i in 0..6 {
314 for j in 0..6 {
315 assert!((res[i][j] - out[i][j]).abs() < COMP_ACC)
316 }
317 }
318 }
319
320 #[test]
321 fn test_central_jacobian_vec_f64_trait() {
322 let jacobian = central_jacobian(&f2);
323 let out = jacobian(&x2()).unwrap();
324 let res = res1();
325 for i in 0..6 {
327 for j in 0..6 {
328 assert!((res[i][j] - out[i][j]).abs() < COMP_ACC)
329 }
330 }
331 }
332
333 #[test]
334 fn test_forward_jacobian_vec_prod_vec_func() {
335 let jacobian = forward_jacobian_vec_prod(&f2);
336 let out = jacobian(&x2(), &p1()).unwrap();
337 let res = res3();
338 for i in 0..6 {
341 assert!((res[i] - out[i]).abs() < 5.5 * COMP_ACC)
342 }
343 }
344
345 #[test]
346 fn test_central_jacobian_vec_prod_vec_func() {
347 let jacobian = central_jacobian_vec_prod(&f2);
348 let out = jacobian(&x2(), &p1()).unwrap();
349 let res = res3();
350 for i in 0..6 {
352 assert!((res[i] - out[i]).abs() < COMP_ACC)
353 }
354 }
355
356 #[test]
357 fn test_forward_jacobian_pert_func() {
358 let jacobian = forward_jacobian_pert(&f2);
359 let out = jacobian(&x2(), &pert()).unwrap();
360 let res = res1();
361 for i in 0..6 {
364 for j in 0..6 {
365 assert!((res[i][j] - out[i][j]).abs() < COMP_ACC)
366 }
367 }
368 }
369
370 #[test]
371 fn test_central_jacobian_pert_func() {
372 let jacobian = central_jacobian_pert(&f2);
373 let out = jacobian(&x2(), &pert()).unwrap();
374 let res = res1();
375 for i in 0..6 {
378 for j in 0..6 {
379 assert!((res[i][j] - out[i][j]).abs() < COMP_ACC)
380 }
381 }
382 }
383
384 #[test]
385 fn test_forward_hessian_func() {
386 let hessian = forward_hessian(&g);
387 let out = hessian(&x3()).unwrap();
388 let res = res2();
389 for i in 0..4 {
392 for j in 0..4 {
393 assert!((res[i][j] - out[i][j]).abs() < COMP_ACC)
394 }
395 }
396 }
397
398 #[test]
399 fn test_central_hessian_func() {
400 let hessian = central_hessian(&g);
401 let out = hessian(&x3()).unwrap();
402 let res = res2();
403 for i in 0..4 {
406 for j in 0..4 {
407 assert!((res[i][j] - out[i][j]).abs() < COMP_ACC)
408 }
409 }
410 }
411
412 #[test]
413 fn test_forward_hessian_vec_prod_func() {
414 let hessian = forward_hessian_vec_prod(&g);
415 let out = hessian(&x3(), &p2()).unwrap();
416 let res = [0.0, 6.0, 10.0, 18.0];
417 for i in 0..4 {
420 assert!((res[i] - out[i]).abs() < COMP_ACC)
421 }
422 }
423
424 #[test]
425 fn test_central_hessian_vec_prod_func() {
426 let hessian = central_hessian_vec_prod(&g);
427 let out = hessian(&x3(), &p2()).unwrap();
428 let res = [0.0, 6.0, 10.0, 18.0];
429 for i in 0..4 {
432 assert!((res[i] - out[i]).abs() < COMP_ACC)
433 }
434 }
435
436 #[test]
437 fn test_forward_hessian_nograd_func() {
438 let hessian = forward_hessian_nograd(&f3);
439 let out = hessian(&x3()).unwrap();
440 let res = res2();
441 for i in 0..4 {
444 for j in 0..4 {
445 assert!((res[i][j] - out[i][j]).abs() < COMP_ACC)
446 }
447 }
448 }
449
450 #[test]
451 fn test_forward_hessian_nograd_sparse_func() {
452 let indices = vec![[1, 1], [2, 3], [3, 3]];
453 let hessian = forward_hessian_nograd_sparse(&f3);
454 let out = hessian(&x3(), indices).unwrap();
455 let res = res2();
456 for i in 0..4 {
459 for j in 0..4 {
460 assert!((res[i][j] - out[i][j]).abs() < COMP_ACC)
461 }
462 }
463 }
464}