argmin/solver/gradientdescent/
steepestdescent.rsuse crate::core::{
ArgminFloat, CostFunction, Error, Executor, Gradient, IterState, LineSearch,
OptimizationResult, Problem, Solver, State, KV,
};
use argmin_math::ArgminMul;
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
#[derive(Clone)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct SteepestDescent<L> {
linesearch: L,
}
impl<L> SteepestDescent<L> {
pub fn new(linesearch: L) -> Self {
SteepestDescent { linesearch }
}
}
impl<O, L, P, G, F> Solver<O, IterState<P, G, (), (), (), F>> for SteepestDescent<L>
where
O: CostFunction<Param = P, Output = F> + Gradient<Param = P, Gradient = G>,
P: Clone,
G: Clone + ArgminMul<F, G>,
L: Clone + LineSearch<G, F> + Solver<O, IterState<P, G, (), (), (), F>>,
F: ArgminFloat,
{
fn name(&self) -> &str {
"Steepest Descent"
}
fn next_iter(
&mut self,
problem: &mut Problem<O>,
state: IterState<P, G, (), (), (), F>,
) -> Result<(IterState<P, G, (), (), (), F>, Option<KV>), Error> {
let param_new = state
.get_param()
.ok_or_else(argmin_error_closure!(
NotInitialized,
concat!(
"`SteepestDescent` requires an initial parameter vector. ",
"Please provide an initial guess via `Executor`s `configure` method."
)
))?
.clone();
let new_cost = problem.cost(¶m_new)?;
let new_grad = problem.gradient(¶m_new)?;
self.linesearch
.search_direction(new_grad.mul(&(float!(-1.0))));
let OptimizationResult {
problem: line_problem,
state: mut linesearch_state,
..
} = Executor::new(
problem.take_problem().ok_or_else(argmin_error_closure!(
PotentialBug,
"`SteepestDescent`: Failed to take `problem` for line search"
))?,
self.linesearch.clone(),
)
.configure(|config| config.param(param_new).gradient(new_grad).cost(new_cost))
.ctrlc(false)
.run()?;
problem.consume_problem(line_problem);
Ok((
state
.param(
linesearch_state
.take_param()
.ok_or_else(argmin_error_closure!(
PotentialBug,
"`GradientDescent`: No `param` returned by line search"
))?,
)
.cost(linesearch_state.get_cost()),
None,
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::test_utils::TestProblem;
use crate::core::ArgminError;
use crate::solver::linesearch::{
condition::ArmijoCondition, BacktrackingLineSearch, MoreThuenteLineSearch,
};
use approx::assert_relative_eq;
test_trait_impl!(
steepest_descent,
SteepestDescent<MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64>>
);
#[test]
fn test_new() {
let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
let SteepestDescent { linesearch: ls } = SteepestDescent::new(linesearch.clone());
assert_eq!(ls, linesearch);
}
#[test]
fn test_next_iter_param_not_initialized() {
let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
let mut sd = SteepestDescent::new(linesearch);
let res = sd.next_iter(&mut Problem::new(TestProblem::new()), IterState::new());
assert_error!(
res,
ArgminError,
concat!(
"Not initialized: \"`SteepestDescent` requires an initial parameter vector. ",
"Please provide an initial guess via `Executor`s `configure` method.\""
)
);
}
#[test]
fn test_next_iter_prev_param_not_erased() {
let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
let mut sd = SteepestDescent::new(linesearch);
let (state, _kv) = sd
.next_iter(
&mut Problem::new(TestProblem::new()),
IterState::new().param(vec![1.0, 2.0]),
)
.unwrap();
state.prev_param.unwrap();
}
#[test]
fn test_next_iter_regression() {
struct SDProblem {}
impl CostFunction for SDProblem {
type Param = Vec<f64>;
type Output = f64;
fn cost(&self, p: &Self::Param) -> Result<Self::Output, Error> {
Ok(p[0].powi(2) + p[1].powi(2))
}
}
impl Gradient for SDProblem {
type Param = Vec<f64>;
type Gradient = Vec<f64>;
fn gradient(&self, p: &Self::Param) -> Result<Self::Param, Error> {
Ok(vec![2.0 * p[0], 2.0 * p[1]])
}
}
let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
let mut sd = SteepestDescent::new(linesearch);
let (state, kv) = sd
.next_iter(
&mut Problem::new(SDProblem {}),
IterState::new().param(vec![1.0, 2.0]),
)
.unwrap();
assert!(kv.is_none());
assert_relative_eq!(
state.param.as_ref().unwrap()[0],
-0.4580000000000002,
epsilon = f64::EPSILON
);
assert_relative_eq!(
state.param.as_ref().unwrap()[1],
-0.9160000000000004,
epsilon = f64::EPSILON
);
assert_relative_eq!(state.cost, 1.048820000000001, epsilon = f64::EPSILON);
}
}