argmin/solver/linesearch/condition/
armijo.rs1use super::LineSearchCondition;
9use crate::core::{ArgminFloat, Error};
10use argmin_math::ArgminDot;
11#[cfg(feature = "serde1")]
12use serde::{Deserialize, Serialize};
13
14#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
18#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
19pub struct ArmijoCondition<F> {
20 c: F,
21}
22
23impl<F> ArmijoCondition<F>
24where
25 F: ArgminFloat,
26{
27 pub fn new(c: F) -> Result<Self, Error> {
36 if c <= float!(0.0) || c >= float!(1.0) {
37 return Err(argmin_error!(
38 InvalidParameter,
39 "ArmijoCondition: Parameter c must be in (0, 1)"
40 ));
41 }
42 Ok(ArmijoCondition { c })
43 }
44}
45
46impl<T, G, F> LineSearchCondition<T, G, F> for ArmijoCondition<F>
47where
48 G: ArgminDot<T, F>,
49 F: ArgminFloat,
50{
51 fn evaluate_condition(
52 &self,
53 current_cost: F,
54 _current_gradient: Option<&G>,
55 initial_cost: F,
56 initial_gradient: &G,
57 search_direction: &T,
58 step_length: F,
59 ) -> bool {
60 current_cost <= initial_cost + self.c * step_length * initial_gradient.dot(search_direction)
61 }
62
63 fn requires_current_gradient(&self) -> bool {
64 false
65 }
66}
67
68#[cfg(test)]
69mod tests {
70 use super::*;
71 use crate::core::ArgminError;
72
73 test_trait_impl!(armijo, ArmijoCondition<f64>);
74
75 #[test]
76 fn test_armijo_new() {
77 let c: f64 = 0.01;
78 let ArmijoCondition { c: c_arm } = ArmijoCondition::new(c).unwrap();
79 assert_eq!(c.to_ne_bytes(), c_arm.to_ne_bytes());
80
81 assert_error!(
82 ArmijoCondition::new(1.0f64),
83 ArgminError,
84 "Invalid parameter: \"ArmijoCondition: Parameter c must be in (0, 1)\""
85 );
86
87 assert_error!(
88 ArmijoCondition::new(2.0f64),
89 ArgminError,
90 "Invalid parameter: \"ArmijoCondition: Parameter c must be in (0, 1)\""
91 );
92
93 assert_error!(
94 ArmijoCondition::new(0.0f64),
95 ArgminError,
96 "Invalid parameter: \"ArmijoCondition: Parameter c must be in (0, 1)\""
97 );
98
99 assert_error!(
100 ArmijoCondition::new(-1.0f64),
101 ArgminError,
102 "Invalid parameter: \"ArmijoCondition: Parameter c must be in (0, 1)\""
103 );
104 }
105
106 #[test]
107 fn test_armijo() {
108 let c: f64 = 0.50;
109 let cond = ArmijoCondition::new(c).unwrap();
110 let f = |x: f64, y: f64| x.powf(2.0) + y.powf(2.0);
111 let g = |x: f64, y: f64| vec![2.0 * x, 2.0 * y];
112 let initial_x = -1.0;
113 let initial_y = -0.0;
114 let search_direction = vec![1.0, 0.0];
115 for (alpha, acc) in [
116 (0.001, true),
117 (0.03, true),
118 (0.2, true),
119 (0.5, true),
120 (0.9, true),
121 (0.99, true),
122 (1.0, true),
123 (1.0 + f64::EPSILON, false),
124 (1.5, false),
125 (1.8, false),
126 (2.0, false),
127 (2.3, false),
128 ] {
129 assert_eq!(
130 cond.evaluate_condition(
131 f(initial_x + alpha, initial_y),
132 Some(&g(initial_x + alpha, initial_y)),
133 f(initial_x, initial_y),
134 &g(initial_x, initial_y),
135 &search_direction,
136 alpha,
137 ),
138 acc
139 );
140 }
141 }
142}