#![allow(clippy::nonminimal_bool)]
use crate::core::{
ArgminFloat, CostFunction, Error, Gradient, IterState, LineSearch, Problem, Solver, State,
TerminationReason, KV,
};
use argmin_math::{ArgminDot, ArgminScaledAdd};
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
#[derive(Clone)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct MoreThuenteLineSearch<P, G, F> {
search_direction: Option<G>,
init_param: Option<P>,
finit: F,
init_grad: Option<G>,
dginit: F,
dgtest: F,
ftol: F,
gtol: F,
xtrapf: F,
width: F,
width1: F,
xtol: F,
alpha: F,
stpmin: F,
stpmax: F,
stp: Step<F>,
stx: Step<F>,
sty: Step<F>,
f: F,
brackt: bool,
stage1: bool,
infoc: usize,
}
#[derive(Clone, Eq, PartialEq, Debug)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
struct Step<F> {
pub x: F,
pub fx: F,
pub gx: F,
}
impl<F> Step<F> {
pub fn new(x: F, fx: F, gx: F) -> Self {
Step { x, fx, gx }
}
}
impl<F> Default for Step<F>
where
F: ArgminFloat,
{
fn default() -> Self {
Step {
x: float!(0.0),
fx: float!(0.0),
gx: float!(0.0),
}
}
}
impl<P, G, F> MoreThuenteLineSearch<P, G, F>
where
F: ArgminFloat,
{
pub fn new() -> Self {
MoreThuenteLineSearch {
search_direction: None,
init_param: None,
finit: F::infinity(),
init_grad: None,
dginit: float!(0.0),
dgtest: float!(0.0),
ftol: float!(1e-4),
gtol: float!(0.9),
xtrapf: float!(4.0),
width: F::nan(),
width1: F::nan(),
xtol: float!(1e-10),
alpha: float!(1.0),
stpmin: F::epsilon().sqrt(),
stpmax: F::infinity(),
stp: Step::default(),
stx: Step::default(),
sty: Step::default(),
f: F::nan(),
brackt: false,
stage1: true,
infoc: 1,
}
}
pub fn with_c(mut self, c1: F, c2: F) -> Result<Self, Error> {
if c1 <= float!(0.0) || c1 >= c2 {
return Err(argmin_error!(
InvalidParameter,
"`MoreThuenteLineSearch`: Parameter c1 must be in (0, c2)."
));
}
if c2 <= c1 || c2 >= float!(1.0) {
return Err(argmin_error!(
InvalidParameter,
"`MoreThuenteLineSearch`: Parameter c2 must be in (c1, 1)."
));
}
self.ftol = c1;
self.gtol = c2;
Ok(self)
}
pub fn with_bounds(mut self, step_min: F, step_max: F) -> Result<Self, Error> {
if step_min < float!(0.0) {
return Err(argmin_error!(
InvalidParameter,
"`MoreThuenteLineSearch`: step_min must be >= 0.0."
));
}
if step_max <= step_min {
return Err(argmin_error!(
InvalidParameter,
"`MoreThuenteLineSearch`: step_min must be smaller than step_max."
));
}
self.stpmin = step_min;
self.stpmax = step_max;
Ok(self)
}
pub fn with_width_tolerance(mut self, xtol: F) -> Result<Self, Error> {
if xtol < float!(0.0) {
return Err(argmin_error!(
InvalidParameter,
"`MoreThuenteLineSearch`: relative width tolerance must be >= 0.0."
));
}
self.xtol = xtol;
Ok(self)
}
}
impl<P, G, F> Default for MoreThuenteLineSearch<P, G, F>
where
F: ArgminFloat,
{
fn default() -> Self {
MoreThuenteLineSearch::new()
}
}
impl<P, G, F> LineSearch<G, F> for MoreThuenteLineSearch<P, G, F>
where
F: ArgminFloat,
{
fn search_direction(&mut self, search_direction: G) {
self.search_direction = Some(search_direction);
}
fn initial_step_length(&mut self, alpha: F) -> Result<(), Error> {
if alpha <= float!(0.0) {
return Err(argmin_error!(
InvalidParameter,
"MoreThuenteLineSearch: Initial alpha must be > 0."
));
}
self.alpha = alpha;
Ok(())
}
}
impl<P, G, O, F> Solver<O, IterState<P, G, (), (), (), F>> for MoreThuenteLineSearch<P, G, F>
where
O: CostFunction<Param = P, Output = F> + Gradient<Param = P, Gradient = G>,
P: Clone + ArgminDot<G, F> + ArgminScaledAdd<G, F, P>,
G: Clone + ArgminDot<G, F>,
F: ArgminFloat,
{
fn name(&self) -> &str {
"More-Thuente Line search"
}
fn init(
&mut self,
problem: &mut Problem<O>,
mut state: IterState<P, G, (), (), (), F>,
) -> Result<(IterState<P, G, (), (), (), F>, Option<KV>), Error> {
check_param!(
self.search_direction,
concat!(
"`MoreThuenteLineSearch`: Search direction not initialized. ",
"Call `search_direction` before executing the solver."
)
);
self.init_param = Some(state.take_param().ok_or_else(argmin_error_closure!(
NotInitialized,
concat!(
"`MoreThuenteLineSearch` requires an initial parameter vector. ",
"Please provide an initial guess via `Executor`s `configure` method."
)
))?);
let cost = state.get_cost();
self.finit = if cost.is_infinite() {
problem.cost(self.init_param.as_ref().unwrap())?
} else {
cost
};
self.init_grad = Some(
state
.take_gradient()
.map(Result::Ok)
.unwrap_or_else(|| problem.gradient(self.init_param.as_ref().unwrap()))?,
);
self.dginit = self
.init_grad
.as_ref()
.unwrap()
.dot(self.search_direction.as_ref().unwrap());
if self.dginit >= float!(0.0) {
return Err(argmin_error!(
ConditionViolated,
"`MoreThuenteLineSearch`: Search direction must be a descent direction."
));
}
self.stage1 = true;
self.brackt = false;
self.dgtest = self.ftol * self.dginit;
self.width = self.stpmax - self.stpmin;
self.width1 = float!(2.0) * self.width;
self.f = self.finit;
self.stp = Step::new(self.alpha, F::nan(), F::nan());
self.stx = Step::new(float!(0.0), self.finit, self.dginit);
self.sty = Step::new(float!(0.0), self.finit, self.dginit);
Ok((state, None))
}
fn next_iter(
&mut self,
problem: &mut Problem<O>,
state: IterState<P, G, (), (), (), F>,
) -> Result<(IterState<P, G, (), (), (), F>, Option<KV>), Error> {
let mut info = 0;
let (stmin, stmax) = if self.brackt {
(self.stx.x.min(self.sty.x), self.stx.x.max(self.sty.x))
} else {
(
self.stx.x,
self.stp.x + self.xtrapf * (self.stp.x - self.stx.x),
)
};
self.stp.x = self.stp.x.max(self.stpmin);
self.stp.x = self.stp.x.min(self.stpmax);
if (self.brackt && (self.stp.x <= stmin || self.stp.x >= stmax))
|| (self.brackt && (stmax - stmin) <= self.xtol * stmax)
|| self.infoc == 0
{
self.stp.x = self.stx.x;
}
let new_param = self
.init_param
.as_ref()
.unwrap()
.scaled_add(&self.stp.x, self.search_direction.as_ref().unwrap());
self.f = problem.cost(&new_param)?;
let new_grad = problem.gradient(&new_param)?;
let cur_cost = self.f;
let cur_param = new_param;
let cur_grad = new_grad.clone();
let dg = self.search_direction.as_ref().unwrap().dot(&new_grad);
let ftest1 = self.finit + self.stp.x * self.dgtest;
if (self.brackt && (self.stp.x <= stmin || self.stp.x >= stmax)) || self.infoc == 0 {
info = 6;
}
if (self.stp.x - self.stpmax).abs() < F::epsilon() && self.f <= ftest1 && dg <= self.dgtest
{
info = 5;
}
if (self.stp.x - self.stpmin).abs() < F::epsilon() && (self.f > ftest1 || dg >= self.dgtest)
{
info = 4;
}
if self.brackt && stmax - stmin <= self.xtol * stmax {
info = 2;
}
if self.f <= ftest1 && dg.abs() <= self.gtol * (-self.dginit) {
info = 1;
}
if info != 0 {
return Ok((
state
.param(cur_param)
.cost(cur_cost)
.gradient(cur_grad)
.terminate_with(TerminationReason::SolverConverged),
None,
));
}
if self.stage1 && self.f <= ftest1 && dg >= self.ftol.min(self.gtol) * self.dginit {
self.stage1 = false;
}
if self.stage1 && self.f <= self.stp.fx && self.f > ftest1 {
let fm = self.f - self.stp.x * self.dgtest;
let fxm = self.stx.fx - self.stx.x * self.dgtest;
let fym = self.sty.fx - self.sty.x * self.dgtest;
let dgm = dg - self.dgtest;
let dgxm = self.stx.gx - self.dgtest;
let dgym = self.sty.gx - self.dgtest;
let (stx1, sty1, stp1, brackt1, _stmin, _stmax, infoc) = cstep(
Step::new(self.stx.x, fxm, dgxm),
Step::new(self.sty.x, fym, dgym),
Step::new(self.stp.x, fm, dgm),
self.brackt,
stmin,
stmax,
)?;
self.stx.x = stx1.x;
self.sty.x = sty1.x;
self.stp.x = stp1.x;
self.stx.fx = self.stx.fx + stx1.x * self.dgtest;
self.sty.fx = self.sty.fx + sty1.x * self.dgtest;
self.stx.gx = self.stx.gx + self.dgtest;
self.sty.gx = self.sty.gx + self.dgtest;
self.brackt = brackt1;
self.stp = stp1;
self.infoc = infoc;
} else {
let (stx1, sty1, stp1, brackt1, _stmin, _stmax, infoc) = cstep(
self.stx.clone(),
self.sty.clone(),
Step::new(self.stp.x, self.f, dg),
self.brackt,
stmin,
stmax,
)?;
self.stx = stx1;
self.sty = sty1;
self.stp = stp1;
self.f = self.stp.fx;
self.brackt = brackt1;
self.infoc = infoc;
}
if self.brackt {
if (self.sty.x - self.stx.x).abs() >= float!(0.66) * self.width1 {
self.stp.x = self.stx.x + float!(0.5) * (self.sty.x - self.stx.x);
}
self.width1 = self.width;
self.width = (self.sty.x - self.stx.x).abs();
}
Ok((state, None))
}
}
type CstepReturnValue<F> = (Step<F>, Step<F>, Step<F>, bool, F, F, usize);
fn cstep<F: ArgminFloat>(
stx: Step<F>,
sty: Step<F>,
stp: Step<F>,
brackt: bool,
stpmin: F,
stpmax: F,
) -> Result<CstepReturnValue<F>, Error> {
let mut info: usize = 0;
let bound: bool;
let mut stpf: F;
let stpc: F;
let stpq: F;
let mut brackt = brackt;
if (brackt && (stp.x <= stx.x.min(sty.x) || stp.x >= stx.x.max(sty.x)))
|| stx.gx * (stp.x - stx.x) >= float!(0.0)
|| stpmax < stpmin
{
return Ok((stx, sty, stp, brackt, stpmin, stpmax, info));
}
let sgnd = stp.gx * (stx.gx / stx.gx.abs());
if stp.fx > stx.fx {
info = 1;
bound = true;
let theta = float!(3.0) * (stx.fx - stp.fx) / (stp.x - stx.x) + stx.gx + stp.gx;
let tmp = [theta, stx.gx, stp.gx];
if tmp.iter().any(|n| n.is_nan() || n.is_infinite()) {
return Err(argmin_error!(
ConditionViolated,
"MoreThuenteLineSearch: NaN or Inf encountered during iteration"
));
}
let s = tmp.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
let mut gamma = *s * ((theta / *s).powi(2) - (stx.gx / *s) * (stp.gx / *s)).sqrt();
if stp.x < stx.x {
gamma = -gamma;
}
let p = (gamma - stx.gx) + theta;
let q = ((gamma - stx.gx) + gamma) + stp.gx;
let r = p / q;
stpc = stx.x + r * (stp.x - stx.x);
stpq = stx.x
+ ((stx.gx / ((stx.fx - stp.fx) / (stp.x - stx.x) + stx.gx)) / float!(2.0))
* (stp.x - stx.x);
if (stpc - stx.x).abs() < (stpq - stx.x).abs() {
stpf = stpc;
} else {
stpf = stpc + (stpq - stpc) / float!(2.0);
}
brackt = true;
} else if sgnd < float!(0.0) {
info = 2;
bound = false;
let theta = float!(3.0) * (stx.fx - stp.fx) / (stp.x - stx.x) + stx.gx + stp.gx;
let tmp = [theta, stx.gx, stp.gx];
if tmp.iter().any(|n| n.is_nan() || n.is_infinite()) {
return Err(argmin_error!(
ConditionViolated,
"MoreThuenteLineSearch: NaN or Inf encountered during iteration"
));
}
let s = tmp.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
let mut gamma = *s * ((theta / *s).powi(2) - (stx.gx / *s) * (stp.gx / *s)).sqrt();
if stp.x > stx.x {
gamma = -gamma;
}
let p = (gamma - stp.gx) + theta;
let q = ((gamma - stp.gx) + gamma) + stx.gx;
let r = p / q;
stpc = stp.x + r * (stx.x - stp.x);
stpq = stp.x + (stp.gx / (stp.gx - stx.gx)) * (stx.x - stp.x);
if (stpc - stp.x).abs() > (stpq - stp.x).abs() {
stpf = stpc;
} else {
stpf = stpq;
}
brackt = true;
} else if stp.gx.abs() < stx.gx.abs() {
info = 3;
bound = true;
let theta = float!(3.0) * (stx.fx - stp.fx) / (stp.x - stx.x) + stx.gx + stp.gx;
let tmp = [theta, stx.gx, stp.gx];
if tmp.iter().any(|n| n.is_nan() || n.is_infinite()) {
return Err(argmin_error!(
ConditionViolated,
"`MoreThuenteLineSearch`: NaN or Inf encountered during iteration"
));
}
let s = tmp.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
let mut gamma = *s
* float!(0.0)
.max((theta / *s).powi(2) - (stx.gx / *s) * (stp.gx / *s))
.sqrt();
if stp.x > stx.x {
gamma = -gamma;
}
let p = (gamma - stp.gx) + theta;
let q = (gamma + (stx.gx - stp.gx)) + gamma;
let r = p / q;
if r < float!(0.0) && gamma != float!(0.0) {
stpc = stp.x + r * (stx.x - stp.x);
} else if stp.x > stx.x {
stpc = stpmax;
} else {
stpc = stpmin;
}
stpq = stp.x + (stp.gx / (stp.gx - stx.gx)) * (stx.x - stp.x);
if brackt {
if (stp.x - stpc).abs() < (stp.x - stpq).abs() {
stpf = stpc;
} else {
stpf = stpq;
}
} else if (stp.x - stpc).abs() > (stp.x - stpq).abs() {
stpf = stpc;
} else {
stpf = stpq;
}
} else {
info = 4;
bound = false;
if brackt {
let theta = float!(3.0) * (stp.fx - sty.fx) / (sty.x - stp.x) + sty.gx + stp.gx;
let tmp = [theta, sty.gx, stp.gx];
if tmp.iter().any(|n| n.is_nan() || n.is_infinite()) {
return Err(argmin_error!(
ConditionViolated,
"MoreThuenteLineSearch: NaN or Inf encountered during iteration"
));
}
let s = tmp.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
let mut gamma = *s * ((theta / *s).powi(2) - (sty.gx / *s) * (stp.gx / *s)).sqrt();
if stp.x > sty.x {
gamma = -gamma;
}
let p = (gamma - stp.gx) + theta;
let q = ((gamma - stp.gx) + gamma) + sty.gx;
let r = p / q;
stpc = stp.x + r * (sty.x - stp.x);
stpf = stpc;
} else if stp.x > stx.x {
stpf = stpmax;
} else {
stpf = stpmin;
}
}
let mut stx_o = stx;
let mut sty_o = sty;
let mut stp_o = stp;
if stp_o.fx > stx_o.fx {
sty_o = Step::new(stp_o.x, stp_o.fx, stp_o.gx);
} else {
if sgnd < float!(0.0) {
sty_o = Step::new(stx_o.x, stx_o.fx, stx_o.gx);
}
stx_o = Step::new(stp_o.x, stp_o.fx, stp_o.gx);
}
stpf = stpmax.min(stpf);
stpf = stpmin.max(stpf);
stp_o.x = stpf;
if brackt && bound {
if sty_o.x > stx_o.x {
stp_o.x = stp_o.x.min(stx_o.x + float!(0.66) * (sty_o.x - stx_o.x));
} else {
stp_o.x = stp_o.x.max(stx_o.x + float!(0.66) * (sty_o.x - stx_o.x));
}
}
Ok((stx_o, sty_o, stp_o, brackt, stpmin, stpmax, info))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{test_utils::TestProblem, ArgminError};
test_trait_impl!(morethuente, MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64>);
#[test]
fn test_new() {
let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
let MoreThuenteLineSearch {
search_direction,
init_param,
finit,
init_grad,
dginit,
dgtest,
ftol,
gtol,
xtrapf,
width,
width1,
xtol,
alpha,
stpmin,
stpmax,
stp,
stx,
sty,
f,
brackt,
stage1,
infoc,
} = mtls;
assert!(search_direction.is_none());
assert!(init_param.is_none());
assert!(finit.is_infinite());
assert!(finit.is_sign_positive());
assert!(init_grad.is_none());
assert_eq!(dginit.to_ne_bytes(), 0.0f64.to_ne_bytes());
assert_eq!(dgtest.to_ne_bytes(), 0.0f64.to_ne_bytes());
assert_eq!(ftol.to_ne_bytes(), 1e-4f64.to_ne_bytes());
assert_eq!(gtol.to_ne_bytes(), 0.9f64.to_ne_bytes());
assert_eq!(xtrapf.to_ne_bytes(), 4.0f64.to_ne_bytes());
assert!(width.is_nan());
assert!(width1.is_nan());
assert_eq!(xtol.to_ne_bytes(), 1e-10f64.to_ne_bytes());
assert_eq!(alpha.to_ne_bytes(), 1.0f64.to_ne_bytes());
assert_eq!(stpmin.to_ne_bytes(), f64::EPSILON.sqrt().to_ne_bytes());
assert!(stpmax.is_infinite());
assert!(stpmax.is_sign_positive());
assert_eq!(stp, Step::default());
assert_eq!(stx, Step::default());
assert_eq!(sty, Step::default());
assert!(f.is_nan());
assert!(!brackt);
assert!(stage1);
assert_eq!(infoc, 1);
}
#[test]
fn test_with_c_correct() {
let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
let res = mtls.with_c(0.1, 0.9);
assert!(res.is_ok());
let mtls = res.unwrap();
assert_eq!(mtls.ftol.to_ne_bytes(), 0.1f64.to_ne_bytes());
assert_eq!(mtls.gtol.to_ne_bytes(), 0.9f64.to_ne_bytes());
}
#[test]
fn test_with_c_c1_larger_than_c2() {
let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
let res = mtls.with_c(0.9, 0.1);
assert_error!(
res,
ArgminError,
concat!(
"Invalid parameter: \"`MoreThuenteLineSearch`: ",
"Parameter c1 must be in (0, c2).\""
)
);
}
#[test]
fn test_with_c_c1_smaller_than_0() {
let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
let res = mtls.with_c(-0.9, 0.99);
assert_error!(
res,
ArgminError,
concat!(
"Invalid parameter: \"`MoreThuenteLineSearch`: ",
"Parameter c1 must be in (0, c2).\""
)
);
}
#[test]
fn test_with_c_c2_larger_than_1() {
let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
let res = mtls.with_c(0.1, 1.01);
assert_error!(
res,
ArgminError,
concat!(
"Invalid parameter: \"`MoreThuenteLineSearch`: ",
"Parameter c2 must be in (c1, 1).\""
)
);
}
#[test]
fn test_with_bounds_correct() {
let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
let res = mtls.with_bounds(0.1, 0.9);
assert!(res.is_ok());
let mtls = res.unwrap();
assert_eq!(mtls.stpmin.to_ne_bytes(), 0.1f64.to_ne_bytes());
assert_eq!(mtls.stpmax.to_ne_bytes(), 0.9f64.to_ne_bytes());
}
#[test]
fn test_with_bounds_step_min_smaller_than_0() {
let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
let res = mtls.with_bounds(-0.1, 0.99);
assert_error!(
res,
ArgminError,
concat!(
"Invalid parameter: \"`MoreThuenteLineSearch`: ",
"step_min must be >= 0.0.\""
)
);
}
#[test]
fn test_with_bounds_step_min_larger_than_step_max() {
let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
let res = mtls.with_bounds(10.0, 0.99);
assert_error!(
res,
ArgminError,
concat!(
"Invalid parameter: \"`MoreThuenteLineSearch`: ",
"step_min must be smaller than step_max.\""
)
);
}
#[test]
fn test_with_width_tolerance_correct() {
let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
let res = mtls.with_width_tolerance(1e-9);
assert!(res.is_ok());
let mtls = res.unwrap();
assert_eq!(mtls.xtol.to_ne_bytes(), 1e-9f64.to_ne_bytes());
}
#[test]
fn test_with_width_tolerance_negative_xtol() {
let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
let res = mtls.with_width_tolerance(-1e-10);
assert_error!(
res,
ArgminError,
concat!(
"Invalid parameter: \"`MoreThuenteLineSearch`: ",
"relative width tolerance must be >= 0.0.\""
)
);
}
#[test]
fn test_init_search_direction_not_set() {
let mut mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
let res = mtls.init(&mut Problem::new(TestProblem::new()), IterState::new());
assert_error!(
res,
ArgminError,
concat!(
"Not initialized: \"`MoreThuenteLineSearch`: Search direction not initialized. ",
"Call `search_direction` before executing the solver.\""
)
);
}
#[test]
fn test_init_param_not_set() {
let mut mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
mtls.search_direction(vec![1.0f64]);
let res = mtls.init(&mut Problem::new(TestProblem::new()), IterState::new());
assert_error!(
res,
ArgminError,
concat!(
"Not initialized: \"`MoreThuenteLineSearch` requires an initial parameter vector. ",
"Please provide an initial guess via `Executor`s `configure` method.\""
)
);
}
}