argmin/solver/linesearch/condition/
strongwolfe.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use super::LineSearchCondition;
9use crate::core::{ArgminFloat, Error};
10use argmin_math::ArgminDot;
11#[cfg(feature = "serde1")]
12use serde::{Deserialize, Serialize};
13
14/// # Strong Wolfe conditions
15///
16/// Assures that a step length satisfies a "sufficient decrease" in cost function value (see
17/// [`ArmijoCondition`](`crate::solver::linesearch::condition::ArmijoCondition`) as well as that
18/// the absolute value of the slope has been reduced sufficiently (thus making it more likely to be
19/// close to a critical point).
20#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
21#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
22pub struct StrongWolfeCondition<F> {
23    c1: F,
24    c2: F,
25}
26
27impl<F> StrongWolfeCondition<F>
28where
29    F: ArgminFloat,
30{
31    /// Construct a new instance of [`StrongWolfeCondition`].
32    ///
33    /// # Example
34    ///
35    /// ```
36    /// # use argmin::solver::linesearch::condition::StrongWolfeCondition;
37    /// let strongwolfe = StrongWolfeCondition::new(0.0001f64, 0.1f64);
38    /// ```
39    pub fn new(c1: F, c2: F) -> Result<Self, Error> {
40        if c1 <= float!(0.0) || c1 >= float!(1.0) {
41            return Err(argmin_error!(
42                InvalidParameter,
43                "StrongWolfeCondition: Parameter c1 must be in (0, 1)"
44            ));
45        }
46        if c2 <= c1 || c2 >= float!(1.0) {
47            return Err(argmin_error!(
48                InvalidParameter,
49                "StrongWolfeCondition: Parameter c2 must be in (c1, 1)"
50            ));
51        }
52        Ok(StrongWolfeCondition { c1, c2 })
53    }
54}
55
56impl<T, G, F> LineSearchCondition<T, G, F> for StrongWolfeCondition<F>
57where
58    G: ArgminDot<T, F>,
59    F: ArgminFloat,
60{
61    fn evaluate_condition(
62        &self,
63        current_cost: F,
64        current_gradient: Option<&G>,
65        initial_cost: F,
66        initial_gradient: &G,
67        search_direction: &T,
68        step_length: F,
69    ) -> bool {
70        let tmp = initial_gradient.dot(search_direction);
71        (current_cost <= initial_cost + self.c1 * step_length * tmp)
72            && current_gradient
73                .expect("Gradient not supplied to `evaluate_condition` of `StrongWolveCondition`")
74                .dot(search_direction)
75                .abs()
76                <= self.c2 * tmp.abs()
77    }
78
79    fn requires_current_gradient(&self) -> bool {
80        true
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87    use crate::core::ArgminError;
88
89    test_trait_impl!(strongwolfe, StrongWolfeCondition<f64>);
90
91    #[test]
92    fn test_strongwolfe_new() {
93        let c1: f64 = 0.01;
94        let c2: f64 = 0.08;
95        let StrongWolfeCondition {
96            c1: c1_wolfe,
97            c2: c2_wolfe,
98        } = StrongWolfeCondition::new(c1, c2).unwrap();
99        assert_eq!(c1.to_ne_bytes(), c1_wolfe.to_ne_bytes());
100        assert_eq!(c2.to_ne_bytes(), c2_wolfe.to_ne_bytes());
101
102        // c1
103        assert_error!(
104            StrongWolfeCondition::new(1.0, 0.5),
105            ArgminError,
106            "Invalid parameter: \"StrongWolfeCondition: Parameter c1 must be in (0, 1)\""
107        );
108
109        assert_error!(
110            StrongWolfeCondition::new(0.0, 0.5),
111            ArgminError,
112            "Invalid parameter: \"StrongWolfeCondition: Parameter c1 must be in (0, 1)\""
113        );
114
115        assert_error!(
116            StrongWolfeCondition::new(-1.0, 0.5),
117            ArgminError,
118            "Invalid parameter: \"StrongWolfeCondition: Parameter c1 must be in (0, 1)\""
119        );
120
121        assert_error!(
122            StrongWolfeCondition::new(2.0, 0.5),
123            ArgminError,
124            "Invalid parameter: \"StrongWolfeCondition: Parameter c1 must be in (0, 1)\""
125        );
126
127        // c2
128        assert_error!(
129            StrongWolfeCondition::new(0.5, -1.0),
130            ArgminError,
131            "Invalid parameter: \"StrongWolfeCondition: Parameter c2 must be in (c1, 1)\""
132        );
133
134        assert_error!(
135            StrongWolfeCondition::new(0.5, 0.0),
136            ArgminError,
137            "Invalid parameter: \"StrongWolfeCondition: Parameter c2 must be in (c1, 1)\""
138        );
139
140        assert_error!(
141            StrongWolfeCondition::new(0.5, 0.5),
142            ArgminError,
143            "Invalid parameter: \"StrongWolfeCondition: Parameter c2 must be in (c1, 1)\""
144        );
145
146        assert_error!(
147            StrongWolfeCondition::new(0.5, 1.0),
148            ArgminError,
149            "Invalid parameter: \"StrongWolfeCondition: Parameter c2 must be in (c1, 1)\""
150        );
151
152        assert_error!(
153            StrongWolfeCondition::new(0.5, 2.0),
154            ArgminError,
155            "Invalid parameter: \"StrongWolfeCondition: Parameter c2 must be in (c1, 1)\""
156        );
157    }
158
159    #[test]
160    fn test_strongwolfe() {
161        // Armijo basically never active (c1 so low that only constraint on gradients have impact
162        // on the chosen function).
163        let c1: f64 = 0.01;
164        let c2: f64 = 0.9;
165        let cond = StrongWolfeCondition::new(c1, c2).unwrap();
166        let f = |x: f64, y: f64| x.powf(2.0) + y.powf(2.0);
167        let g = |x: f64, y: f64| vec![2.0 * x, 2.0 * y];
168        let initial_x = -1.0;
169        let initial_y = -0.0;
170        let search_direction = vec![1.0, 0.0];
171        for (alpha, acc) in [
172            (0.001, false),
173            (0.03, false),
174            (0.1 - f64::EPSILON, false),
175            (0.1, true),
176            (0.15, true),
177            (0.9, true),
178            (0.99, true),
179            (1.0, true),
180            (1.9, true),
181            (1.9 + f64::EPSILON, false),
182            (2.0, false),
183            (2.3, false),
184        ] {
185            assert_eq!(
186                cond.evaluate_condition(
187                    f(initial_x + alpha, initial_y),
188                    Some(&g(initial_x + alpha, initial_y)),
189                    f(initial_x, initial_y),
190                    &g(initial_x, initial_y),
191                    &search_direction,
192                    alpha,
193                ),
194                acc
195            );
196        }
197
198        // Armijo active
199        let c1: f64 = 0.5;
200        let c2: f64 = 0.9;
201        let cond = StrongWolfeCondition::new(c1, c2).unwrap();
202        let f = |x: f64, y: f64| x.powf(2.0) + y.powf(2.0);
203        let g = |x: f64, y: f64| vec![2.0 * x, 2.0 * y];
204        let initial_x = -1.0;
205        let initial_y = -0.0;
206        let search_direction = vec![1.0, 0.0];
207        for (alpha, acc) in [
208            (0.001, false),
209            (0.03, false),
210            (0.1 - f64::EPSILON, false),
211            (0.1, true),
212            (0.15, true),
213            (0.9, true),
214            (0.99, true),
215            (1.0, true),
216            (1.0 + f64::EPSILON, false),
217            (1.9, false),
218            (2.0, false),
219            (2.3, false),
220        ] {
221            assert_eq!(
222                cond.evaluate_condition(
223                    f(initial_x + alpha, initial_y),
224                    Some(&g(initial_x + alpha, initial_y)),
225                    f(initial_x, initial_y),
226                    &g(initial_x, initial_y),
227                    &search_direction,
228                    alpha,
229                ),
230                acc
231            );
232        }
233    }
234}