finitediff/
utils.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use std::ops::{Add, IndexMut};
9
10use anyhow::Error;
11use num::{Float, FromPrimitive};
12
13/// Panics when idx > x.len()
14#[inline(always)]
15pub fn mod_and_calc<F, C, T>(
16    x: &mut C,
17    f: &dyn Fn(&C) -> Result<T, Error>,
18    idx: usize,
19    y: F,
20) -> Result<T, Error>
21where
22    F: Add<Output = F> + Copy,
23    C: IndexMut<usize, Output = F>,
24{
25    let xtmp = x[idx];
26    x[idx] = xtmp + y;
27    let fx1 = (f)(x)?;
28    x[idx] = xtmp;
29    Ok(fx1)
30}
31
32/// Panics when idx > N
33#[inline(always)]
34pub fn mod_and_calc_const<const N: usize, F, T>(
35    x: &mut [F; N],
36    f: &dyn Fn(&[F; N]) -> Result<T, Error>,
37    idx: usize,
38    y: F,
39) -> Result<T, Error>
40where
41    F: Add<Output = F> + Copy,
42{
43    assert!(idx < N);
44    let xtmp = x[idx];
45    x[idx] = xtmp + y;
46    let fx1 = (f)(x)?;
47    x[idx] = xtmp;
48    Ok(fx1)
49}
50
51#[inline(always)]
52pub fn restore_symmetry_vec<F>(mut mat: Vec<Vec<F>>) -> Vec<Vec<F>>
53where
54    F: Float + FromPrimitive,
55{
56    for i in 0..mat.len() {
57        for j in (i + 1)..mat[i].len() {
58            let t = (mat[i][j] + mat[j][i]) / F::from_f64(2.0).unwrap();
59            mat[i][j] = t;
60            mat[j][i] = t;
61        }
62    }
63    mat
64}
65
66#[inline(always)]
67pub fn restore_symmetry_const<const N: usize, F>(mut mat: [[F; N]; N]) -> [[F; N]; N]
68where
69    F: Float + FromPrimitive,
70{
71    for i in 0..mat.len() {
72        for j in (i + 1)..mat[i].len() {
73            let t = (mat[i][j] + mat[j][i]) / F::from_f64(2.0).unwrap();
74            mat[i][j] = t;
75            mat[j][i] = t;
76        }
77    }
78    mat
79}
80
81/// Restore symmetry for an array of type `ndarray::Array2<f64>`
82///
83/// Unfortunately, this is *really* slow!
84#[cfg(feature = "ndarray")]
85#[inline(always)]
86pub fn restore_symmetry_ndarray<F>(mut mat: ndarray::Array2<F>) -> ndarray::Array2<F>
87where
88    F: Float + FromPrimitive,
89{
90    let (nx, ny) = mat.dim();
91    for i in 0..nx {
92        for j in (i + 1)..ny {
93            let t = (mat[(i, j)] + mat[(j, i)]) / F::from_f64(2.0).unwrap();
94            mat[(i, j)] = t;
95            mat[(j, i)] = t;
96        }
97    }
98    mat
99}
100
101pub struct KV<F> {
102    k: Vec<usize>,
103    v: Vec<F>,
104}
105
106impl<F: Copy> KV<F> {
107    pub fn new(capacity: usize) -> Self {
108        KV {
109            k: Vec::with_capacity(capacity),
110            v: Vec::with_capacity(capacity),
111        }
112    }
113
114    pub fn set(&mut self, k: usize, v: F) -> &mut Self {
115        self.k.push(k);
116        self.v.push(v);
117        self
118    }
119
120    pub fn get(&self, k: usize) -> Option<F> {
121        for (i, kk) in self.k.iter().enumerate() {
122            if *kk == k {
123                return Some(self.v[i]);
124            }
125            if *kk > k {
126                return None;
127            }
128        }
129        None
130    }
131}