argmin/solver/gradientdescent/
steepestdescent.rs1use crate::core::{
9 ArgminFloat, CostFunction, Error, Executor, Gradient, IterState, LineSearch,
10 OptimizationResult, Problem, Solver, State, KV,
11};
12use argmin_math::ArgminMul;
13#[cfg(feature = "serde1")]
14use serde::{Deserialize, Serialize};
15
16#[derive(Clone)]
30#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
31pub struct SteepestDescent<L> {
32 linesearch: L,
34}
35
36impl<L> SteepestDescent<L> {
37 pub fn new(linesearch: L) -> Self {
49 SteepestDescent { linesearch }
50 }
51}
52
53impl<O, L, P, G, F> Solver<O, IterState<P, G, (), (), (), F>> for SteepestDescent<L>
54where
55 O: CostFunction<Param = P, Output = F> + Gradient<Param = P, Gradient = G>,
56 P: Clone,
57 G: Clone + ArgminMul<F, G>,
58 L: Clone + LineSearch<G, F> + Solver<O, IterState<P, G, (), (), (), F>>,
59 F: ArgminFloat,
60{
61 fn name(&self) -> &str {
62 "Steepest Descent"
63 }
64
65 fn next_iter(
66 &mut self,
67 problem: &mut Problem<O>,
68 state: IterState<P, G, (), (), (), F>,
69 ) -> Result<(IterState<P, G, (), (), (), F>, Option<KV>), Error> {
70 let param_new = state
71 .get_param()
72 .ok_or_else(argmin_error_closure!(
73 NotInitialized,
74 concat!(
75 "`SteepestDescent` requires an initial parameter vector. ",
76 "Please provide an initial guess via `Executor`s `configure` method."
77 )
78 ))?
79 .clone();
80 let new_cost = problem.cost(¶m_new)?;
81 let new_grad = problem.gradient(¶m_new)?;
82
83 self.linesearch
84 .search_direction(new_grad.mul(&(float!(-1.0))));
85
86 let OptimizationResult {
88 problem: line_problem,
89 state: mut linesearch_state,
90 ..
91 } = Executor::new(
92 problem.take_problem().ok_or_else(argmin_error_closure!(
93 PotentialBug,
94 "`SteepestDescent`: Failed to take `problem` for line search"
95 ))?,
96 self.linesearch.clone(),
97 )
98 .configure(|config| config.param(param_new).gradient(new_grad).cost(new_cost))
99 .ctrlc(false)
100 .run()?;
101
102 problem.consume_problem(line_problem);
104
105 Ok((
106 state
107 .param(
108 linesearch_state
109 .take_param()
110 .ok_or_else(argmin_error_closure!(
111 PotentialBug,
112 "`GradientDescent`: No `param` returned by line search"
113 ))?,
114 )
115 .cost(linesearch_state.get_cost()),
116 None,
117 ))
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124 use crate::core::test_utils::TestProblem;
125 use crate::core::ArgminError;
126 use crate::solver::linesearch::{
127 condition::ArmijoCondition, BacktrackingLineSearch, MoreThuenteLineSearch,
128 };
129 use approx::assert_relative_eq;
130
131 test_trait_impl!(
132 steepest_descent,
133 SteepestDescent<MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64>>
134 );
135
136 #[test]
137 fn test_new() {
138 let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
139 BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
140 let SteepestDescent { linesearch: ls } = SteepestDescent::new(linesearch.clone());
141 assert_eq!(ls, linesearch);
142 }
143
144 #[test]
145 fn test_next_iter_param_not_initialized() {
146 let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
147 BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
148 let mut sd = SteepestDescent::new(linesearch);
149 let res = sd.next_iter(&mut Problem::new(TestProblem::new()), IterState::new());
150 assert_error!(
151 res,
152 ArgminError,
153 concat!(
154 "Not initialized: \"`SteepestDescent` requires an initial parameter vector. ",
155 "Please provide an initial guess via `Executor`s `configure` method.\""
156 )
157 );
158 }
159
160 #[test]
161 fn test_next_iter_prev_param_not_erased() {
162 let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
163 BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
164 let mut sd = SteepestDescent::new(linesearch);
165 let (state, _kv) = sd
166 .next_iter(
167 &mut Problem::new(TestProblem::new()),
168 IterState::new().param(vec![1.0, 2.0]),
169 )
170 .unwrap();
171 state.prev_param.unwrap();
172 }
173
174 #[test]
175 fn test_next_iter_regression() {
176 struct SDProblem {}
177
178 impl CostFunction for SDProblem {
179 type Param = Vec<f64>;
180 type Output = f64;
181
182 fn cost(&self, p: &Self::Param) -> Result<Self::Output, Error> {
183 Ok(p[0].powi(2) + p[1].powi(2))
184 }
185 }
186
187 impl Gradient for SDProblem {
188 type Param = Vec<f64>;
189 type Gradient = Vec<f64>;
190
191 fn gradient(&self, p: &Self::Param) -> Result<Self::Param, Error> {
192 Ok(vec![2.0 * p[0], 2.0 * p[1]])
193 }
194 }
195
196 let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
197 BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
198 let mut sd = SteepestDescent::new(linesearch);
199 let (state, kv) = sd
200 .next_iter(
201 &mut Problem::new(SDProblem {}),
202 IterState::new().param(vec![1.0, 2.0]),
203 )
204 .unwrap();
205
206 assert!(kv.is_none());
207
208 assert_relative_eq!(
209 state.param.as_ref().unwrap()[0],
210 -0.4580000000000002,
211 epsilon = f64::EPSILON
212 );
213 assert_relative_eq!(
214 state.param.as_ref().unwrap()[1],
215 -0.9160000000000004,
216 epsilon = f64::EPSILON
217 );
218 assert_relative_eq!(state.cost, 1.048820000000001, epsilon = f64::EPSILON);
219 }
220}