use crate::core::{
ArgminFloat, CostFunction, Error, Executor, Gradient, IterState, LineSearch, NLCGBetaUpdate,
OptimizationResult, Problem, Solver, State, KV,
};
use argmin_math::{ArgminAdd, ArgminDot, ArgminL2Norm, ArgminMul};
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
#[derive(Clone)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct NonlinearConjugateGradient<P, L, B, F> {
p: Option<P>,
beta: F,
linesearch: L,
beta_method: B,
restart_iter: u64,
restart_orthogonality: Option<F>,
}
impl<P, L, B, F> NonlinearConjugateGradient<P, L, B, F>
where
F: ArgminFloat,
{
pub fn new(linesearch: L, beta_method: B) -> Self {
NonlinearConjugateGradient {
p: None,
beta: F::nan(),
linesearch,
beta_method,
restart_iter: u64::MAX,
restart_orthogonality: None,
}
}
#[must_use]
pub fn restart_iters(mut self, iters: u64) -> Self {
self.restart_iter = iters;
self
}
#[must_use]
pub fn restart_orthogonality(mut self, v: F) -> Self {
self.restart_orthogonality = Some(v);
self
}
}
impl<O, P, G, L, B, F> Solver<O, IterState<P, G, (), (), (), F>>
for NonlinearConjugateGradient<P, L, B, F>
where
O: CostFunction<Param = P, Output = F> + Gradient<Param = P, Gradient = G>,
P: Clone + ArgminAdd<P, P> + ArgminMul<F, P>,
G: Clone + ArgminMul<F, P> + ArgminDot<G, F> + ArgminL2Norm<F>,
L: Clone + LineSearch<P, F> + Solver<O, IterState<P, G, (), (), (), F>>,
B: NLCGBetaUpdate<G, P, F>,
F: ArgminFloat,
{
fn name(&self) -> &str {
"Nonlinear Conjugate Gradient"
}
fn init(
&mut self,
problem: &mut Problem<O>,
state: IterState<P, G, (), (), (), F>,
) -> Result<(IterState<P, G, (), (), (), F>, Option<KV>), Error> {
let param = state.get_param().ok_or_else(argmin_error_closure!(
NotInitialized,
concat!(
"`NonlinearConjugateGradient` requires an initial parameter vector. ",
"Please provide an initial guess via `Executor`s `configure` method."
)
))?;
let cost = problem.cost(param)?;
let grad = problem.gradient(param)?;
self.p = Some(grad.mul(&(float!(-1.0))));
Ok((state.cost(cost).gradient(grad), None))
}
fn next_iter(
&mut self,
problem: &mut Problem<O>,
mut state: IterState<P, G, (), (), (), F>,
) -> Result<(IterState<P, G, (), (), (), F>, Option<KV>), Error> {
let p = self.p.as_ref().ok_or_else(argmin_error_closure!(
PotentialBug,
"`NonlinearConjugateGradient`: Field `p` not set"
))?;
let xk = state.take_param().ok_or_else(argmin_error_closure!(
PotentialBug,
"`NonlinearConjugateGradient`: No `param` in `state`"
))?;
let grad = state
.take_gradient()
.map(Result::Ok)
.unwrap_or_else(|| problem.gradient(&xk))?;
let cur_cost = state.cost;
self.linesearch.search_direction(p.clone());
let OptimizationResult {
problem: line_problem,
state: mut line_state,
..
} = Executor::new(
problem.take_problem().ok_or_else(argmin_error_closure!(
PotentialBug,
"`NonlinearConjugateGradient`: Failed to take `problem` for line search"
))?,
self.linesearch.clone(),
)
.configure(|state| state.param(xk).gradient(grad.clone()).cost(cur_cost))
.ctrlc(false)
.run()?;
problem.consume_problem(line_problem);
let xk1 = line_state.take_param().ok_or_else(argmin_error_closure!(
PotentialBug,
"`NonlinearConjugateGradient`: No `param` returned by line search"
))?;
let new_grad = problem.gradient(&xk1)?;
let restart_orthogonality = match self.restart_orthogonality {
Some(v) => new_grad.dot(&grad).abs() / new_grad.l2_norm().powi(2) >= v,
None => false,
};
let restart_iter: bool =
(state.get_iter() % self.restart_iter == 0) && state.get_iter() != 0;
if restart_iter || restart_orthogonality {
self.beta = float!(0.0);
} else {
self.beta = self.beta_method.update(&grad, &new_grad, p);
}
self.p = Some(new_grad.mul(&(float!(-1.0))).add(&p.mul(&self.beta)));
let cost = problem.cost(&xk1)?;
Ok((
state.param(xk1).cost(cost).gradient(new_grad),
Some(kv!("beta" => self.beta;
"restart_iter" => restart_iter;
"restart_orthogonality" => restart_orthogonality;
)),
))
}
}
#[cfg(test)]
#[allow(clippy::let_unit_value)]
mod tests {
use super::*;
use crate::core::test_utils::TestProblem;
use crate::core::ArgminError;
use crate::solver::conjugategradient::beta::PolakRibiere;
use crate::solver::linesearch::{
condition::ArmijoCondition, BacktrackingLineSearch, MoreThuenteLineSearch,
};
use approx::assert_relative_eq;
#[derive(Eq, PartialEq, Clone, Copy, Debug)]
struct Linesearch {}
#[derive(Eq, PartialEq, Clone, Copy, Debug)]
struct BetaUpdate {}
test_trait_impl!(
nonlinear_cg,
NonlinearConjugateGradient<
TestProblem,
MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64>,
PolakRibiere,
f64
>
);
#[test]
fn test_new() {
let linesearch = Linesearch {};
let beta_method = BetaUpdate {};
let nlcg: NonlinearConjugateGradient<Vec<f64>, _, _, f64> =
NonlinearConjugateGradient::new(linesearch, beta_method);
let NonlinearConjugateGradient {
p,
beta,
linesearch,
beta_method,
restart_iter,
restart_orthogonality,
} = nlcg;
assert!(p.is_none());
assert!(beta.is_nan());
assert_eq!(linesearch, linesearch);
assert_eq!(beta_method, beta_method);
assert_eq!(restart_iter, u64::MAX);
assert!(restart_orthogonality.is_none());
}
#[test]
fn test_restart_iters() {
let linesearch = ();
let beta_method = ();
let nlcg: NonlinearConjugateGradient<Vec<f64>, _, _, f64> =
NonlinearConjugateGradient::new(linesearch, beta_method);
assert_eq!(nlcg.restart_iter, u64::MAX);
let nlcg = nlcg.restart_iters(100);
assert_eq!(nlcg.restart_iter, 100);
}
#[test]
fn test_restart_orthogonality() {
let linesearch = ();
let beta_method = ();
let nlcg: NonlinearConjugateGradient<Vec<f64>, _, _, f64> =
NonlinearConjugateGradient::new(linesearch, beta_method);
assert!(nlcg.restart_orthogonality.is_none());
let nlcg = nlcg.restart_orthogonality(0.1);
assert_eq!(
nlcg.restart_orthogonality.as_ref().unwrap().to_ne_bytes(),
0.1f64.to_ne_bytes()
);
}
#[test]
fn test_init_param_not_initialized() {
let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
let beta_method = PolakRibiere::new();
let mut nlcg: NonlinearConjugateGradient<Vec<f64>, _, _, f64> =
NonlinearConjugateGradient::new(linesearch, beta_method);
let res = nlcg.init(&mut Problem::new(TestProblem::new()), IterState::new());
assert_error!(
res,
ArgminError,
concat!(
"Not initialized: \"`NonlinearConjugateGradient` requires an initial parameter vector. ",
"Please provide an initial guess via `Executor`s `configure` method.\""
)
);
}
#[test]
fn test_init() {
let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
let beta_method = PolakRibiere::new();
let mut nlcg: NonlinearConjugateGradient<Vec<f64>, _, _, f64> =
NonlinearConjugateGradient::new(linesearch, beta_method);
let state: IterState<Vec<f64>, Vec<f64>, (), (), (), f64> =
IterState::new().param(vec![3.0, 4.0]);
let (state_out, kv) = nlcg
.init(&mut Problem::new(TestProblem::new()), state.clone())
.unwrap();
assert!(kv.is_none());
assert_ne!(state_out, state);
assert_eq!(state_out.cost.to_ne_bytes(), 1f64.to_ne_bytes());
assert_eq!(
state_out.grad.as_ref().unwrap()[0].to_ne_bytes(),
3f64.to_ne_bytes()
);
assert_eq!(
state_out.grad.as_ref().unwrap()[1].to_ne_bytes(),
4f64.to_ne_bytes()
);
assert_eq!(
state_out.param.as_ref().unwrap()[0].to_ne_bytes(),
3f64.to_ne_bytes()
);
assert_eq!(
state_out.param.as_ref().unwrap()[1].to_ne_bytes(),
4f64.to_ne_bytes()
);
assert_eq!(
nlcg.p.as_ref().unwrap()[0].to_ne_bytes(),
(-3f64).to_ne_bytes()
);
assert_eq!(
nlcg.p.as_ref().unwrap()[1].to_ne_bytes(),
(-4f64).to_ne_bytes()
);
}
#[test]
fn test_next_iter_p_not_set() {
let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
let beta_method = PolakRibiere::new();
let mut nlcg: NonlinearConjugateGradient<Vec<f64>, _, _, f64> =
NonlinearConjugateGradient::new(linesearch, beta_method);
let state = IterState::new().param(vec![1.0f64, 2.0f64]);
assert!(nlcg.p.is_none());
let res = nlcg.next_iter(&mut Problem::new(TestProblem::new()), state);
assert_error!(
res,
ArgminError,
concat!(
"Potential bug: \"`NonlinearConjugateGradient`: ",
"Field `p` not set\". This is potentially a bug. ",
"Please file a report on https://github.com/argmin-rs/argmin/issues"
)
);
}
#[test]
fn test_next_iter_state_param_not_set() {
let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
let beta_method = PolakRibiere::new();
let mut nlcg: NonlinearConjugateGradient<Vec<f64>, _, _, f64> =
NonlinearConjugateGradient::new(linesearch, beta_method);
let state = IterState::new();
nlcg.p = Some(vec![]);
assert!(nlcg.p.is_some());
let res = nlcg.next_iter(&mut Problem::new(TestProblem::new()), state);
assert_error!(
res,
ArgminError,
concat!(
"Potential bug: \"`NonlinearConjugateGradient`: ",
"No `param` in `state`\". This is potentially a bug. ",
"Please file a report on https://github.com/argmin-rs/argmin/issues"
)
);
}
#[test]
fn test_next_iter_problem_missing() {
let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
let beta_method = PolakRibiere::new();
let mut nlcg: NonlinearConjugateGradient<Vec<f64>, _, _, f64> =
NonlinearConjugateGradient::new(linesearch, beta_method);
let state = IterState::new()
.param(vec![1.0f64, 2.0])
.gradient(vec![1.0f64, 2.0]);
nlcg.p = Some(vec![]);
assert!(nlcg.p.is_some());
let mut problem = Problem::new(TestProblem::new());
let _ = problem.take_problem().unwrap();
let res = nlcg.next_iter(&mut problem, state);
assert_error!(
res,
ArgminError,
concat!(
"Potential bug: \"`NonlinearConjugateGradient`: ",
"Failed to take `problem` for line search\". This is potentially a bug. ",
"Please file a report on https://github.com/argmin-rs/argmin/issues"
)
);
}
#[test]
fn test_next_iter() {
let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
let beta_method = PolakRibiere::new();
let mut nlcg: NonlinearConjugateGradient<Vec<f64>, _, _, f64> =
NonlinearConjugateGradient::new(linesearch, beta_method);
let state = IterState::new()
.param(vec![1.0f64, 2.0])
.gradient(vec![1.0f64, 2.0]);
let mut problem = Problem::new(TestProblem::new());
let (state, kv) = nlcg.init(&mut problem, state).unwrap();
assert!(kv.is_none());
let (mut state, kv) = nlcg.next_iter(&mut problem, state).unwrap();
state.update();
let kv2 = kv!("beta" => 0.0; "restart_iter" => false; "restart_orthogonality" => false;);
assert_eq!(kv.unwrap(), kv2);
assert_relative_eq!(
state.param.as_ref().unwrap()[0],
1.0f64,
epsilon = f64::EPSILON
);
assert_relative_eq!(
state.param.as_ref().unwrap()[1],
2.0f64,
epsilon = f64::EPSILON
);
assert_relative_eq!(
state.best_param.as_ref().unwrap()[0],
1.0f64,
epsilon = f64::EPSILON
);
assert_relative_eq!(
state.best_param.as_ref().unwrap()[1],
2.0f64,
epsilon = f64::EPSILON
);
assert_relative_eq!(state.cost, 1.0f64, epsilon = f64::EPSILON);
assert_relative_eq!(state.prev_cost, 1.0f64, epsilon = f64::EPSILON);
assert_relative_eq!(state.best_cost, 1.0f64, epsilon = f64::EPSILON);
assert_relative_eq!(
state.grad.as_ref().unwrap()[0],
1.0f64,
epsilon = f64::EPSILON
);
assert_relative_eq!(
state.grad.as_ref().unwrap()[1],
2.0f64,
epsilon = f64::EPSILON
);
}
}