use crate::core::{ArgminFloat, Error, IterState, Operator, Problem, Solver, State, KV};
use argmin_math::{ArgminConj, ArgminDot, ArgminL2Norm, ArgminMul, ArgminScaledAdd, ArgminSub};
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
#[derive(Clone)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct ConjugateGradient<P, F> {
b: P,
p: Option<P>,
p_prev: Option<P>,
rtr: F,
}
impl<P, F> ConjugateGradient<P, F>
where
F: ArgminFloat,
{
pub fn new(b: P) -> Self {
ConjugateGradient {
b,
p: None,
p_prev: None,
rtr: F::nan(),
}
}
pub fn get_prev_p(&self) -> Result<&P, Error> {
self.p_prev.as_ref().ok_or_else(argmin_error_closure!(
NotInitialized,
"Field `p_prev` of `ConjugateGradient` not initialized."
))
}
}
impl<P, O, R, F> Solver<O, IterState<P, (), (), (), R, F>> for ConjugateGradient<P, F>
where
O: Operator<Param = P, Output = P>,
P: Clone + ArgminDot<P, F> + ArgminSub<P, R> + ArgminScaledAdd<P, F, P> + ArgminConj,
R: ArgminMul<F, R> + ArgminMul<F, P> + ArgminConj + ArgminDot<R, F> + ArgminScaledAdd<P, F, R>,
F: ArgminFloat + ArgminL2Norm<F>,
{
fn name(&self) -> &str {
"Conjugate Gradient"
}
fn init(
&mut self,
problem: &mut Problem<O>,
state: IterState<P, (), (), (), R, F>,
) -> Result<(IterState<P, (), (), (), R, F>, Option<KV>), Error> {
let init_param = state.get_param().ok_or_else(argmin_error_closure!(
NotInitialized,
concat!(
"`ConjugateGradient` requires an initial parameter vector. ",
"Please provide an initial guess via `Executor`s `configure` method."
)
))?;
let ap = problem.apply(init_param)?;
let r0: R = self.b.sub(&ap).mul(&(float!(-1.0)));
self.p = Some(r0.mul(&(float!(-1.0))));
self.rtr = r0.dot(&r0.conj());
Ok((state.residuals(r0), None))
}
fn next_iter(
&mut self,
problem: &mut Problem<O>,
mut state: IterState<P, (), (), (), R, F>,
) -> Result<(IterState<P, (), (), (), R, F>, Option<KV>), Error> {
let p = self.p.take().ok_or_else(argmin_error_closure!(
PotentialBug,
"`ConjugateGradient`: Field `p` not set"
))?;
let r = state.take_residuals().ok_or_else(argmin_error_closure!(
PotentialBug,
"`ConjugateGradient`: Residuals in `state` not set"
))?;
let apk = problem.apply(&p)?;
let alpha = self.rtr.div(p.dot(&apk.conj()));
let state_param = state.get_param().ok_or_else(argmin_error_closure!(
PotentialBug,
"`ConjugateGradient`: Parameter vector in `state` not set"
))?;
let new_param = state_param.scaled_add(&alpha, &p);
let r = r.scaled_add(&alpha, &apk);
let rtr_n = r.dot(&r.conj());
let beta = rtr_n.div(self.rtr);
self.rtr = rtr_n;
let p_n = <R as ArgminMul<F, P>>::mul(&r, &(float!(-1.0))).scaled_add(&beta, &p);
let norm = r.dot(&r.conj()).l2_norm();
self.p = Some(p_n);
self.p_prev = Some(p);
Ok((
state.param(new_param).residuals(r).cost(norm),
Some(kv!("alpha" => alpha; "beta" => beta;)),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{test_utils::TestProblem, ArgminError};
use approx::assert_relative_eq;
test_trait_impl!(conjugate_gradient, ConjugateGradient<Vec<f64>, f64>);
#[test]
fn test_new() {
let cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![1.0f64, 2.0]);
let ConjugateGradient { b, p, p_prev, rtr } = cg;
assert_eq!(b[0].to_ne_bytes(), 1.0f64.to_ne_bytes());
assert_eq!(b[1].to_ne_bytes(), 2.0f64.to_ne_bytes());
assert!(p.is_none());
assert!(p_prev.is_none());
assert!(rtr.is_nan());
}
#[test]
fn test_get_prev_p_not_initialized() {
let cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![1.0f64, 2.0]);
let res: Result<_, _> = cg.get_prev_p();
assert_error!(
res,
ArgminError,
"Not initialized: \"Field `p_prev` of `ConjugateGradient` not initialized.\""
);
}
#[test]
fn test_get_prev_p() {
let mut cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![1.0f64, 2.0]);
cg.p_prev = Some(vec![3.0f64, 4.0]);
let res: Result<_, _> = cg.get_prev_p();
assert!(res.is_ok());
let p_prev = res.unwrap();
assert_eq!(p_prev[0].to_ne_bytes(), 3.0f64.to_ne_bytes());
assert_eq!(p_prev[1].to_ne_bytes(), 4.0f64.to_ne_bytes());
}
#[test]
fn test_init_param_not_initialized() {
let mut cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![1.0f64, 2.0]);
let res = cg.init(&mut Problem::new(TestProblem::new()), IterState::new());
assert_error!(
res,
ArgminError,
concat!(
"Not initialized: \"`ConjugateGradient` requires an initial parameter vector. ",
"Please provide an initial guess via `Executor`s `configure` method.\""
)
);
}
#[test]
fn test_init() {
let mut cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![1.0f64, 2.0]);
let state: IterState<Vec<f64>, (), (), (), Vec<f64>, f64> =
IterState::new().param(vec![3.0, 4.0]);
let (state_out, kv) = cg
.init(&mut Problem::new(TestProblem::new()), state.clone())
.unwrap();
assert!(kv.is_none());
let ConjugateGradient { b, p, p_prev, rtr } = cg;
assert_relative_eq!(b[0], 1.0, epsilon = f64::EPSILON);
assert_relative_eq!(b[1], 2.0, epsilon = f64::EPSILON);
let r0 = [2.0f64, 2.0];
assert_relative_eq!(
r0[0],
state_out.get_residuals().as_ref().unwrap()[0],
epsilon = f64::EPSILON
);
assert_relative_eq!(
r0[1],
state_out.get_residuals().as_ref().unwrap()[1],
epsilon = f64::EPSILON
);
let pp = [-2.0f64, -2.0];
assert_relative_eq!(pp[0], p.as_ref().unwrap()[0], epsilon = f64::EPSILON);
assert_relative_eq!(pp[1], p.as_ref().unwrap()[1], epsilon = f64::EPSILON);
assert_relative_eq!(rtr, 8.0, epsilon = f64::EPSILON);
assert!(p_prev.is_none());
}
#[test]
fn test_next_iter_p_not_set() {
let mut cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![1.0f64, 2.0]);
let state = IterState::new().param(vec![1.0f64]);
assert!(cg.p.is_none());
let res = cg.next_iter(&mut Problem::new(TestProblem::new()), state);
assert_error!(
res,
ArgminError,
concat!(
"Potential bug: \"`ConjugateGradient`: ",
"Field `p` not set\". This is potentially a bug. ",
"Please file a report on https://github.com/argmin-rs/argmin/issues"
)
);
}
#[test]
fn test_next_iter_r_not_set() {
let mut cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![1.0f64, 2.0]);
let state = IterState::new().param(vec![1.0f64]);
cg.p = Some(vec![]);
let res = cg.next_iter(&mut Problem::new(TestProblem::new()), state);
assert_error!(
res,
ArgminError,
concat!(
"Potential bug: \"`ConjugateGradient`: ",
"Residuals in `state` not set\". This is potentially a bug. ",
"Please file a report on https://github.com/argmin-rs/argmin/issues"
)
);
}
#[test]
fn test_next_iter_state_param_not_set() {
let mut cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![1.0f64, 2.0]);
let state = IterState::new().residuals(vec![]);
cg.p = Some(vec![]);
assert!(state.param.is_none());
let res = cg.next_iter(&mut Problem::new(TestProblem::new()), state);
assert_error!(
res,
ArgminError,
concat!(
"Potential bug: \"`ConjugateGradient`: ",
"Parameter vector in `state` not set\". This is potentially a bug. ",
"Please file a report on https://github.com/argmin-rs/argmin/issues"
)
);
}
#[test]
fn test_next_iter() {
let mut cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![2.0f64]);
let state = IterState::new().param(vec![1.0f64]);
let mut problem = Problem::new(TestProblem::new());
let (state, _) = cg.init(&mut problem, state).unwrap();
let rtr = cg.rtr;
let p = cg.p.clone().unwrap()[0];
let r = state.get_residuals().unwrap()[0];
let apk = p;
let alpha = rtr / (p * apk);
let new_param = 1.0 + alpha * p;
let r = r + alpha * apk;
let rtr_n = -r * r;
let beta = rtr_n / rtr;
let p_n = -r + beta * p;
let norm = (r * r).l2_norm();
let (state, kv) = cg.next_iter(&mut problem, state).unwrap();
assert!(kv.is_some());
assert_relative_eq!(r, state.get_residuals().unwrap()[0]);
assert_relative_eq!(p_n, cg.p.as_ref().unwrap()[0]);
assert_relative_eq!(p, cg.p_prev.as_ref().unwrap()[0]);
assert_relative_eq!(rtr_n, cg.rtr);
assert_relative_eq!(norm, state.get_cost());
assert_relative_eq!(new_param, state.get_param().unwrap()[0]);
}
}