argmin/solver/linesearch/
morethuente.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8// Deactivating this lint here because it would make the Boolean expressions more difficult to
9// read.
10#![allow(clippy::nonminimal_bool)]
11
12use crate::core::{
13    ArgminFloat, CostFunction, Error, Gradient, IterState, LineSearch, Problem, Solver, State,
14    TerminationReason, KV,
15};
16use argmin_math::{ArgminDot, ArgminScaledAdd};
17#[cfg(feature = "serde1")]
18use serde::{Deserialize, Serialize};
19
20/// # More-Thuente line search
21///
22/// The More-Thuente line search is a method which finds an appropriate step length from a starting
23/// point and a search direction. This point obeys the strong Wolfe conditions.
24///
25/// With the method [`with_c`](`MoreThuenteLineSearch::with_c`) the scaling factors for the
26/// sufficient decrease condition and the curvature condition can be supplied. By default they are
27/// set to `c1 = 1e-4` and `c2 = 0.9`.
28///
29/// Bounds on the range where step lengths are being searched for can be set with
30/// [`with_bounds`](`MoreThuenteLineSearch::with_bounds`) which accepts a lower and an upper bound.
31/// Both values need to be non-negative and `lower < upper`.
32///
33/// One of the reasons for the algorithm to terminate is when the the relative width of the
34/// uncertainty interval is smaller than a given tolerance (default: `1e-10`). This tolerance can
35/// be set via [`with_width_tolerance`](`MoreThuenteLineSearch::with_width_tolerance`) and must be
36/// non-negative.
37///
38/// TODO: Add missing stopping criteria!
39///
40/// ## Requirements on the optimization problem
41///
42/// The optimization problem is required to implement [`CostFunction`] and [`Gradient`].
43///
44/// ## References
45///
46/// This implementation follows the excellent MATLAB implementation of Dianne P. O'Leary at
47/// <http://www.cs.umd.edu/users/oleary/software/>
48///
49/// Jorge J. More and David J. Thuente. "Line search algorithms with guaranteed sufficient
50/// decrease." ACM Trans. Math. Softw. 20, 3 (September 1994), 286-307.
51/// DOI: <https://doi.org/10.1145/192115.192132>
52#[derive(Clone)]
53#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
54pub struct MoreThuenteLineSearch<P, G, F> {
55    /// Search direction
56    search_direction: Option<G>,
57    /// initial parameter vector
58    init_param: Option<P>,
59    /// initial cost
60    finit: F,
61    /// initial gradient
62    init_grad: Option<G>,
63    /// Search direction in 1D
64    dginit: F,
65    /// dgtest
66    dgtest: F,
67    /// c1
68    ftol: F,
69    /// c2
70    gtol: F,
71    /// xtrapf
72    xtrapf: F,
73    /// width of interval
74    width: F,
75    /// width of what?
76    width1: F,
77    /// xtol
78    xtol: F,
79    /// alpha
80    alpha: F,
81    /// stpmin
82    stpmin: F,
83    /// stpmax
84    stpmax: F,
85    /// current step
86    stp: Step<F>,
87    /// stx (one endpoint of uncertainty interval)
88    stx: Step<F>,
89    /// sty (another endpoint of uncertainty interval)
90    sty: Step<F>,
91    /// f
92    f: F,
93    /// bracketed
94    brackt: bool,
95    /// stage1
96    stage1: bool,
97    /// infoc
98    infoc: usize,
99}
100
101#[derive(Clone, Eq, PartialEq, Debug)]
102#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
103struct Step<F> {
104    pub x: F,
105    pub fx: F,
106    pub gx: F,
107}
108
109impl<F> Step<F> {
110    /// Create a new instance of `Step`
111    pub fn new(x: F, fx: F, gx: F) -> Self {
112        Step { x, fx, gx }
113    }
114}
115
116impl<F> Default for Step<F>
117where
118    F: ArgminFloat,
119{
120    fn default() -> Self {
121        Step {
122            x: float!(0.0),
123            fx: float!(0.0),
124            gx: float!(0.0),
125        }
126    }
127}
128
129impl<P, G, F> MoreThuenteLineSearch<P, G, F>
130where
131    F: ArgminFloat,
132{
133    /// Construct a new instance of `MoreThuenteLineSearch`
134    ///
135    /// # Example
136    ///
137    /// ```
138    /// # use argmin::solver::linesearch::MoreThuenteLineSearch;
139    /// let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
140    /// ```
141    pub fn new() -> Self {
142        MoreThuenteLineSearch {
143            search_direction: None,
144            init_param: None,
145            finit: F::infinity(),
146            init_grad: None,
147            dginit: float!(0.0),
148            dgtest: float!(0.0),
149            ftol: float!(1e-4),
150            gtol: float!(0.9),
151            xtrapf: float!(4.0),
152            width: F::nan(),
153            width1: F::nan(),
154            xtol: float!(1e-10),
155            alpha: float!(1.0),
156            stpmin: F::epsilon().sqrt(),
157            stpmax: F::infinity(),
158            stp: Step::default(),
159            stx: Step::default(),
160            sty: Step::default(),
161            f: F::nan(),
162            brackt: false,
163            stage1: true,
164            infoc: 1,
165        }
166    }
167
168    /// Set the constants c1 and c2 for the sufficient decrease and curvature conditions,
169    /// respectively. `0 < c1 < c2 < 1` must hold.
170    ///
171    /// The default values are `c1 = 1e-4` and `c2 = 0.9`.
172    ///
173    /// # Example
174    ///
175    /// ```
176    /// # use argmin::solver::linesearch::MoreThuenteLineSearch;
177    /// # use argmin::core::Error;
178    /// # fn main() -> Result<(), Error> {
179    /// let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> =
180    ///     MoreThuenteLineSearch::new().with_c(1e-3, 0.8)?;
181    /// # Ok(())
182    /// # }
183    /// ```
184    pub fn with_c(mut self, c1: F, c2: F) -> Result<Self, Error> {
185        if c1 <= float!(0.0) || c1 >= c2 {
186            return Err(argmin_error!(
187                InvalidParameter,
188                "`MoreThuenteLineSearch`: Parameter c1 must be in (0, c2)."
189            ));
190        }
191        if c2 <= c1 || c2 >= float!(1.0) {
192            return Err(argmin_error!(
193                InvalidParameter,
194                "`MoreThuenteLineSearch`: Parameter c2 must be in (c1, 1)."
195            ));
196        }
197        self.ftol = c1;
198        self.gtol = c2;
199        Ok(self)
200    }
201
202    /// Set lower and upper bound of step
203    ///
204    /// Defaults are `step_min = sqrt(EPS)` and `step_max = INF`.
205    ///
206    /// `step_min` must be smaller than `step_max`.
207    ///
208    /// # Example
209    ///
210    /// ```
211    /// # use argmin::solver::linesearch::MoreThuenteLineSearch;
212    /// # use argmin::core::Error;
213    /// # fn main() -> Result<(), Error> {
214    /// let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> =
215    ///     MoreThuenteLineSearch::new().with_bounds(1e-6, 10.0)?;
216    /// # Ok(())
217    /// # }
218    /// ```
219    pub fn with_bounds(mut self, step_min: F, step_max: F) -> Result<Self, Error> {
220        if step_min < float!(0.0) {
221            return Err(argmin_error!(
222                InvalidParameter,
223                "`MoreThuenteLineSearch`: step_min must be >= 0.0."
224            ));
225        }
226        if step_max <= step_min {
227            return Err(argmin_error!(
228                InvalidParameter,
229                "`MoreThuenteLineSearch`: step_min must be smaller than step_max."
230            ));
231        }
232        self.stpmin = step_min;
233        self.stpmax = step_max;
234        Ok(self)
235    }
236
237    /// Set relative tolerance on width of uncertainty interval
238    ///
239    /// The algorithm terminates when the relative width of the uncertainty interval is below the
240    /// supplied tolerance.
241    ///
242    /// Must be non-negative and defaults to `1e-10`.
243    ///
244    /// # Example
245    ///
246    /// ```
247    /// # use argmin::solver::linesearch::MoreThuenteLineSearch;
248    /// # use argmin::core::Error;
249    /// # fn main() -> Result<(), Error> {
250    /// let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> =
251    ///     MoreThuenteLineSearch::new().with_width_tolerance(1e-9)?;
252    /// # Ok(())
253    /// # }
254    /// ```
255    pub fn with_width_tolerance(mut self, xtol: F) -> Result<Self, Error> {
256        if xtol < float!(0.0) {
257            return Err(argmin_error!(
258                InvalidParameter,
259                "`MoreThuenteLineSearch`: relative width tolerance must be >= 0.0."
260            ));
261        }
262        self.xtol = xtol;
263        Ok(self)
264    }
265}
266
267impl<P, G, F> Default for MoreThuenteLineSearch<P, G, F>
268where
269    F: ArgminFloat,
270{
271    fn default() -> Self {
272        MoreThuenteLineSearch::new()
273    }
274}
275
276impl<P, G, F> LineSearch<G, F> for MoreThuenteLineSearch<P, G, F>
277where
278    F: ArgminFloat,
279{
280    /// Set search direction
281    fn search_direction(&mut self, search_direction: G) {
282        self.search_direction = Some(search_direction);
283    }
284
285    /// Set initial alpha value
286    fn initial_step_length(&mut self, alpha: F) -> Result<(), Error> {
287        if alpha <= float!(0.0) {
288            return Err(argmin_error!(
289                InvalidParameter,
290                "MoreThuenteLineSearch: Initial alpha must be > 0."
291            ));
292        }
293        self.alpha = alpha;
294        Ok(())
295    }
296}
297
298impl<P, G, O, F> Solver<O, IterState<P, G, (), (), (), F>> for MoreThuenteLineSearch<P, G, F>
299where
300    O: CostFunction<Param = P, Output = F> + Gradient<Param = P, Gradient = G>,
301    P: Clone + ArgminDot<G, F> + ArgminScaledAdd<G, F, P>,
302    G: Clone + ArgminDot<G, F>,
303    F: ArgminFloat,
304{
305    fn name(&self) -> &str {
306        "More-Thuente Line search"
307    }
308
309    fn init(
310        &mut self,
311        problem: &mut Problem<O>,
312        mut state: IterState<P, G, (), (), (), F>,
313    ) -> Result<(IterState<P, G, (), (), (), F>, Option<KV>), Error> {
314        check_param!(
315            self.search_direction,
316            concat!(
317                "`MoreThuenteLineSearch`: Search direction not initialized. ",
318                "Call `search_direction` before executing the solver."
319            )
320        );
321
322        self.init_param = Some(state.take_param().ok_or_else(argmin_error_closure!(
323            NotInitialized,
324            concat!(
325                "`MoreThuenteLineSearch` requires an initial parameter vector. ",
326                "Please provide an initial guess via `Executor`s `configure` method."
327            )
328        ))?);
329
330        let cost = state.get_cost();
331        self.finit = if cost.is_infinite() {
332            problem.cost(self.init_param.as_ref().unwrap())?
333        } else {
334            cost
335        };
336
337        self.init_grad = Some(
338            state
339                .take_gradient()
340                .map(Result::Ok)
341                .unwrap_or_else(|| problem.gradient(self.init_param.as_ref().unwrap()))?,
342        );
343
344        self.dginit = self
345            .init_grad
346            .as_ref()
347            .unwrap()
348            .dot(self.search_direction.as_ref().unwrap());
349
350        // compute search direction in 1D
351        if self.dginit >= float!(0.0) {
352            return Err(argmin_error!(
353                ConditionViolated,
354                "`MoreThuenteLineSearch`: Search direction must be a descent direction."
355            ));
356        }
357
358        self.stage1 = true;
359        self.brackt = false;
360
361        self.dgtest = self.ftol * self.dginit;
362        self.width = self.stpmax - self.stpmin;
363        self.width1 = float!(2.0) * self.width;
364        self.f = self.finit;
365
366        self.stp = Step::new(self.alpha, F::nan(), F::nan());
367        self.stx = Step::new(float!(0.0), self.finit, self.dginit);
368        self.sty = Step::new(float!(0.0), self.finit, self.dginit);
369
370        Ok((state, None))
371    }
372
373    fn next_iter(
374        &mut self,
375        problem: &mut Problem<O>,
376        state: IterState<P, G, (), (), (), F>,
377    ) -> Result<(IterState<P, G, (), (), (), F>, Option<KV>), Error> {
378        // set the minimum and maximum steps to correspond to the present interval of uncertainty
379        let mut info = 0;
380        let (stmin, stmax) = if self.brackt {
381            (self.stx.x.min(self.sty.x), self.stx.x.max(self.sty.x))
382        } else {
383            (
384                self.stx.x,
385                self.stp.x + self.xtrapf * (self.stp.x - self.stx.x),
386            )
387        };
388
389        // alpha needs to be within bounds
390        self.stp.x = self.stp.x.max(self.stpmin);
391        self.stp.x = self.stp.x.min(self.stpmax);
392
393        // If an unusual termination is to occur then let alpha be the lowest point obtained so
394        // far.
395        if (self.brackt && (self.stp.x <= stmin || self.stp.x >= stmax))
396            || (self.brackt && (stmax - stmin) <= self.xtol * stmax)
397            || self.infoc == 0
398        {
399            self.stp.x = self.stx.x;
400        }
401
402        // Evaluate the function and gradient at new stp.x and compute the directional derivative
403        let new_param = self
404            .init_param
405            .as_ref()
406            .unwrap()
407            .scaled_add(&self.stp.x, self.search_direction.as_ref().unwrap());
408        self.f = problem.cost(&new_param)?;
409        let new_grad = problem.gradient(&new_param)?;
410        let cur_cost = self.f;
411        let cur_param = new_param;
412        let cur_grad = new_grad.clone();
413        // self.stx.fx = new_cost;
414        let dg = self.search_direction.as_ref().unwrap().dot(&new_grad);
415        let ftest1 = self.finit + self.stp.x * self.dgtest;
416        // self.stp.fx = new_cost;
417        // self.stp.gx = dg;
418
419        if (self.brackt && (self.stp.x <= stmin || self.stp.x >= stmax)) || self.infoc == 0 {
420            info = 6;
421        }
422
423        if (self.stp.x - self.stpmax).abs() < F::epsilon() && self.f <= ftest1 && dg <= self.dgtest
424        {
425            info = 5;
426        }
427
428        if (self.stp.x - self.stpmin).abs() < F::epsilon() && (self.f > ftest1 || dg >= self.dgtest)
429        {
430            info = 4;
431        }
432
433        if self.brackt && stmax - stmin <= self.xtol * stmax {
434            info = 2;
435        }
436
437        if self.f <= ftest1 && dg.abs() <= self.gtol * (-self.dginit) {
438            info = 1;
439        }
440
441        if info != 0 {
442            return Ok((
443                state
444                    .param(cur_param)
445                    .cost(cur_cost)
446                    .gradient(cur_grad)
447                    .terminate_with(TerminationReason::SolverConverged),
448                None,
449            ));
450        }
451
452        if self.stage1 && self.f <= ftest1 && dg >= self.ftol.min(self.gtol) * self.dginit {
453            self.stage1 = false;
454        }
455
456        if self.stage1 && self.f <= self.stp.fx && self.f > ftest1 {
457            let fm = self.f - self.stp.x * self.dgtest;
458            let fxm = self.stx.fx - self.stx.x * self.dgtest;
459            let fym = self.sty.fx - self.sty.x * self.dgtest;
460            let dgm = dg - self.dgtest;
461            let dgxm = self.stx.gx - self.dgtest;
462            let dgym = self.sty.gx - self.dgtest;
463
464            let (stx1, sty1, stp1, brackt1, _stmin, _stmax, infoc) = cstep(
465                Step::new(self.stx.x, fxm, dgxm),
466                Step::new(self.sty.x, fym, dgym),
467                Step::new(self.stp.x, fm, dgm),
468                self.brackt,
469                stmin,
470                stmax,
471            )?;
472
473            self.stx.x = stx1.x;
474            self.sty.x = sty1.x;
475            self.stp.x = stp1.x;
476            self.stx.fx = self.stx.fx + stx1.x * self.dgtest;
477            self.sty.fx = self.sty.fx + sty1.x * self.dgtest;
478            self.stx.gx = self.stx.gx + self.dgtest;
479            self.sty.gx = self.sty.gx + self.dgtest;
480            self.brackt = brackt1;
481            self.stp = stp1;
482            self.infoc = infoc;
483        } else {
484            let (stx1, sty1, stp1, brackt1, _stmin, _stmax, infoc) = cstep(
485                self.stx.clone(),
486                self.sty.clone(),
487                Step::new(self.stp.x, self.f, dg),
488                self.brackt,
489                stmin,
490                stmax,
491            )?;
492            self.stx = stx1;
493            self.sty = sty1;
494            self.stp = stp1;
495            self.f = self.stp.fx;
496            // dg = self.stp.gx;
497            self.brackt = brackt1;
498            self.infoc = infoc;
499        }
500
501        if self.brackt {
502            if (self.sty.x - self.stx.x).abs() >= float!(0.66) * self.width1 {
503                self.stp.x = self.stx.x + float!(0.5) * (self.sty.x - self.stx.x);
504            }
505            self.width1 = self.width;
506            self.width = (self.sty.x - self.stx.x).abs();
507        }
508
509        Ok((state, None))
510    }
511}
512
513type CstepReturnValue<F> = (Step<F>, Step<F>, Step<F>, bool, F, F, usize);
514
515fn cstep<F: ArgminFloat>(
516    stx: Step<F>,
517    sty: Step<F>,
518    stp: Step<F>,
519    brackt: bool,
520    stpmin: F,
521    stpmax: F,
522) -> Result<CstepReturnValue<F>, Error> {
523    let mut info: usize = 0;
524    let bound: bool;
525    let mut stpf: F;
526    let stpc: F;
527    let stpq: F;
528    let mut brackt = brackt;
529
530    // check inputs
531    if (brackt && (stp.x <= stx.x.min(sty.x) || stp.x >= stx.x.max(sty.x)))
532        || stx.gx * (stp.x - stx.x) >= float!(0.0)
533        || stpmax < stpmin
534    {
535        return Ok((stx, sty, stp, brackt, stpmin, stpmax, info));
536    }
537
538    // determine if the derivatives have opposite sign
539    let sgnd = stp.gx * (stx.gx / stx.gx.abs());
540
541    if stp.fx > stx.fx {
542        // First case. A higher function value. The minimum is bracketed. If the cubic step is closer to
543        // stx.x than the quadratic step, the cubic step is taken, else the average of the cubic and
544        // the quadratic steps is taken.
545        info = 1;
546        bound = true;
547        let theta = float!(3.0) * (stx.fx - stp.fx) / (stp.x - stx.x) + stx.gx + stp.gx;
548        let tmp = [theta, stx.gx, stp.gx];
549        // Check for a NaN or Inf in tmp before sorting
550        if tmp.iter().any(|n| n.is_nan() || n.is_infinite()) {
551            return Err(argmin_error!(
552                ConditionViolated,
553                "MoreThuenteLineSearch: NaN or Inf encountered during iteration"
554            ));
555        }
556        let s = tmp.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
557        let mut gamma = *s * ((theta / *s).powi(2) - (stx.gx / *s) * (stp.gx / *s)).sqrt();
558        if stp.x < stx.x {
559            gamma = -gamma;
560        }
561
562        let p = (gamma - stx.gx) + theta;
563        let q = ((gamma - stx.gx) + gamma) + stp.gx;
564        let r = p / q;
565        stpc = stx.x + r * (stp.x - stx.x);
566        stpq = stx.x
567            + ((stx.gx / ((stx.fx - stp.fx) / (stp.x - stx.x) + stx.gx)) / float!(2.0))
568                * (stp.x - stx.x);
569        if (stpc - stx.x).abs() < (stpq - stx.x).abs() {
570            stpf = stpc;
571        } else {
572            stpf = stpc + (stpq - stpc) / float!(2.0);
573        }
574        brackt = true;
575    } else if sgnd < float!(0.0) {
576        // Second case. A lower function value and derivatives of opposite sign. The minimum is
577        // bracketed. If the cubic step is closer to stx.x than the quadratic (secant) step, the
578        // cubic step is taken, else the quadratic step is taken.
579        info = 2;
580        bound = false;
581        let theta = float!(3.0) * (stx.fx - stp.fx) / (stp.x - stx.x) + stx.gx + stp.gx;
582        let tmp = [theta, stx.gx, stp.gx];
583        // Check for a NaN or Inf in tmp before sorting
584        if tmp.iter().any(|n| n.is_nan() || n.is_infinite()) {
585            return Err(argmin_error!(
586                ConditionViolated,
587                "MoreThuenteLineSearch: NaN or Inf encountered during iteration"
588            ));
589        }
590        let s = tmp.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
591        let mut gamma = *s * ((theta / *s).powi(2) - (stx.gx / *s) * (stp.gx / *s)).sqrt();
592        if stp.x > stx.x {
593            gamma = -gamma;
594        }
595        let p = (gamma - stp.gx) + theta;
596        let q = ((gamma - stp.gx) + gamma) + stx.gx;
597        let r = p / q;
598        stpc = stp.x + r * (stx.x - stp.x);
599        stpq = stp.x + (stp.gx / (stp.gx - stx.gx)) * (stx.x - stp.x);
600        if (stpc - stp.x).abs() > (stpq - stp.x).abs() {
601            stpf = stpc;
602        } else {
603            stpf = stpq;
604        }
605        brackt = true;
606    } else if stp.gx.abs() < stx.gx.abs() {
607        // Third case. A lower function value, derivatives of the same sign, and the magnitude of
608        // the derivative decreases. The cubic step is only used if the cubic tends to infinity in
609        // the direction of the step or if the minimum of the cubic is beyond stp.x. Otherwise the
610        // cubic step is defined to be either stpmin or stpmax. The quadratic (secant) step is
611        // also computed and if the minimum is bracketed then the step closest to stx.x is taken,
612        // else the step farthest away is taken.
613        info = 3;
614        bound = true;
615        let theta = float!(3.0) * (stx.fx - stp.fx) / (stp.x - stx.x) + stx.gx + stp.gx;
616        let tmp = [theta, stx.gx, stp.gx];
617        // Check for a NaN or Inf in tmp before sorting
618        if tmp.iter().any(|n| n.is_nan() || n.is_infinite()) {
619            return Err(argmin_error!(
620                ConditionViolated,
621                "`MoreThuenteLineSearch`: NaN or Inf encountered during iteration"
622            ));
623        }
624        let s = tmp.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
625        // the case gamma == 0 only arises if the cubic does not tend to infinity in the direction
626        // of the step.
627
628        let mut gamma = *s
629            * float!(0.0)
630                .max((theta / *s).powi(2) - (stx.gx / *s) * (stp.gx / *s))
631                .sqrt();
632        if stp.x > stx.x {
633            gamma = -gamma;
634        }
635
636        let p = (gamma - stp.gx) + theta;
637        let q = (gamma + (stx.gx - stp.gx)) + gamma;
638        let r = p / q;
639        if r < float!(0.0) && gamma != float!(0.0) {
640            stpc = stp.x + r * (stx.x - stp.x);
641        } else if stp.x > stx.x {
642            stpc = stpmax;
643        } else {
644            stpc = stpmin;
645        }
646        stpq = stp.x + (stp.gx / (stp.gx - stx.gx)) * (stx.x - stp.x);
647        if brackt {
648            if (stp.x - stpc).abs() < (stp.x - stpq).abs() {
649                stpf = stpc;
650            } else {
651                stpf = stpq;
652            }
653        } else if (stp.x - stpc).abs() > (stp.x - stpq).abs() {
654            stpf = stpc;
655        } else {
656            stpf = stpq;
657        }
658    } else {
659        // Fourth case. A lower function value, derivatives of the same sign, and the magnitude of
660        // the derivative does not decrease. If the minimum is not bracketed, the step is either
661        // stpmin or stpmax, else the cubic step is taken.
662        info = 4;
663        bound = false;
664        if brackt {
665            let theta = float!(3.0) * (stp.fx - sty.fx) / (sty.x - stp.x) + sty.gx + stp.gx;
666            let tmp = [theta, sty.gx, stp.gx];
667            // Check for a NaN or Inf in tmp before sorting
668            if tmp.iter().any(|n| n.is_nan() || n.is_infinite()) {
669                return Err(argmin_error!(
670                    ConditionViolated,
671                    "MoreThuenteLineSearch: NaN or Inf encountered during iteration"
672                ));
673            }
674            let s = tmp.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
675            let mut gamma = *s * ((theta / *s).powi(2) - (sty.gx / *s) * (stp.gx / *s)).sqrt();
676            if stp.x > sty.x {
677                gamma = -gamma;
678            }
679            let p = (gamma - stp.gx) + theta;
680            let q = ((gamma - stp.gx) + gamma) + sty.gx;
681            let r = p / q;
682            stpc = stp.x + r * (sty.x - stp.x);
683            stpf = stpc;
684        } else if stp.x > stx.x {
685            stpf = stpmax;
686        } else {
687            stpf = stpmin;
688        }
689    }
690    // Update the interval of uncertainty. This update does not depend on the new step or the case
691    // analysis above.
692
693    let mut stx_o = stx;
694    let mut sty_o = sty;
695    let mut stp_o = stp;
696    if stp_o.fx > stx_o.fx {
697        sty_o = Step::new(stp_o.x, stp_o.fx, stp_o.gx);
698    } else {
699        if sgnd < float!(0.0) {
700            sty_o = Step::new(stx_o.x, stx_o.fx, stx_o.gx);
701        }
702        stx_o = Step::new(stp_o.x, stp_o.fx, stp_o.gx);
703    }
704
705    // compute the new step and safeguard it.
706
707    stpf = stpmax.min(stpf);
708    stpf = stpmin.max(stpf);
709
710    stp_o.x = stpf;
711    if brackt && bound {
712        if sty_o.x > stx_o.x {
713            stp_o.x = stp_o.x.min(stx_o.x + float!(0.66) * (sty_o.x - stx_o.x));
714        } else {
715            stp_o.x = stp_o.x.max(stx_o.x + float!(0.66) * (sty_o.x - stx_o.x));
716        }
717    }
718
719    Ok((stx_o, sty_o, stp_o, brackt, stpmin, stpmax, info))
720}
721
722#[cfg(test)]
723mod tests {
724    use super::*;
725    use crate::core::{test_utils::TestProblem, ArgminError};
726
727    test_trait_impl!(morethuente, MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64>);
728
729    #[test]
730    fn test_new() {
731        let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
732        let MoreThuenteLineSearch {
733            search_direction,
734            init_param,
735            finit,
736            init_grad,
737            dginit,
738            dgtest,
739            ftol,
740            gtol,
741            xtrapf,
742            width,
743            width1,
744            xtol,
745            alpha,
746            stpmin,
747            stpmax,
748            stp,
749            stx,
750            sty,
751            f,
752            brackt,
753            stage1,
754            infoc,
755        } = mtls;
756
757        assert!(search_direction.is_none());
758        assert!(init_param.is_none());
759        assert!(finit.is_infinite());
760        assert!(finit.is_sign_positive());
761        assert!(init_grad.is_none());
762        assert_eq!(dginit.to_ne_bytes(), 0.0f64.to_ne_bytes());
763        assert_eq!(dgtest.to_ne_bytes(), 0.0f64.to_ne_bytes());
764        assert_eq!(ftol.to_ne_bytes(), 1e-4f64.to_ne_bytes());
765        assert_eq!(gtol.to_ne_bytes(), 0.9f64.to_ne_bytes());
766        assert_eq!(xtrapf.to_ne_bytes(), 4.0f64.to_ne_bytes());
767        assert!(width.is_nan());
768        assert!(width1.is_nan());
769        assert_eq!(xtol.to_ne_bytes(), 1e-10f64.to_ne_bytes());
770        assert_eq!(alpha.to_ne_bytes(), 1.0f64.to_ne_bytes());
771        assert_eq!(stpmin.to_ne_bytes(), f64::EPSILON.sqrt().to_ne_bytes());
772        assert!(stpmax.is_infinite());
773        assert!(stpmax.is_sign_positive());
774        assert_eq!(stp, Step::default());
775        assert_eq!(stx, Step::default());
776        assert_eq!(sty, Step::default());
777        assert!(f.is_nan());
778        assert!(!brackt);
779        assert!(stage1);
780        assert_eq!(infoc, 1);
781    }
782
783    #[test]
784    fn test_with_c_correct() {
785        let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
786        let res = mtls.with_c(0.1, 0.9);
787        assert!(res.is_ok());
788
789        let mtls = res.unwrap();
790        assert_eq!(mtls.ftol.to_ne_bytes(), 0.1f64.to_ne_bytes());
791        assert_eq!(mtls.gtol.to_ne_bytes(), 0.9f64.to_ne_bytes());
792    }
793
794    #[test]
795    fn test_with_c_c1_larger_than_c2() {
796        let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
797        let res = mtls.with_c(0.9, 0.1);
798        assert_error!(
799            res,
800            ArgminError,
801            concat!(
802                "Invalid parameter: \"`MoreThuenteLineSearch`: ",
803                "Parameter c1 must be in (0, c2).\""
804            )
805        );
806    }
807
808    #[test]
809    fn test_with_c_c1_smaller_than_0() {
810        let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
811        let res = mtls.with_c(-0.9, 0.99);
812        assert_error!(
813            res,
814            ArgminError,
815            concat!(
816                "Invalid parameter: \"`MoreThuenteLineSearch`: ",
817                "Parameter c1 must be in (0, c2).\""
818            )
819        );
820    }
821
822    #[test]
823    fn test_with_c_c2_larger_than_1() {
824        let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
825        let res = mtls.with_c(0.1, 1.01);
826        assert_error!(
827            res,
828            ArgminError,
829            concat!(
830                "Invalid parameter: \"`MoreThuenteLineSearch`: ",
831                "Parameter c2 must be in (c1, 1).\""
832            )
833        );
834    }
835
836    #[test]
837    fn test_with_bounds_correct() {
838        let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
839        let res = mtls.with_bounds(0.1, 0.9);
840        assert!(res.is_ok());
841
842        let mtls = res.unwrap();
843        assert_eq!(mtls.stpmin.to_ne_bytes(), 0.1f64.to_ne_bytes());
844        assert_eq!(mtls.stpmax.to_ne_bytes(), 0.9f64.to_ne_bytes());
845    }
846
847    #[test]
848    fn test_with_bounds_step_min_smaller_than_0() {
849        let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
850        let res = mtls.with_bounds(-0.1, 0.99);
851        assert_error!(
852            res,
853            ArgminError,
854            concat!(
855                "Invalid parameter: \"`MoreThuenteLineSearch`: ",
856                "step_min must be >= 0.0.\""
857            )
858        );
859    }
860
861    #[test]
862    fn test_with_bounds_step_min_larger_than_step_max() {
863        let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
864        let res = mtls.with_bounds(10.0, 0.99);
865        assert_error!(
866            res,
867            ArgminError,
868            concat!(
869                "Invalid parameter: \"`MoreThuenteLineSearch`: ",
870                "step_min must be smaller than step_max.\""
871            )
872        );
873    }
874
875    #[test]
876    fn test_with_width_tolerance_correct() {
877        let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
878        let res = mtls.with_width_tolerance(1e-9);
879        assert!(res.is_ok());
880
881        let mtls = res.unwrap();
882        assert_eq!(mtls.xtol.to_ne_bytes(), 1e-9f64.to_ne_bytes());
883    }
884
885    #[test]
886    fn test_with_width_tolerance_negative_xtol() {
887        let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
888        let res = mtls.with_width_tolerance(-1e-10);
889        assert_error!(
890            res,
891            ArgminError,
892            concat!(
893                "Invalid parameter: \"`MoreThuenteLineSearch`: ",
894                "relative width tolerance must be >= 0.0.\""
895            )
896        );
897    }
898
899    #[test]
900    fn test_init_search_direction_not_set() {
901        let mut mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
902        let res = mtls.init(&mut Problem::new(TestProblem::new()), IterState::new());
903        assert_error!(
904            res,
905            ArgminError,
906            concat!(
907                "Not initialized: \"`MoreThuenteLineSearch`: Search direction not initialized. ",
908                "Call `search_direction` before executing the solver.\""
909            )
910        );
911    }
912
913    #[test]
914    fn test_init_param_not_set() {
915        let mut mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
916        mtls.search_direction(vec![1.0f64]);
917        let res = mtls.init(&mut Problem::new(TestProblem::new()), IterState::new());
918        assert_error!(
919            res,
920            ArgminError,
921            concat!(
922                "Not initialized: \"`MoreThuenteLineSearch` requires an initial parameter vector. ",
923                "Please provide an initial guess via `Executor`s `configure` method.\""
924            )
925        );
926    }
927}