argmin/solver/linesearch/condition/
armijo.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use super::LineSearchCondition;
9use crate::core::{ArgminFloat, Error};
10use argmin_math::ArgminDot;
11#[cfg(feature = "serde1")]
12use serde::{Deserialize, Serialize};
13
14/// # Armijo Condition
15///
16/// Ensures that the step length "sufficiently" decreases the cost function value.
17#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
18#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
19pub struct ArmijoCondition<F> {
20    c: F,
21}
22
23impl<F> ArmijoCondition<F>
24where
25    F: ArgminFloat,
26{
27    /// Construct a new [`ArmijoCondition`] instance.
28    ///
29    /// # Example
30    ///
31    /// ```
32    /// # use argmin::solver::linesearch::condition::ArmijoCondition;
33    /// let armijo = ArmijoCondition::new(0.0001f64);
34    /// ```
35    pub fn new(c: F) -> Result<Self, Error> {
36        if c <= float!(0.0) || c >= float!(1.0) {
37            return Err(argmin_error!(
38                InvalidParameter,
39                "ArmijoCondition: Parameter c must be in (0, 1)"
40            ));
41        }
42        Ok(ArmijoCondition { c })
43    }
44}
45
46impl<T, G, F> LineSearchCondition<T, G, F> for ArmijoCondition<F>
47where
48    G: ArgminDot<T, F>,
49    F: ArgminFloat,
50{
51    fn evaluate_condition(
52        &self,
53        current_cost: F,
54        _current_gradient: Option<&G>,
55        initial_cost: F,
56        initial_gradient: &G,
57        search_direction: &T,
58        step_length: F,
59    ) -> bool {
60        current_cost <= initial_cost + self.c * step_length * initial_gradient.dot(search_direction)
61    }
62
63    fn requires_current_gradient(&self) -> bool {
64        false
65    }
66}
67
68#[cfg(test)]
69mod tests {
70    use super::*;
71    use crate::core::ArgminError;
72
73    test_trait_impl!(armijo, ArmijoCondition<f64>);
74
75    #[test]
76    fn test_armijo_new() {
77        let c: f64 = 0.01;
78        let ArmijoCondition { c: c_arm } = ArmijoCondition::new(c).unwrap();
79        assert_eq!(c.to_ne_bytes(), c_arm.to_ne_bytes());
80
81        assert_error!(
82            ArmijoCondition::new(1.0f64),
83            ArgminError,
84            "Invalid parameter: \"ArmijoCondition: Parameter c must be in (0, 1)\""
85        );
86
87        assert_error!(
88            ArmijoCondition::new(2.0f64),
89            ArgminError,
90            "Invalid parameter: \"ArmijoCondition: Parameter c must be in (0, 1)\""
91        );
92
93        assert_error!(
94            ArmijoCondition::new(0.0f64),
95            ArgminError,
96            "Invalid parameter: \"ArmijoCondition: Parameter c must be in (0, 1)\""
97        );
98
99        assert_error!(
100            ArmijoCondition::new(-1.0f64),
101            ArgminError,
102            "Invalid parameter: \"ArmijoCondition: Parameter c must be in (0, 1)\""
103        );
104    }
105
106    #[test]
107    fn test_armijo() {
108        let c: f64 = 0.50;
109        let cond = ArmijoCondition::new(c).unwrap();
110        let f = |x: f64, y: f64| x.powf(2.0) + y.powf(2.0);
111        let g = |x: f64, y: f64| vec![2.0 * x, 2.0 * y];
112        let initial_x = -1.0;
113        let initial_y = -0.0;
114        let search_direction = vec![1.0, 0.0];
115        for (alpha, acc) in [
116            (0.001, true),
117            (0.03, true),
118            (0.2, true),
119            (0.5, true),
120            (0.9, true),
121            (0.99, true),
122            (1.0, true),
123            (1.0 + f64::EPSILON, false),
124            (1.5, false),
125            (1.8, false),
126            (2.0, false),
127            (2.3, false),
128        ] {
129            assert_eq!(
130                cond.evaluate_condition(
131                    f(initial_x + alpha, initial_y),
132                    Some(&g(initial_x + alpha, initial_y)),
133                    f(initial_x, initial_y),
134                    &g(initial_x, initial_y),
135                    &search_direction,
136                    alpha,
137                ),
138                acc
139            );
140        }
141    }
142}