argmin/solver/gradientdescent/
steepestdescent.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use crate::core::{
9    ArgminFloat, CostFunction, Error, Executor, Gradient, IterState, LineSearch,
10    OptimizationResult, Problem, Solver, State, KV,
11};
12use argmin_math::ArgminMul;
13#[cfg(feature = "serde1")]
14use serde::{Deserialize, Serialize};
15
16/// # Steepest descent
17///
18/// Iteratively takes steps in the direction of the strongest negative gradient. In each iteration,
19/// a line search is used to obtain an appropriate step length.
20///
21/// ## Requirements on the optimization problem
22///
23/// The optimization problem is required to implement [`CostFunction`] and [`Gradient`].
24///
25/// ## Reference
26///
27/// Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
28/// Springer. ISBN 0-387-30303-0.
29#[derive(Clone)]
30#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
31pub struct SteepestDescent<L> {
32    /// line search
33    linesearch: L,
34}
35
36impl<L> SteepestDescent<L> {
37    /// Construct a new instance of [`SteepestDescent`]
38    ///
39    /// Requires a line search.
40    ///
41    /// # Example
42    ///
43    /// ```
44    /// # use argmin::solver::gradientdescent::SteepestDescent;
45    /// # let linesearch = ();
46    /// let sd = SteepestDescent::new(linesearch);
47    /// ```
48    pub fn new(linesearch: L) -> Self {
49        SteepestDescent { linesearch }
50    }
51}
52
53impl<O, L, P, G, F> Solver<O, IterState<P, G, (), (), (), F>> for SteepestDescent<L>
54where
55    O: CostFunction<Param = P, Output = F> + Gradient<Param = P, Gradient = G>,
56    P: Clone,
57    G: Clone + ArgminMul<F, G>,
58    L: Clone + LineSearch<G, F> + Solver<O, IterState<P, G, (), (), (), F>>,
59    F: ArgminFloat,
60{
61    fn name(&self) -> &str {
62        "Steepest Descent"
63    }
64
65    fn next_iter(
66        &mut self,
67        problem: &mut Problem<O>,
68        state: IterState<P, G, (), (), (), F>,
69    ) -> Result<(IterState<P, G, (), (), (), F>, Option<KV>), Error> {
70        let param_new = state
71            .get_param()
72            .ok_or_else(argmin_error_closure!(
73                NotInitialized,
74                concat!(
75                    "`SteepestDescent` requires an initial parameter vector. ",
76                    "Please provide an initial guess via `Executor`s `configure` method."
77                )
78            ))?
79            .clone();
80        let new_cost = problem.cost(&param_new)?;
81        let new_grad = problem.gradient(&param_new)?;
82
83        self.linesearch
84            .search_direction(new_grad.mul(&(float!(-1.0))));
85
86        // Run line search
87        let OptimizationResult {
88            problem: line_problem,
89            state: mut linesearch_state,
90            ..
91        } = Executor::new(
92            problem.take_problem().ok_or_else(argmin_error_closure!(
93                PotentialBug,
94                "`SteepestDescent`: Failed to take `problem` for line search"
95            ))?,
96            self.linesearch.clone(),
97        )
98        .configure(|config| config.param(param_new).gradient(new_grad).cost(new_cost))
99        .ctrlc(false)
100        .run()?;
101
102        // Get back problem and function evaluation counts
103        problem.consume_problem(line_problem);
104
105        Ok((
106            state
107                .param(
108                    linesearch_state
109                        .take_param()
110                        .ok_or_else(argmin_error_closure!(
111                            PotentialBug,
112                            "`GradientDescent`: No `param` returned by line search"
113                        ))?,
114                )
115                .cost(linesearch_state.get_cost()),
116            None,
117        ))
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124    use crate::core::test_utils::TestProblem;
125    use crate::core::ArgminError;
126    use crate::solver::linesearch::{
127        condition::ArmijoCondition, BacktrackingLineSearch, MoreThuenteLineSearch,
128    };
129    use approx::assert_relative_eq;
130
131    test_trait_impl!(
132        steepest_descent,
133        SteepestDescent<MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64>>
134    );
135
136    #[test]
137    fn test_new() {
138        let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
139            BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
140        let SteepestDescent { linesearch: ls } = SteepestDescent::new(linesearch.clone());
141        assert_eq!(ls, linesearch);
142    }
143
144    #[test]
145    fn test_next_iter_param_not_initialized() {
146        let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
147            BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
148        let mut sd = SteepestDescent::new(linesearch);
149        let res = sd.next_iter(&mut Problem::new(TestProblem::new()), IterState::new());
150        assert_error!(
151            res,
152            ArgminError,
153            concat!(
154                "Not initialized: \"`SteepestDescent` requires an initial parameter vector. ",
155                "Please provide an initial guess via `Executor`s `configure` method.\""
156            )
157        );
158    }
159
160    #[test]
161    fn test_next_iter_prev_param_not_erased() {
162        let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
163            BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
164        let mut sd = SteepestDescent::new(linesearch);
165        let (state, _kv) = sd
166            .next_iter(
167                &mut Problem::new(TestProblem::new()),
168                IterState::new().param(vec![1.0, 2.0]),
169            )
170            .unwrap();
171        state.prev_param.unwrap();
172    }
173
174    #[test]
175    fn test_next_iter_regression() {
176        struct SDProblem {}
177
178        impl CostFunction for SDProblem {
179            type Param = Vec<f64>;
180            type Output = f64;
181
182            fn cost(&self, p: &Self::Param) -> Result<Self::Output, Error> {
183                Ok(p[0].powi(2) + p[1].powi(2))
184            }
185        }
186
187        impl Gradient for SDProblem {
188            type Param = Vec<f64>;
189            type Gradient = Vec<f64>;
190
191            fn gradient(&self, p: &Self::Param) -> Result<Self::Param, Error> {
192                Ok(vec![2.0 * p[0], 2.0 * p[1]])
193            }
194        }
195
196        let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
197            BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
198        let mut sd = SteepestDescent::new(linesearch);
199        let (state, kv) = sd
200            .next_iter(
201                &mut Problem::new(SDProblem {}),
202                IterState::new().param(vec![1.0, 2.0]),
203            )
204            .unwrap();
205
206        assert!(kv.is_none());
207
208        assert_relative_eq!(
209            state.param.as_ref().unwrap()[0],
210            -0.4580000000000002,
211            epsilon = f64::EPSILON
212        );
213        assert_relative_eq!(
214            state.param.as_ref().unwrap()[1],
215            -0.9160000000000004,
216            epsilon = f64::EPSILON
217        );
218        assert_relative_eq!(state.cost, 1.048820000000001, epsilon = f64::EPSILON);
219    }
220}