1use std::ops::AddAssign;
9
10use anyhow::Error;
11use num::{Float, FromPrimitive};
12
13use crate::utils::{mod_and_calc, restore_symmetry_const, KV};
14
15use super::{CostFn, GradientFn};
16
17pub fn forward_hessian_const<const N: usize, F>(
18 x: &[F; N],
19 grad: GradientFn<'_, N, F>,
20) -> Result<[[F; N]; N], Error>
21where
22 F: Float + FromPrimitive,
23{
24 let eps_sqrt = F::epsilon().sqrt();
25 let fx = (grad)(x)?;
26 let mut xt = *x;
27 let mut out = [[F::from_f64(0.0).unwrap(); N]; N];
28 for (i, o_item) in out.iter_mut().enumerate().take(N) {
29 let fx1 = mod_and_calc(&mut xt, grad, i, eps_sqrt)?;
30 for j in 0..N {
31 o_item[j] = (fx1[j] - fx[j]) / eps_sqrt;
32 }
33 }
34
35 Ok(restore_symmetry_const(out))
37}
38
39pub fn central_hessian_const<const N: usize, F>(
40 x: &[F; N],
41 grad: GradientFn<'_, N, F>,
42) -> Result<[[F; N]; N], Error>
43where
44 F: Float + FromPrimitive,
45{
46 let eps_cbrt = F::epsilon().cbrt();
47 let mut xt = x.to_owned();
48 let mut out = [[F::from_f64(0.0).unwrap(); N]; N];
49
50 for (i, o_item) in out.iter_mut().enumerate().take(N) {
51 let fx1 = mod_and_calc(&mut xt, grad, i, eps_cbrt)?;
52 let fx2 = mod_and_calc(&mut xt, grad, i, -eps_cbrt)?;
53 for j in 0..N {
54 o_item[j] = (fx1[j] - fx2[j]) / (F::from_f64(2.0).unwrap() * eps_cbrt);
55 }
56 }
57
58 Ok(restore_symmetry_const(out))
60}
61
62pub fn forward_hessian_vec_prod_const<const N: usize, F>(
63 x: &[F; N],
64 grad: GradientFn<'_, N, F>,
65 p: &[F; N],
66) -> Result<[F; N], Error>
67where
68 F: Float + FromPrimitive,
69{
70 let eps_sqrt = F::epsilon().sqrt();
71 let fx = (grad)(x)?;
72 let mut out = [F::from_f64(0.0).unwrap(); N];
73
74 let mut x1 = *x;
75 for i in 1..N {
76 x1[i] = x[i] + p[i] * eps_sqrt;
77 }
78 let fx1 = (grad)(&x1)?;
79
80 for i in 0..N {
81 out[i] = (fx1[i] - fx[i]) / eps_sqrt;
82 }
83 Ok(out)
84}
85
86pub fn central_hessian_vec_prod_const<const N: usize, F>(
87 x: &[F; N],
88 grad: GradientFn<'_, N, F>,
89 p: &[F; N],
90) -> Result<[F; N], Error>
91where
92 F: Float + FromPrimitive,
93{
94 let eps_cbrt = F::epsilon().cbrt();
95 let mut x1 = *x;
96 let mut x2 = *x;
97 for i in 1..N {
98 x1[i] = x[i] + p[i] * eps_cbrt;
99 x2[i] = x[i] - p[i] * eps_cbrt;
100 }
101 let fx1 = (grad)(&x1)?;
102 let fx2 = (grad)(&x2)?;
103
104 let mut out = [F::from_f64(0.0).unwrap(); N];
105 for i in 0..N {
106 out[i] = (fx1[i] - fx2[i]) / (F::from_f64(2.0).unwrap() * eps_cbrt);
107 }
108 Ok(out)
109}
110
111pub fn forward_hessian_nograd_const<const N: usize, F>(
112 x: &[F; N],
113 f: CostFn<'_, N, F>,
114) -> Result<[[F; N]; N], Error>
115where
116 F: Float + FromPrimitive + AddAssign,
117{
118 let eps_nograd = F::from_f64(2.0).unwrap() * F::epsilon();
120 let eps_sqrt_nograd = eps_nograd.sqrt();
121
122 let fx = (f)(x)?;
123 let mut xt = *x;
124
125 let mut fxei = [F::from_f64(0.0).unwrap(); N];
127 for (i, item) in fxei.iter_mut().enumerate().take(N) {
128 *item = mod_and_calc(&mut xt, f, i, eps_sqrt_nograd)?;
129 }
130
131 let mut out = [[F::from_f64(0.0).unwrap(); N]; N];
132
133 for i in 0..N {
134 for j in 0..=i {
135 let t = {
136 let xti = xt[i];
137 let xtj = xt[j];
138 xt[i] += eps_sqrt_nograd;
139 xt[j] += eps_sqrt_nograd;
140 let fxij = (f)(&xt)?;
141 xt[i] = xti;
142 xt[j] = xtj;
143 (fxij - fxei[i] - fxei[j] + fx) / eps_nograd
144 };
145 out[i][j] = t;
146 out[j][i] = t;
147 }
148 }
149 Ok(out)
150}
151
152pub fn forward_hessian_nograd_sparse_const<const N: usize, F>(
153 x: &[F; N],
154 f: CostFn<'_, N, F>,
155 indices: Vec<[usize; 2]>,
156) -> Result<[[F; N]; N], Error>
157where
158 F: Float + FromPrimitive + AddAssign,
159{
160 let eps_nograd = F::from_f64(2.0).unwrap() * F::epsilon();
162 let eps_sqrt_nograd = eps_nograd.sqrt();
163
164 let fx = (f)(x)?;
165 let mut xt = *x;
166
167 let mut idxs: Vec<usize> = indices
168 .iter()
169 .flat_map(|i| i.iter())
170 .cloned()
171 .collect::<Vec<usize>>();
172 idxs.sort();
173 idxs.dedup();
174
175 let mut fxei = KV::new(idxs.len());
176
177 for idx in idxs.iter() {
178 fxei.set(*idx, mod_and_calc(&mut xt, f, *idx, eps_sqrt_nograd)?);
179 }
180
181 let mut out = [[F::from_f64(0.0).unwrap(); N]; N];
182 for [i, j] in indices {
183 let t = {
184 let xti = xt[i];
185 let xtj = xt[j];
186 xt[i] += eps_sqrt_nograd;
187 xt[j] += eps_sqrt_nograd;
188 let fxij = (f)(&xt)?;
189 xt[i] = xti;
190 xt[j] = xtj;
191
192 let fxi = fxei.get(i).ok_or(anyhow::anyhow!("Bug?"))?;
193 let fxj = fxei.get(j).ok_or(anyhow::anyhow!("Bug?"))?;
194 (fxij - fxi - fxj + fx) / eps_nograd
195 };
196 out[i][j] = t;
197 out[j][i] = t;
198 }
199 Ok(out)
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 const COMP_ACC: f64 = 1e-6;
207
208 fn f(x: &[f64; 4]) -> Result<f64, Error> {
209 Ok(x[0] + x[1].powi(2) + x[2] * x[3].powi(2))
210 }
211
212 fn g(x: &[f64; 4]) -> Result<[f64; 4], Error> {
213 Ok([1.0, 2.0 * x[1], x[3].powi(2), 2.0 * x[3] * x[2]])
214 }
215
216 fn x() -> [f64; 4] {
217 [1.0f64, 1.0, 1.0, 1.0]
218 }
219
220 fn p() -> [f64; 4] {
221 [2.0, 3.0, 4.0, 5.0]
222 }
223
224 fn res1() -> [[f64; 4]; 4] {
225 [
226 [0.0, 0.0, 0.0, 0.0],
227 [0.0, 2.0, 0.0, 0.0],
228 [0.0, 0.0, 0.0, 2.0],
229 [0.0, 0.0, 2.0, 2.0],
230 ]
231 }
232
233 fn res2() -> [f64; 4] {
234 [0.0, 6.0, 10.0, 18.0]
235 }
236
237 #[test]
238 fn test_forward_hessian_vec_f64() {
239 let hessian = forward_hessian_const(&x(), &g).unwrap();
240 let res = res1();
241 for i in 0..4 {
244 for j in 0..4 {
245 assert!((res[i][j] - hessian[i][j]).abs() < COMP_ACC)
246 }
247 }
248 }
249
250 #[test]
251 fn test_central_hessian_vec_f64() {
252 let hessian = central_hessian_const(&x(), &g).unwrap();
253 let res = res1();
254 for i in 0..4 {
257 for j in 0..4 {
258 assert!((res[i][j] - hessian[i][j]).abs() < COMP_ACC)
259 }
260 }
261 }
262
263 #[test]
264 fn test_forward_hessian_vec_prod_vec_f64() {
265 let hessian = forward_hessian_vec_prod_const(&x(), &g, &p()).unwrap();
266 let res = res2();
267 for i in 0..4 {
270 assert!((res[i] - hessian[i]).abs() < COMP_ACC)
271 }
272 }
273
274 #[test]
275 fn test_central_hessian_vec_prod_vec_f64() {
276 let hessian = central_hessian_vec_prod_const(&x(), &g, &p()).unwrap();
277 let res = res2();
278 for i in 0..4 {
281 assert!((res[i] - hessian[i]).abs() < COMP_ACC)
282 }
283 }
284
285 #[test]
286 fn test_forward_hessian_nograd_vec_f64() {
287 let hessian = forward_hessian_nograd_const(&x(), &f).unwrap();
288 let res = res1();
289 for i in 0..4 {
291 for j in 0..4 {
292 assert!((res[i][j] - hessian[i][j]).abs() < COMP_ACC)
293 }
294 }
295 }
296
297 #[test]
298 fn test_forward_hessian_nograd_sparse_vec_f64() {
299 let indices = vec![[1, 1], [2, 3], [3, 3]];
300 let hessian = forward_hessian_nograd_sparse_const(&x(), &f, indices).unwrap();
301 let res = res1();
302 for i in 0..4 {
305 for j in 0..4 {
306 assert!((res[i][j] - hessian[i][j]).abs() < COMP_ACC)
307 }
308 }
309 }
310}