argmin/solver/linesearch/
hagerzhang.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use crate::core::{
9    ArgminFloat, CostFunction, Error, Gradient, IterState, LineSearch, Problem, Solver,
10    TerminationReason, TerminationStatus, KV,
11};
12use argmin_math::{ArgminDot, ArgminScaledAdd};
13#[cfg(feature = "serde1")]
14use serde::{Deserialize, Serialize};
15
16type Triplet<F> = (F, F, F);
17
18/// # Hager-Zhang line search
19///
20/// The Hager-Zhang line search is a method to find a step length which obeys the strong Wolfe
21/// conditions.
22///
23/// ## Requirements on the optimization problem
24///
25/// The optimization problem is required to implement [`CostFunction`] and [`Gradient`].
26///
27/// ## Reference
28///
29/// William W. Hager and Hongchao Zhang. "A new conjugate gradient method with guaranteed
30/// descent and an efficient line search." SIAM J. Optim. 16(1), 2006, 170-192.
31/// DOI: <https://doi.org/10.1137/030601880>
32#[derive(Clone)]
33#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
34pub struct HagerZhangLineSearch<P, G, F> {
35    /// delta: (0, 0.5), used in the Wolfe conditions
36    delta: F,
37    /// sigma: [delta, 1), used in the Wolfe conditions
38    sigma: F,
39    /// epsilon: [0, infinity), used in the approximate Wolfe termination
40    epsilon: F,
41    /// epsilon_k
42    epsilon_k: F,
43    /// theta: (0, 1), used in the update rules when the potential intervals [a, c] or [c, b]
44    /// violate the opposite slope condition
45    theta: F,
46    /// gamma: (0, 1), determines when a bisection step is performed
47    gamma: F,
48    /// eta: (0, infinity), used in the lower bound for beta_k^N
49    eta: F,
50    /// initial a
51    a_x_init: F,
52    /// a
53    a_x: F,
54    /// phi(a)
55    a_f: F,
56    /// phi'(a)
57    a_g: F,
58    /// initial b
59    b_x_init: F,
60    /// b
61    b_x: F,
62    /// phi(b)
63    b_f: F,
64    /// phi'(b)
65    b_g: F,
66    /// initial c
67    c_x_init: F,
68    /// c
69    c_x: F,
70    /// phi(c)
71    c_f: F,
72    /// phi'(c)
73    c_g: F,
74    /// best x
75    best_x: F,
76    /// best function value
77    best_f: F,
78    /// best slope
79    best_g: F,
80    /// initial parameter vector
81    init_param: Option<P>,
82    /// initial cost
83    finit: F,
84    /// initial gradient (builder)
85    init_grad: Option<G>,
86    /// Search direction (builder)
87    search_direction: Option<G>,
88    /// Search direction in 1D
89    dginit: F,
90}
91
92impl<P, G, F> HagerZhangLineSearch<P, G, F>
93where
94    P: ArgminScaledAdd<G, F, P>,
95    G: ArgminDot<G, F>,
96    F: ArgminFloat,
97{
98    /// Construct a new instance of [`HagerZhangLineSearch`]
99    ///
100    /// # Example
101    ///
102    /// ```
103    /// # use argmin::solver::linesearch::HagerZhangLineSearch;
104    /// let hzls: HagerZhangLineSearch<Vec<f64>, Vec<f64>, f64> = HagerZhangLineSearch::new();
105    /// ```
106    pub fn new() -> Self {
107        HagerZhangLineSearch {
108            delta: float!(0.1),
109            sigma: float!(0.9),
110            epsilon: float!(1e-6),
111            epsilon_k: F::nan(),
112            theta: float!(0.5),
113            gamma: float!(0.66),
114            eta: float!(0.01),
115            a_x_init: F::epsilon(),
116            a_x: F::nan(),
117            a_f: F::nan(),
118            a_g: F::nan(),
119            b_x_init: float!(1e5),
120            b_x: F::nan(),
121            b_f: F::nan(),
122            b_g: F::nan(),
123            c_x_init: float!(1.0),
124            c_x: F::nan(),
125            c_f: F::nan(),
126            c_g: F::nan(),
127            best_x: float!(0.0),
128            best_f: F::infinity(),
129            best_g: F::nan(),
130            init_param: None,
131            init_grad: None,
132            search_direction: None,
133            dginit: F::nan(),
134            finit: F::infinity(),
135        }
136    }
137
138    /// Set delta and sigma.
139    ///
140    /// Delta defaults to `0.1` and must be in `(0, 1)`.
141    /// Sigma defaults to `0.9` and must be in `[delta, 1)`.
142    ///
143    /// Delta and Sigma correspond to the constants `c1` and `c2` of the strong Wolfe conditions,
144    /// respectively.
145    ///
146    /// # Example
147    ///
148    /// ```
149    /// # use argmin::solver::linesearch::HagerZhangLineSearch;
150    /// # use argmin::core::Error;
151    /// # fn main() -> Result<(), Error> {
152    /// let hzls: HagerZhangLineSearch<Vec<f64>, Vec<f64>, f64> =
153    ///     HagerZhangLineSearch::new().with_delta_sigma(0.2, 0.8)?;
154    /// # Ok(())
155    /// # }
156    /// ```
157    pub fn with_delta_sigma(mut self, delta: F, sigma: F) -> Result<Self, Error> {
158        if delta <= float!(0.0) || delta >= float!(1.0) || sigma < delta || sigma >= float!(1.0) {
159            return Err(argmin_error!(
160                InvalidParameter,
161                "`HagerZhangLineSearch`: delta must be in (0, 1) and sigma must be in [delta, 1)."
162            ));
163        }
164        self.delta = delta;
165        self.sigma = sigma;
166        Ok(self)
167    }
168
169    /// Set epsilon
170    ///
171    /// Used in the approximate strong Wolfe condition.
172    ///
173    /// Must be non-negative and defaults to `1e-6`.
174    ///
175    /// # Example
176    ///
177    /// ```
178    /// # use argmin::solver::linesearch::HagerZhangLineSearch;
179    /// # use argmin::core::Error;
180    /// # fn main() -> Result<(), Error> {
181    /// let hzls: HagerZhangLineSearch<Vec<f64>, Vec<f64>, f64> =
182    ///     HagerZhangLineSearch::new().with_epsilon(1e-8)?;
183    /// # Ok(())
184    /// # }
185    /// ```
186    pub fn with_epsilon(mut self, epsilon: F) -> Result<Self, Error> {
187        if epsilon < float!(0.0) {
188            return Err(argmin_error!(
189                InvalidParameter,
190                "`HagerZhangLineSearch`: epsilon must be >= 0."
191            ));
192        }
193        self.epsilon = epsilon;
194        Ok(self)
195    }
196
197    /// Set theta
198    ///
199    /// Used in the update rules when the potential intervals [a, c] or [c, b] violate the opposite
200    /// slope condition.
201    ///
202    /// Must be in `(0, 1)` and defaults to `0.5`.
203    ///
204    /// # Example
205    ///
206    /// ```
207    /// # use argmin::solver::linesearch::HagerZhangLineSearch;
208    /// # use argmin::core::Error;
209    /// # fn main() -> Result<(), Error> {
210    /// let hzls: HagerZhangLineSearch<Vec<f64>, Vec<f64>, f64> =
211    ///     HagerZhangLineSearch::new().with_theta(0.4)?;
212    /// # Ok(())
213    /// # }
214    /// ```
215    pub fn with_theta(mut self, theta: F) -> Result<Self, Error> {
216        if theta <= float!(0.0) || theta >= float!(1.0) {
217            return Err(argmin_error!(
218                InvalidParameter,
219                "`HagerZhangLineSearch`: theta must be in (0, 1)."
220            ));
221        }
222        self.theta = theta;
223        Ok(self)
224    }
225
226    /// Set gamma
227    ///
228    /// Determines when a bisection step is performed.
229    ///
230    /// Must be in `(0, 1)` and defaults to `0.66`.
231    ///
232    /// # Example
233    ///
234    /// ```
235    /// # use argmin::solver::linesearch::HagerZhangLineSearch;
236    /// # use argmin::core::Error;
237    /// # fn main() -> Result<(), Error> {
238    /// let hzls: HagerZhangLineSearch<Vec<f64>, Vec<f64>, f64> =
239    ///     HagerZhangLineSearch::new().with_gamma(0.7)?;
240    /// # Ok(())
241    /// # }
242    /// ```
243    pub fn with_gamma(mut self, gamma: F) -> Result<Self, Error> {
244        if gamma <= float!(0.0) || gamma >= float!(1.0) {
245            return Err(argmin_error!(
246                InvalidParameter,
247                "`HagerZhangLineSearch`: gamma must be in (0, 1)."
248            ));
249        }
250        self.gamma = gamma;
251        Ok(self)
252    }
253
254    /// Set eta
255    ///
256    /// Used in the lower bound for `beta_k^N`.
257    ///
258    /// Must be larger than zero and defaults to `0.01`.
259    ///
260    /// # Example
261    ///
262    /// ```
263    /// # use argmin::solver::linesearch::HagerZhangLineSearch;
264    /// # use argmin::core::Error;
265    /// # fn main() -> Result<(), Error> {
266    /// let hzls: HagerZhangLineSearch<Vec<f64>, Vec<f64>, f64> =
267    ///     HagerZhangLineSearch::new().with_eta(0.02)?;
268    /// # Ok(())
269    /// # }
270    /// ```
271    pub fn with_eta(mut self, eta: F) -> Result<Self, Error> {
272        if eta <= float!(0.0) {
273            return Err(argmin_error!(
274                InvalidParameter,
275                "`HagerZhangLineSearch`: eta must be > 0."
276            ));
277        }
278        self.eta = eta;
279        Ok(self)
280    }
281
282    /// Set lower and upper bound of step
283    ///
284    /// Defaults to a minimum step length of `EPSILON` and a maximum step length of `1e5`.
285    ///
286    /// The chosen values must satisfy `0 <= step_min < step_max`.
287    ///
288    /// # Example
289    ///
290    /// ```
291    /// # use argmin::solver::linesearch::HagerZhangLineSearch;
292    /// # use argmin::core::Error;
293    /// # fn main() -> Result<(), Error> {
294    /// let hzls: HagerZhangLineSearch<Vec<f64>, Vec<f64>, f64> =
295    ///     HagerZhangLineSearch::new().with_bounds(1e-3, 1.0)?;
296    /// # Ok(())
297    /// # }
298    /// ```
299    pub fn with_bounds(mut self, step_min: F, step_max: F) -> Result<Self, Error> {
300        if step_min < float!(0.0) || step_max <= step_min {
301            return Err(argmin_error!(
302                InvalidParameter,
303                concat!(
304                    "`HagerZhangLineSearch`: minimum and maximum step length must be chosen ",
305                    "such that 0 <= step_min < step_max."
306                )
307            ));
308        }
309        self.a_x_init = step_min;
310        self.b_x_init = step_max;
311        Ok(self)
312    }
313
314    fn update<O>(
315        &mut self,
316        problem: &mut Problem<O>,
317        (a_x, a_f, a_g): Triplet<F>,
318        (b_x, b_f, b_g): Triplet<F>,
319        (c_x, c_f, c_g): Triplet<F>,
320    ) -> Result<(Triplet<F>, Triplet<F>), Error>
321    where
322        O: CostFunction<Param = P, Output = F> + Gradient<Param = P, Gradient = G>,
323    {
324        // U0
325        if c_x <= a_x || c_x >= b_x {
326            // nothing changes.
327            return Ok(((a_x, a_f, a_g), (b_x, b_f, b_g)));
328        }
329
330        // U1
331        if c_g >= float!(0.0) {
332            return Ok(((a_x, a_f, a_g), (c_x, c_f, c_g)));
333        }
334
335        // U2
336        if c_g < float!(0.0) && c_f <= self.finit + self.epsilon_k {
337            return Ok(((c_x, c_f, c_g), (b_x, b_f, b_g)));
338        }
339
340        // U3
341        if c_g < float!(0.0) && c_f > self.finit + self.epsilon_k {
342            let mut ah_x = a_x;
343            let mut ah_f = a_f;
344            let mut ah_g = a_g;
345            let mut bh_x = c_x;
346            loop {
347                let d_x = (float!(1.0) - self.theta) * ah_x + self.theta * bh_x;
348                let d_f = self.calc(problem, d_x)?;
349                let d_g = self.calc_grad(problem, d_x)?;
350                if d_g >= float!(0.0) {
351                    return Ok(((ah_x, ah_f, ah_g), (d_x, d_f, d_g)));
352                }
353                if d_g < float!(0.0) && d_f <= self.finit + self.epsilon_k {
354                    ah_x = d_x;
355                    ah_f = d_f;
356                    ah_g = d_g;
357                }
358                if d_g < float!(0.0) && d_f > self.finit + self.epsilon_k {
359                    bh_x = d_x;
360                }
361            }
362        }
363
364        // return Ok(((a_x, a_f, a_g), (b_x, b_f, b_g)));
365        Err(argmin_error!(
366            PotentialBug,
367            "`HagerZhangLineSearch`: Reached unreachable point in `update` method."
368        ))
369    }
370
371    /// secant step
372    fn secant(&self, a_x: F, a_g: F, b_x: F, b_g: F) -> F {
373        (a_x * b_g - b_x * a_g) / (b_g - a_g)
374    }
375
376    /// double secant step
377    fn secant2<O>(
378        &mut self,
379        problem: &mut Problem<O>,
380        (a_x, a_f, a_g): Triplet<F>,
381        (b_x, b_f, b_g): Triplet<F>,
382    ) -> Result<(Triplet<F>, Triplet<F>), Error>
383    where
384        O: CostFunction<Param = P, Output = F> + Gradient<Param = P, Gradient = G>,
385    {
386        // S1
387        let c_x = self.secant(a_x, a_g, b_x, b_g);
388        let c_f = self.calc(problem, c_x)?;
389        let c_g = self.calc_grad(problem, c_x)?;
390        let mut c_bar_x: F = float!(0.0);
391
392        let ((aa_x, aa_f, aa_g), (bb_x, bb_f, bb_g)) =
393            self.update(problem, (a_x, a_f, a_g), (b_x, b_f, b_g), (c_x, c_f, c_g))?;
394
395        // S2
396        if (c_x - bb_x).abs() < F::epsilon() {
397            c_bar_x = self.secant(b_x, b_g, bb_x, bb_g);
398        }
399
400        // S3
401        if (c_x - aa_x).abs() < F::epsilon() {
402            c_bar_x = self.secant(a_x, a_g, aa_x, aa_g);
403        }
404
405        // S4
406        if (c_x - aa_x).abs() < F::epsilon() || (c_x - bb_x).abs() < F::epsilon() {
407            let c_bar_f = self.calc(problem, c_bar_x)?;
408            let c_bar_g = self.calc_grad(problem, c_bar_x)?;
409
410            let (a_bar, b_bar) = self.update(
411                problem,
412                (aa_x, aa_f, aa_g),
413                (bb_x, bb_f, bb_g),
414                (c_bar_x, c_bar_f, c_bar_g),
415            )?;
416            Ok((a_bar, b_bar))
417        } else {
418            Ok(((aa_x, aa_f, aa_g), (bb_x, bb_f, bb_g)))
419        }
420    }
421
422    fn calc<O>(&mut self, problem: &mut Problem<O>, alpha: F) -> Result<F, Error>
423    where
424        O: CostFunction<Param = P, Output = F>,
425    {
426        let tmp = self
427            .init_param
428            .as_ref()
429            .ok_or_else(argmin_error_closure!(
430                PotentialBug,
431                "`HagerZhangLineSearch`: `init_param` is `None` in `calc`."
432            ))?
433            .scaled_add(&alpha, self.search_direction.as_ref().unwrap());
434        problem.cost(&tmp)
435    }
436
437    fn calc_grad<O>(&mut self, problem: &mut Problem<O>, alpha: F) -> Result<F, Error>
438    where
439        O: Gradient<Param = P, Gradient = G>,
440    {
441        let tmp = self
442            .init_param
443            .as_ref()
444            .ok_or_else(argmin_error_closure!(
445                PotentialBug,
446                "`HagerZhangLineSearch`: `init_param` is `None` in `calc_grad`."
447            ))?
448            .scaled_add(&alpha, self.search_direction.as_ref().unwrap());
449        let grad = problem.gradient(&tmp)?;
450        Ok(self.search_direction.as_ref().unwrap().dot(&grad))
451    }
452
453    fn set_best(&mut self) {
454        if self.a_f <= self.b_f && self.a_f <= self.c_f {
455            self.best_x = self.a_x;
456            self.best_f = self.a_f;
457            self.best_g = self.a_g;
458        }
459
460        if self.b_f <= self.a_f && self.b_f <= self.c_f {
461            self.best_x = self.b_x;
462            self.best_f = self.b_f;
463            self.best_g = self.b_g;
464        }
465
466        if self.c_f <= self.a_f && self.c_f <= self.b_f {
467            self.best_x = self.c_x;
468            self.best_f = self.c_f;
469            self.best_g = self.c_g;
470        }
471    }
472}
473
474impl<P, G, F> Default for HagerZhangLineSearch<P, G, F>
475where
476    P: ArgminScaledAdd<G, F, P>,
477    G: ArgminDot<G, F>,
478    F: ArgminFloat,
479{
480    fn default() -> Self {
481        HagerZhangLineSearch::new()
482    }
483}
484
485impl<P, G, F> LineSearch<G, F> for HagerZhangLineSearch<P, G, F> {
486    /// Set search direction
487    fn search_direction(&mut self, search_direction: G) {
488        self.search_direction = Some(search_direction);
489    }
490
491    /// Set initial alpha value
492    fn initial_step_length(&mut self, alpha: F) -> Result<(), Error> {
493        self.c_x_init = alpha;
494        Ok(())
495    }
496}
497
498impl<P, G, O, F> Solver<O, IterState<P, G, (), (), (), F>> for HagerZhangLineSearch<P, G, F>
499where
500    O: CostFunction<Param = P, Output = F> + Gradient<Param = P, Gradient = G>,
501    P: Clone + ArgminDot<G, F> + ArgminScaledAdd<G, F, P>,
502    G: Clone + ArgminDot<G, F>,
503    F: ArgminFloat,
504{
505    fn name(&self) -> &str {
506        "Hager-Zhang line search"
507    }
508
509    fn init(
510        &mut self,
511        problem: &mut Problem<O>,
512        mut state: IterState<P, G, (), (), (), F>,
513    ) -> Result<(IterState<P, G, (), (), (), F>, Option<KV>), Error> {
514        check_param!(
515            self.search_direction,
516            concat!(
517                "`HagerZhangLineSearch`: Search direction not initialized. ",
518                "Call `search_direction` before executing the solver."
519            )
520        );
521
522        self.init_param = Some(state.take_param().ok_or_else(argmin_error_closure!(
523            NotInitialized,
524            concat!(
525                "`HagerZhangLineSearch` requires an initial parameter vector. ",
526                "Please provide an initial guess via `Executor`s `configure` method."
527            )
528        ))?);
529
530        let cost = state.get_cost();
531        self.finit = if cost.is_infinite() {
532            problem.cost(self.init_param.as_ref().unwrap())?
533        } else {
534            cost
535        };
536
537        self.init_grad = Some(
538            state
539                .take_gradient()
540                .map(Result::Ok)
541                .unwrap_or_else(|| problem.gradient(self.init_param.as_ref().unwrap()))?,
542        );
543
544        self.a_x = self.a_x_init;
545        self.b_x = self.b_x_init;
546        self.c_x = self.c_x_init;
547
548        self.a_f = self.calc(problem, self.a_x)?;
549        self.a_g = self.calc_grad(problem, self.a_x)?;
550        self.b_f = self.calc(problem, self.b_x)?;
551        self.b_g = self.calc_grad(problem, self.b_x)?;
552        self.c_f = self.calc(problem, self.c_x)?;
553        self.c_g = self.calc_grad(problem, self.c_x)?;
554
555        self.epsilon_k = self.epsilon * self.finit.abs();
556
557        self.dginit = self
558            .init_grad
559            .as_ref()
560            .unwrap()
561            .dot(self.search_direction.as_ref().unwrap());
562
563        self.set_best();
564        let new_param = self
565            .init_param
566            .as_ref()
567            .unwrap()
568            .scaled_add(&self.best_x, self.search_direction.as_ref().unwrap());
569        let best_f = self.best_f;
570
571        Ok((state.param(new_param).cost(best_f), None))
572    }
573
574    fn next_iter(
575        &mut self,
576        problem: &mut Problem<O>,
577        state: IterState<P, G, (), (), (), F>,
578    ) -> Result<(IterState<P, G, (), (), (), F>, Option<KV>), Error> {
579        // L1
580        let aa = (self.a_x, self.a_f, self.a_g);
581        let bb = (self.b_x, self.b_f, self.b_g);
582        let ((mut at_x, mut at_f, mut at_g), (mut bt_x, mut bt_f, mut bt_g)) =
583            self.secant2(problem, aa, bb)?;
584
585        // L2
586        if bt_x - at_x > self.gamma * (self.b_x - self.a_x) {
587            let c_x = (at_x + bt_x) / float!(2.0);
588            let tmp = self
589                .init_param
590                .as_ref()
591                .unwrap()
592                .scaled_add(&c_x, self.search_direction.as_ref().unwrap());
593            let c_f = problem.cost(&tmp)?;
594            let grad = problem.gradient(&tmp)?;
595            let c_g = self.search_direction.as_ref().unwrap().dot(&grad);
596            let ((an_x, an_f, an_g), (bn_x, bn_f, bn_g)) = self.update(
597                problem,
598                (at_x, at_f, at_g),
599                (bt_x, bt_f, bt_g),
600                (c_x, c_f, c_g),
601            )?;
602            at_x = an_x;
603            at_f = an_f;
604            at_g = an_g;
605            bt_x = bn_x;
606            bt_f = bn_f;
607            bt_g = bn_g;
608        }
609
610        // L3
611        self.a_x = at_x;
612        self.a_f = at_f;
613        self.a_g = at_g;
614        self.b_x = bt_x;
615        self.b_f = bt_f;
616        self.b_g = bt_g;
617
618        self.set_best();
619        let new_param = self
620            .init_param
621            .as_ref()
622            .unwrap()
623            .scaled_add(&self.best_x, self.search_direction.as_ref().unwrap());
624        Ok((state.param(new_param).cost(self.best_f), None))
625    }
626
627    fn terminate(&mut self, _state: &IterState<P, G, (), (), (), F>) -> TerminationStatus {
628        if self.best_f - self.finit <= self.delta * self.best_x * self.dginit
629            && self.best_g >= self.sigma * self.dginit
630        {
631            return TerminationStatus::Terminated(TerminationReason::SolverConverged);
632        }
633        if (float!(2.0) * self.delta - float!(1.0)) * self.dginit >= self.best_g
634            && self.best_g >= self.sigma * self.dginit
635            && self.best_f <= self.finit + self.epsilon_k
636        {
637            return TerminationStatus::Terminated(TerminationReason::SolverConverged);
638        }
639        TerminationStatus::NotTerminated
640    }
641}
642
643#[cfg(test)]
644mod tests {
645    use super::*;
646    use crate::core::{test_utils::TestProblem, ArgminError, State};
647
648    test_trait_impl!(hagerzhang, HagerZhangLineSearch<Vec<f64>, Vec<f64>, f64>);
649
650    #[test]
651    fn test_new() {
652        let hzls: HagerZhangLineSearch<Vec<f64>, Vec<f64>, f64> = HagerZhangLineSearch::new();
653        let HagerZhangLineSearch {
654            delta,
655            sigma,
656            epsilon,
657            epsilon_k,
658            theta,
659            gamma,
660            eta,
661            a_x_init,
662            a_x,
663            a_f,
664            a_g,
665            b_x_init,
666            b_x,
667            b_f,
668            b_g,
669            c_x_init,
670            c_x,
671            c_f,
672            c_g,
673            best_x,
674            best_f,
675            best_g,
676            init_param,
677            init_grad,
678            search_direction,
679            dginit,
680            finit,
681        } = hzls;
682
683        assert_eq!(delta.to_ne_bytes(), 0.1f64.to_ne_bytes());
684        assert_eq!(sigma.to_ne_bytes(), 0.9f64.to_ne_bytes());
685        assert_eq!(epsilon.to_ne_bytes(), 1e-6f64.to_ne_bytes());
686        assert!(epsilon_k.is_nan());
687        assert_eq!(theta.to_ne_bytes(), 0.5f64.to_ne_bytes());
688        assert_eq!(gamma.to_ne_bytes(), 0.66f64.to_ne_bytes());
689        assert_eq!(eta.to_ne_bytes(), 0.01f64.to_ne_bytes());
690        assert_eq!(a_x_init.to_ne_bytes(), f64::EPSILON.to_ne_bytes());
691        assert!(a_x.is_nan());
692        assert!(a_f.is_nan());
693        assert!(a_g.is_nan());
694        assert_eq!(b_x_init.to_ne_bytes(), 1e5f64.to_ne_bytes());
695        assert!(b_x.is_nan());
696        assert!(b_f.is_nan());
697        assert!(b_g.is_nan());
698        assert_eq!(c_x_init.to_ne_bytes(), 1.0f64.to_ne_bytes());
699        assert!(c_x.is_nan());
700        assert!(c_f.is_nan());
701        assert!(c_g.is_nan());
702        assert_eq!(best_x.to_ne_bytes(), 0.0f64.to_ne_bytes());
703        assert!(best_f.is_infinite());
704        assert!(best_f.is_sign_positive());
705        assert!(best_g.is_nan());
706        assert!(init_param.is_none());
707        assert!(init_grad.is_none());
708        assert!(search_direction.is_none());
709        assert!(dginit.is_nan());
710        assert!(finit.is_infinite());
711        assert!(finit.is_sign_positive());
712    }
713
714    #[test]
715    fn test_with_delta_sigma() {
716        // correct parameters
717        for (delta, sigma) in [
718            (0.2, 0.8),
719            (0.5, 0.5),
720            (0.0 + f64::EPSILON, 0.5),
721            (0.2, 1.0 - f64::EPSILON),
722            (0.5, 0.5),
723        ] {
724            let hzls: HagerZhangLineSearch<Vec<f64>, Vec<f64>, f64> = HagerZhangLineSearch::new();
725            let res = hzls.with_delta_sigma(delta, sigma);
726            assert!(res.is_ok());
727
728            let hzls = res.unwrap();
729            assert_eq!(hzls.delta.to_ne_bytes(), delta.to_ne_bytes());
730            assert_eq!(hzls.sigma.to_ne_bytes(), sigma.to_ne_bytes());
731        }
732
733        // incorrect parameters
734        for (delta, sigma) in [
735            (-1.0, 0.5),
736            (0.0, 0.5),
737            (1.0, 0.5),
738            (2.0, 0.5),
739            (0.5, 0.5 - f64::EPSILON),
740            (0.5, 0.0),
741            (0.5, 1.0),
742            (0.5, 2.0),
743            (0.6, 0.2),
744        ] {
745            let hzls: HagerZhangLineSearch<Vec<f64>, Vec<f64>, f64> = HagerZhangLineSearch::new();
746            let res = hzls.with_delta_sigma(delta, sigma);
747            assert_error!(
748                res,
749                ArgminError,
750                concat!(
751                    "Invalid parameter: \"`HagerZhangLineSearch`: ",
752                    "delta must be in (0, 1) and sigma must be in [delta, 1).\""
753                )
754            );
755        }
756    }
757
758    #[test]
759    fn test_with_epsilon() {
760        // correct parameters
761        for epsilon in [1e-6, 0.0, 1e-2, 1.0, 2.0] {
762            let hzls: HagerZhangLineSearch<Vec<f64>, Vec<f64>, f64> = HagerZhangLineSearch::new();
763            let res = hzls.with_epsilon(epsilon);
764            assert!(res.is_ok());
765
766            let hzls = res.unwrap();
767            assert_eq!(hzls.epsilon.to_ne_bytes(), epsilon.to_ne_bytes());
768        }
769
770        // incorrect parameters
771        for epsilon in [-f64::EPSILON, -1.0, -100.0, -42.0] {
772            let hzls: HagerZhangLineSearch<Vec<f64>, Vec<f64>, f64> = HagerZhangLineSearch::new();
773            let res = hzls.with_epsilon(epsilon);
774            assert_error!(
775                res,
776                ArgminError,
777                concat!(
778                    "Invalid parameter: \"`HagerZhangLineSearch`: ",
779                    "epsilon must be >= 0.\""
780                )
781            );
782        }
783    }
784
785    #[test]
786    fn test_with_theta() {
787        // correct parameters
788        for theta in [0.0 + f64::EPSILON, 1e-2, 0.5, 0.6, 1.0 - f64::EPSILON] {
789            let hzls: HagerZhangLineSearch<Vec<f64>, Vec<f64>, f64> = HagerZhangLineSearch::new();
790            let res = hzls.with_theta(theta);
791            assert!(res.is_ok());
792
793            let hzls = res.unwrap();
794            assert_eq!(hzls.theta.to_ne_bytes(), theta.to_ne_bytes());
795        }
796
797        // incorrect parameters
798        for theta in [0.0, 1.0, -100.0, 42.0] {
799            let hzls: HagerZhangLineSearch<Vec<f64>, Vec<f64>, f64> = HagerZhangLineSearch::new();
800            let res = hzls.with_theta(theta);
801            assert_error!(
802                res,
803                ArgminError,
804                concat!(
805                    "Invalid parameter: \"`HagerZhangLineSearch`: ",
806                    "theta must be in (0, 1).\""
807                )
808            );
809        }
810    }
811
812    #[test]
813    fn test_with_gamma() {
814        // correct parameters
815        for gamma in [0.0 + f64::EPSILON, 1e-2, 0.5, 0.6, 1.0 - f64::EPSILON] {
816            let hzls: HagerZhangLineSearch<Vec<f64>, Vec<f64>, f64> = HagerZhangLineSearch::new();
817            let res = hzls.with_gamma(gamma);
818            assert!(res.is_ok());
819
820            let hzls = res.unwrap();
821            assert_eq!(hzls.gamma.to_ne_bytes(), gamma.to_ne_bytes());
822        }
823
824        // incorrect parameters
825        for gamma in [0.0, 1.0, -100.0, 42.0] {
826            let hzls: HagerZhangLineSearch<Vec<f64>, Vec<f64>, f64> = HagerZhangLineSearch::new();
827            let res = hzls.with_gamma(gamma);
828            assert_error!(
829                res,
830                ArgminError,
831                concat!(
832                    "Invalid parameter: \"`HagerZhangLineSearch`: ",
833                    "gamma must be in (0, 1).\""
834                )
835            );
836        }
837    }
838
839    #[test]
840    fn test_with_eta() {
841        // correct parameters
842        for eta in [0.0 + f64::EPSILON, 1e-2, 0.5, 1.0, 10.0] {
843            let hzls: HagerZhangLineSearch<Vec<f64>, Vec<f64>, f64> = HagerZhangLineSearch::new();
844            let res = hzls.with_eta(eta);
845            assert!(res.is_ok());
846
847            let hzls = res.unwrap();
848            assert_eq!(hzls.eta.to_ne_bytes(), eta.to_ne_bytes());
849        }
850
851        // incorrect parameters
852        for eta in [0.0, -f64::EPSILON, -100.0, -42.0] {
853            let hzls: HagerZhangLineSearch<Vec<f64>, Vec<f64>, f64> = HagerZhangLineSearch::new();
854            let res = hzls.with_eta(eta);
855            assert_error!(
856                res,
857                ArgminError,
858                concat!(
859                    "Invalid parameter: \"`HagerZhangLineSearch`: ",
860                    "eta must be > 0.\""
861                )
862            );
863        }
864    }
865
866    #[test]
867    fn test_with_bounds() {
868        // correct parameters
869        for (min, max) in [
870            (0.2, 0.8),
871            (0.5 - f64::EPSILON, 0.5),
872            (0.5, 0.5 + f64::EPSILON),
873            (0.0, 0.5),
874            (0.0 + f64::EPSILON, 0.5),
875            (50.0, 100.0),
876        ] {
877            let hzls: HagerZhangLineSearch<Vec<f64>, Vec<f64>, f64> = HagerZhangLineSearch::new();
878            let res = hzls.with_bounds(min, max);
879            assert!(res.is_ok());
880
881            let hzls = res.unwrap();
882            assert_eq!(hzls.a_x_init.to_ne_bytes(), min.to_ne_bytes());
883            assert_eq!(hzls.b_x_init.to_ne_bytes(), max.to_ne_bytes());
884        }
885
886        // incorrect parameters
887        for (min, max) in [
888            (-1.0, 0.5),
889            (0.5, 0.5),
890            (0.5 + f64::EPSILON, 0.5),
891            (0.5, 0.0),
892            (-1000.0, -100.0),
893        ] {
894            let hzls: HagerZhangLineSearch<Vec<f64>, Vec<f64>, f64> = HagerZhangLineSearch::new();
895            let res = hzls.with_bounds(min, max);
896            assert_error!(
897                res,
898                ArgminError,
899                concat!(
900                    "Invalid parameter: \"`HagerZhangLineSearch`: minimum and maximum step length ",
901                    "must be chosen such that 0 <= step_min < step_max.\""
902                )
903            );
904        }
905    }
906
907    #[test]
908    fn test_init_search_direction_not_set() {
909        let mut hzls: HagerZhangLineSearch<Vec<f64>, Vec<f64>, f64> = HagerZhangLineSearch::new();
910        let res = hzls.init(&mut Problem::new(TestProblem::new()), IterState::new());
911        assert_error!(
912            res,
913            ArgminError,
914            concat!(
915                "Not initialized: \"`HagerZhangLineSearch`: Search direction not initialized. ",
916                "Call `search_direction` before executing the solver.\""
917            )
918        );
919    }
920
921    #[test]
922    fn test_init_param_not_set() {
923        let mut hzls: HagerZhangLineSearch<Vec<f64>, Vec<f64>, f64> = HagerZhangLineSearch::new();
924        hzls.search_direction(vec![1.0f64]);
925        let res = hzls.init(&mut Problem::new(TestProblem::new()), IterState::new());
926        assert_error!(
927            res,
928            ArgminError,
929            concat!(
930                "Not initialized: \"`HagerZhangLineSearch` requires an initial parameter vector. ",
931                "Please provide an initial guess via `Executor`s `configure` method.\""
932            )
933        );
934    }
935}