1use std::ops::AddAssign;
9
10use anyhow::Error;
11use num::{Float, FromPrimitive};
12
13use crate::pert::PerturbationVectors;
14use crate::utils::{mod_and_calc, mod_and_calc_const};
15
16use super::OpFn;
17
18pub fn forward_jacobian_const<const N: usize, const M: usize, F>(
19 x: &[F; N],
20 fs: OpFn<'_, N, M, F>,
21) -> Result<[[F; N]; M], Error>
22where
23 F: Float + FromPrimitive,
24{
25 let fx = (fs)(x)?;
26 let mut xt = *x;
27 let eps_sqrt = F::epsilon().sqrt();
28 let mut out = [[F::from_f64(0.0).unwrap(); N]; M];
29
30 for i in 0..N {
31 let fx1 = mod_and_calc_const(&mut xt, fs, i, eps_sqrt)?;
32
33 for j in 0..M {
34 out[j][i] = (fx1[j] - fx[j]) / eps_sqrt;
35 }
36 }
37 Ok(out)
38}
39
40pub fn central_jacobian_const<const N: usize, const M: usize, F>(
41 x: &[F; N],
42 fs: OpFn<'_, N, M, F>,
43) -> Result<[[F; N]; M], Error>
44where
45 F: Float + FromPrimitive,
46{
47 let mut xt = *x;
48 let eps_cbrt = F::epsilon().cbrt();
49 let mut out = [[F::from_f64(0.0).unwrap(); N]; M];
50 for i in 0..M {
51 let fx1 = mod_and_calc(&mut xt, fs, i, eps_cbrt)?;
52 let fx2 = mod_and_calc(&mut xt, fs, i, -eps_cbrt)?;
53
54 for j in 0..M {
55 out[j][i] = (fx1[j] - fx2[j]) / (F::from_f64(2.0).unwrap() * eps_cbrt);
56 }
57 }
58 Ok(out)
59}
60
61pub fn forward_jacobian_vec_prod_const<const N: usize, const M: usize, F>(
62 x: &[F; N],
63 fs: OpFn<'_, N, M, F>,
64 p: &[F; N],
65) -> Result<[F; M], Error>
66where
67 F: Float + FromPrimitive,
68{
69 let fx = (fs)(x)?;
70 let eps_sqrt = F::epsilon().sqrt();
71 let mut x1 = [F::from_f64(0.0).unwrap(); N];
72 x1.iter_mut()
73 .enumerate()
74 .map(|(i, o)| *o = x[i] + eps_sqrt * p[i])
75 .count();
76
77 let fx1 = (fs)(&x1)?;
78 let mut out = [F::from_f64(0.0).unwrap(); M];
79 out.iter_mut()
80 .enumerate()
81 .map(|(i, o)| {
82 *o = (fx1[i] - fx[i]) / eps_sqrt;
83 })
84 .count();
85 Ok(out)
86}
87
88pub fn central_jacobian_vec_prod_const<const N: usize, const M: usize, F>(
89 x: &[F; N],
90 fs: OpFn<'_, N, M, F>,
91 p: &[F; N],
92) -> Result<[F; M], Error>
93where
94 F: Float + FromPrimitive,
95{
96 let eps_cbrt = F::epsilon().cbrt();
97 let mut x1 = [F::from_f64(0.0).unwrap(); N];
98 let mut x2 = [F::from_f64(0.0).unwrap(); N];
99 x1.iter_mut()
100 .zip(x2.iter_mut())
101 .enumerate()
102 .map(|(i, (x1, x2))| {
103 let tmp = eps_cbrt * p[i];
104 *x1 = x[i] + tmp;
105 *x2 = x[i] - tmp;
106 })
107 .count();
108 let fx1 = (fs)(&x1)?;
109 let fx2 = (fs)(&x2)?;
110 let mut out = [F::from_f64(0.0).unwrap(); M];
111 out.iter_mut()
112 .enumerate()
113 .map(|(i, o)| {
114 *o = (fx1[i] - fx2[i]) / (F::from_f64(2.0).unwrap() * eps_cbrt);
115 })
116 .count();
117 Ok(out)
118}
119
120pub fn forward_jacobian_pert_const<const N: usize, const M: usize, F>(
121 x: &[F; N],
122 fs: OpFn<'_, N, M, F>,
123 pert: &PerturbationVectors,
124) -> Result<[[F; N]; M], Error>
125where
126 F: Float + FromPrimitive + AddAssign,
127{
128 let fx = (fs)(x)?;
129 let eps_sqrt = F::epsilon().sqrt();
130 let mut xt = *x;
131 let mut out = [[F::from_f64(0.0).unwrap(); N]; M];
132 for pert_item in pert.iter() {
133 for j in pert_item.x_idx.iter() {
134 xt[*j] += eps_sqrt;
135 }
136
137 let fx1 = (fs)(&xt)?;
138
139 for j in pert_item.x_idx.iter() {
140 xt[*j] = x[*j];
141 }
142
143 for (k, x_idx) in pert_item.x_idx.iter().enumerate() {
144 for j in pert_item.r_idx[k].iter() {
145 out[*j][*x_idx] = (fx1[*j] - fx[*j]) / eps_sqrt;
146 }
147 }
148 }
149 Ok(out)
150}
151
152pub fn central_jacobian_pert_const<const N: usize, const M: usize, F>(
153 x: &[F; N],
154 fs: OpFn<'_, N, M, F>,
155 pert: &PerturbationVectors,
156) -> Result<[[F; N]; M], Error>
157where
158 F: Float + FromPrimitive + AddAssign,
159{
160 let eps_cbrt = F::epsilon().cbrt();
161 let mut xt = *x;
162 let mut out = [[F::from_f64(0.0).unwrap(); N]; M];
163 for pert_item in pert.iter() {
164 for j in pert_item.x_idx.iter() {
165 xt[*j] += eps_cbrt;
166 }
167
168 let fx1 = (fs)(&xt)?;
169
170 for j in pert_item.x_idx.iter() {
171 xt[*j] = x[*j] - eps_cbrt;
172 }
173
174 let fx2 = (fs)(&xt)?;
175
176 for j in pert_item.x_idx.iter() {
177 xt[*j] = x[*j];
178 }
179
180 for (k, x_idx) in pert_item.x_idx.iter().enumerate() {
181 for j in pert_item.r_idx[k].iter() {
182 out[*j][*x_idx] = (fx1[*j] - fx2[*j]) / (F::from_f64(2.0).unwrap() * eps_cbrt);
183 }
184 }
185 }
186 Ok(out)
187}
188
189#[cfg(test)]
190mod tests {
191 use crate::PerturbationVector;
192
193 use super::*;
194
195 const COMP_ACC: f64 = 1e-6;
196
197 fn f(x: &[f64; 6]) -> Result<[f64; 6], Error> {
198 Ok([
199 2.0 * (x[1].powi(3) - x[0].powi(2)),
200 3.0 * (x[1].powi(3) - x[0].powi(2)) + 2.0 * (x[2].powi(3) - x[1].powi(2)),
201 3.0 * (x[2].powi(3) - x[1].powi(2)) + 2.0 * (x[3].powi(3) - x[2].powi(2)),
202 3.0 * (x[3].powi(3) - x[2].powi(2)) + 2.0 * (x[4].powi(3) - x[3].powi(2)),
203 3.0 * (x[4].powi(3) - x[3].powi(2)) + 2.0 * (x[5].powi(3) - x[4].powi(2)),
204 3.0 * (x[5].powi(3) - x[4].powi(2)),
205 ])
206 }
207
208 fn res1() -> [[f64; 6]; 6] {
209 [
210 [-4.0, 6.0, 0.0, 0.0, 0.0, 0.0],
211 [-6.0, 5.0, 6.0, 0.0, 0.0, 0.0],
212 [0.0, -6.0, 5.0, 6.0, 0.0, 0.0],
213 [0.0, 0.0, -6.0, 5.0, 6.0, 0.0],
214 [0.0, 0.0, 0.0, -6.0, 5.0, 6.0],
215 [0.0, 0.0, 0.0, 0.0, -6.0, 9.0],
216 ]
217 }
218
219 fn res2() -> [f64; 6] {
220 [8.0, 22.0, 27.0, 32.0, 37.0, 24.0]
221 }
222
223 fn x() -> [f64; 6] {
224 [1.0f64, 1.0, 1.0, 1.0, 1.0, 1.0]
225 }
226
227 fn p() -> [f64; 6] {
228 [1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0]
229 }
230
231 fn pert() -> PerturbationVectors {
232 vec![
233 PerturbationVector::new()
234 .add(0, vec![0, 1])
235 .add(3, vec![2, 3, 4]),
236 PerturbationVector::new()
237 .add(1, vec![0, 1, 2])
238 .add(4, vec![3, 4, 5]),
239 PerturbationVector::new()
240 .add(2, vec![1, 2, 3])
241 .add(5, vec![4, 5]),
242 ]
243 }
244
245 #[test]
246 fn test_forward_jacobian_const_f64() {
247 let jacobian = forward_jacobian_const(&x(), &f).unwrap();
248 let res = res1();
249 for i in 0..6 {
251 for j in 0..6 {
252 assert!((res[i][j] - jacobian[i][j]).abs() < COMP_ACC)
253 }
254 }
255 }
256
257 #[test]
258 fn test_central_jacobian_const_f64() {
259 let jacobian = central_jacobian_const(&x(), &f).unwrap();
260 let res = res1();
261 println!("{:?}", jacobian);
262 for i in 0..6 {
263 for j in 0..6 {
264 assert!((res[i][j] - jacobian[i][j]).abs() < COMP_ACC);
265 }
266 }
267 }
268
269 #[test]
270 fn test_forward_jacobian_vec_prod_const_f64() {
271 let jacobian = forward_jacobian_vec_prod_const(&x(), &f, &p()).unwrap();
272 let res = res2();
273 for i in 0..6 {
276 assert!((res[i] - jacobian[i]).abs() < 11.0 * COMP_ACC)
277 }
278 }
279
280 #[test]
281 fn test_central_jacobian_vec_prod_const_f64() {
282 let jacobian = central_jacobian_vec_prod_const(&x(), &f, &p()).unwrap();
283 let res = res2();
284 for i in 0..6 {
286 assert!((res[i] - jacobian[i]).abs() < COMP_ACC)
287 }
288 }
289
290 #[test]
291 fn test_forward_jacobian_pert_const_f64() {
292 let jacobian = forward_jacobian_pert_const(&x(), &f, &pert()).unwrap();
293 let res = res1();
294 for i in 0..6 {
297 for j in 0..6 {
298 assert!((res[i][j] - jacobian[i][j]).abs() < COMP_ACC)
299 }
300 }
301 }
302
303 #[test]
304 fn test_central_jacobian_pert_const_f64() {
305 let jacobian = central_jacobian_pert_const(&x(), &f, &pert()).unwrap();
306 let res = res1();
307 for i in 0..6 {
310 for j in 0..6 {
311 assert!((res[i][j] - jacobian[i][j]).abs() < COMP_ACC)
312 }
313 }
314 }
315}