argmin/solver/linesearch/condition/
strongwolfe.rs1use super::LineSearchCondition;
9use crate::core::{ArgminFloat, Error};
10use argmin_math::ArgminDot;
11#[cfg(feature = "serde1")]
12use serde::{Deserialize, Serialize};
13
14#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
21#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
22pub struct StrongWolfeCondition<F> {
23 c1: F,
24 c2: F,
25}
26
27impl<F> StrongWolfeCondition<F>
28where
29 F: ArgminFloat,
30{
31 pub fn new(c1: F, c2: F) -> Result<Self, Error> {
40 if c1 <= float!(0.0) || c1 >= float!(1.0) {
41 return Err(argmin_error!(
42 InvalidParameter,
43 "StrongWolfeCondition: Parameter c1 must be in (0, 1)"
44 ));
45 }
46 if c2 <= c1 || c2 >= float!(1.0) {
47 return Err(argmin_error!(
48 InvalidParameter,
49 "StrongWolfeCondition: Parameter c2 must be in (c1, 1)"
50 ));
51 }
52 Ok(StrongWolfeCondition { c1, c2 })
53 }
54}
55
56impl<T, G, F> LineSearchCondition<T, G, F> for StrongWolfeCondition<F>
57where
58 G: ArgminDot<T, F>,
59 F: ArgminFloat,
60{
61 fn evaluate_condition(
62 &self,
63 current_cost: F,
64 current_gradient: Option<&G>,
65 initial_cost: F,
66 initial_gradient: &G,
67 search_direction: &T,
68 step_length: F,
69 ) -> bool {
70 let tmp = initial_gradient.dot(search_direction);
71 (current_cost <= initial_cost + self.c1 * step_length * tmp)
72 && current_gradient
73 .expect("Gradient not supplied to `evaluate_condition` of `StrongWolveCondition`")
74 .dot(search_direction)
75 .abs()
76 <= self.c2 * tmp.abs()
77 }
78
79 fn requires_current_gradient(&self) -> bool {
80 true
81 }
82}
83
84#[cfg(test)]
85mod tests {
86 use super::*;
87 use crate::core::ArgminError;
88
89 test_trait_impl!(strongwolfe, StrongWolfeCondition<f64>);
90
91 #[test]
92 fn test_strongwolfe_new() {
93 let c1: f64 = 0.01;
94 let c2: f64 = 0.08;
95 let StrongWolfeCondition {
96 c1: c1_wolfe,
97 c2: c2_wolfe,
98 } = StrongWolfeCondition::new(c1, c2).unwrap();
99 assert_eq!(c1.to_ne_bytes(), c1_wolfe.to_ne_bytes());
100 assert_eq!(c2.to_ne_bytes(), c2_wolfe.to_ne_bytes());
101
102 assert_error!(
104 StrongWolfeCondition::new(1.0, 0.5),
105 ArgminError,
106 "Invalid parameter: \"StrongWolfeCondition: Parameter c1 must be in (0, 1)\""
107 );
108
109 assert_error!(
110 StrongWolfeCondition::new(0.0, 0.5),
111 ArgminError,
112 "Invalid parameter: \"StrongWolfeCondition: Parameter c1 must be in (0, 1)\""
113 );
114
115 assert_error!(
116 StrongWolfeCondition::new(-1.0, 0.5),
117 ArgminError,
118 "Invalid parameter: \"StrongWolfeCondition: Parameter c1 must be in (0, 1)\""
119 );
120
121 assert_error!(
122 StrongWolfeCondition::new(2.0, 0.5),
123 ArgminError,
124 "Invalid parameter: \"StrongWolfeCondition: Parameter c1 must be in (0, 1)\""
125 );
126
127 assert_error!(
129 StrongWolfeCondition::new(0.5, -1.0),
130 ArgminError,
131 "Invalid parameter: \"StrongWolfeCondition: Parameter c2 must be in (c1, 1)\""
132 );
133
134 assert_error!(
135 StrongWolfeCondition::new(0.5, 0.0),
136 ArgminError,
137 "Invalid parameter: \"StrongWolfeCondition: Parameter c2 must be in (c1, 1)\""
138 );
139
140 assert_error!(
141 StrongWolfeCondition::new(0.5, 0.5),
142 ArgminError,
143 "Invalid parameter: \"StrongWolfeCondition: Parameter c2 must be in (c1, 1)\""
144 );
145
146 assert_error!(
147 StrongWolfeCondition::new(0.5, 1.0),
148 ArgminError,
149 "Invalid parameter: \"StrongWolfeCondition: Parameter c2 must be in (c1, 1)\""
150 );
151
152 assert_error!(
153 StrongWolfeCondition::new(0.5, 2.0),
154 ArgminError,
155 "Invalid parameter: \"StrongWolfeCondition: Parameter c2 must be in (c1, 1)\""
156 );
157 }
158
159 #[test]
160 fn test_strongwolfe() {
161 let c1: f64 = 0.01;
164 let c2: f64 = 0.9;
165 let cond = StrongWolfeCondition::new(c1, c2).unwrap();
166 let f = |x: f64, y: f64| x.powf(2.0) + y.powf(2.0);
167 let g = |x: f64, y: f64| vec![2.0 * x, 2.0 * y];
168 let initial_x = -1.0;
169 let initial_y = -0.0;
170 let search_direction = vec![1.0, 0.0];
171 for (alpha, acc) in [
172 (0.001, false),
173 (0.03, false),
174 (0.1 - f64::EPSILON, false),
175 (0.1, true),
176 (0.15, true),
177 (0.9, true),
178 (0.99, true),
179 (1.0, true),
180 (1.9, true),
181 (1.9 + f64::EPSILON, false),
182 (2.0, false),
183 (2.3, false),
184 ] {
185 assert_eq!(
186 cond.evaluate_condition(
187 f(initial_x + alpha, initial_y),
188 Some(&g(initial_x + alpha, initial_y)),
189 f(initial_x, initial_y),
190 &g(initial_x, initial_y),
191 &search_direction,
192 alpha,
193 ),
194 acc
195 );
196 }
197
198 let c1: f64 = 0.5;
200 let c2: f64 = 0.9;
201 let cond = StrongWolfeCondition::new(c1, c2).unwrap();
202 let f = |x: f64, y: f64| x.powf(2.0) + y.powf(2.0);
203 let g = |x: f64, y: f64| vec![2.0 * x, 2.0 * y];
204 let initial_x = -1.0;
205 let initial_y = -0.0;
206 let search_direction = vec![1.0, 0.0];
207 for (alpha, acc) in [
208 (0.001, false),
209 (0.03, false),
210 (0.1 - f64::EPSILON, false),
211 (0.1, true),
212 (0.15, true),
213 (0.9, true),
214 (0.99, true),
215 (1.0, true),
216 (1.0 + f64::EPSILON, false),
217 (1.9, false),
218 (2.0, false),
219 (2.3, false),
220 ] {
221 assert_eq!(
222 cond.evaluate_condition(
223 f(initial_x + alpha, initial_y),
224 Some(&g(initial_x + alpha, initial_y)),
225 f(initial_x, initial_y),
226 &g(initial_x, initial_y),
227 &search_direction,
228 alpha,
229 ),
230 acc
231 );
232 }
233 }
234}