argmin/solver/conjugategradient/
nonlinear_cg.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use crate::core::{
9    ArgminFloat, CostFunction, Error, Executor, Gradient, IterState, LineSearch, NLCGBetaUpdate,
10    OptimizationResult, Problem, Solver, State, KV,
11};
12use argmin_math::{ArgminAdd, ArgminDot, ArgminL2Norm, ArgminMul};
13#[cfg(feature = "serde1")]
14use serde::{Deserialize, Serialize};
15
16/// # Non-linear Conjugate Gradient method
17///
18/// A generalization of the conjugate gradient method for nonlinear optimization problems.
19///
20/// Requires an initial parameter vector.
21///
22/// ## Requirements on the optimization problem
23///
24/// The optimization problem is required to implement [`CostFunction`] and [`Gradient`].
25///
26/// ## Reference
27///
28/// Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
29/// Springer. ISBN 0-387-30303-0.
30#[derive(Clone)]
31#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
32pub struct NonlinearConjugateGradient<P, L, B, F> {
33    /// p
34    p: Option<P>,
35    /// beta
36    beta: F,
37    /// line search
38    linesearch: L,
39    /// beta update method
40    beta_method: B,
41    /// Number of iterations after which a restart is performed
42    restart_iter: u64,
43    /// Restart based on orthogonality
44    restart_orthogonality: Option<F>,
45}
46
47impl<P, L, B, F> NonlinearConjugateGradient<P, L, B, F>
48where
49    F: ArgminFloat,
50{
51    /// Construct a new instance of `NonlinearConjugateGradient`.
52    ///
53    /// Takes a [`LineSearch`] and a [`NLCGBetaUpdate`] as input.
54    ///
55    /// # Example
56    ///
57    /// ```
58    /// # use argmin::solver::conjugategradient::NonlinearConjugateGradient;
59    /// # let linesearch = ();
60    /// # let beta_method = ();
61    /// let nlcg: NonlinearConjugateGradient<Vec<f64>, _, _, f64> =
62    ///     NonlinearConjugateGradient::new(linesearch, beta_method);
63    /// ```
64    pub fn new(linesearch: L, beta_method: B) -> Self {
65        NonlinearConjugateGradient {
66            p: None,
67            beta: F::nan(),
68            linesearch,
69            beta_method,
70            restart_iter: u64::MAX,
71            restart_orthogonality: None,
72        }
73    }
74
75    /// Specify the number of iterations after which a restart should be performed.
76    ///
77    /// This allows the algorithm to "forget" previous information which may not be helpful
78    /// anymore.
79    ///
80    /// # Example
81    ///
82    /// ```
83    /// # use argmin::solver::conjugategradient::NonlinearConjugateGradient;
84    /// # let linesearch = ();
85    /// # let beta_method = ();
86    /// # let nlcg: NonlinearConjugateGradient<Vec<f64>, _, _, f64> = NonlinearConjugateGradient::new(linesearch, beta_method);
87    /// let nlcg = nlcg.restart_iters(100);
88    /// ```
89    #[must_use]
90    pub fn restart_iters(mut self, iters: u64) -> Self {
91        self.restart_iter = iters;
92        self
93    }
94
95    /// Set the value for the orthogonality measure.
96    ///
97    /// Setting this parameter leads to a restart of the algorithm (setting beta = 0) after
98    /// consecutive search directions stop being orthogonal. In other words, if this condition
99    /// is met:
100    ///
101    /// `|\nabla f_k^T * \nabla f_{k-1}| / | \nabla f_k |^2 >= v`
102    ///
103    /// A typical value for `v` is 0.1.
104    ///
105    /// # Example
106    ///
107    /// ```
108    /// # use argmin::solver::conjugategradient::NonlinearConjugateGradient;
109    /// # let linesearch = ();
110    /// # let beta_method = ();
111    /// # let nlcg: NonlinearConjugateGradient<Vec<f64>, _, _, f64> = NonlinearConjugateGradient::new(linesearch, beta_method);
112    /// let nlcg = nlcg.restart_orthogonality(0.1);
113    /// ```
114    #[must_use]
115    pub fn restart_orthogonality(mut self, v: F) -> Self {
116        self.restart_orthogonality = Some(v);
117        self
118    }
119}
120
121impl<O, P, G, L, B, F> Solver<O, IterState<P, G, (), (), (), F>>
122    for NonlinearConjugateGradient<P, L, B, F>
123where
124    O: CostFunction<Param = P, Output = F> + Gradient<Param = P, Gradient = G>,
125    P: Clone + ArgminAdd<P, P> + ArgminMul<F, P>,
126    G: Clone + ArgminMul<F, P> + ArgminDot<G, F> + ArgminL2Norm<F>,
127    L: Clone + LineSearch<P, F> + Solver<O, IterState<P, G, (), (), (), F>>,
128    B: NLCGBetaUpdate<G, P, F>,
129    F: ArgminFloat,
130{
131    fn name(&self) -> &str {
132        "Nonlinear Conjugate Gradient"
133    }
134
135    fn init(
136        &mut self,
137        problem: &mut Problem<O>,
138        state: IterState<P, G, (), (), (), F>,
139    ) -> Result<(IterState<P, G, (), (), (), F>, Option<KV>), Error> {
140        let param = state.get_param().ok_or_else(argmin_error_closure!(
141            NotInitialized,
142            concat!(
143                "`NonlinearConjugateGradient` requires an initial parameter vector. ",
144                "Please provide an initial guess via `Executor`s `configure` method."
145            )
146        ))?;
147        let cost = problem.cost(param)?;
148        let grad = problem.gradient(param)?;
149        self.p = Some(grad.mul(&(float!(-1.0))));
150        Ok((state.cost(cost).gradient(grad), None))
151    }
152
153    fn next_iter(
154        &mut self,
155        problem: &mut Problem<O>,
156        mut state: IterState<P, G, (), (), (), F>,
157    ) -> Result<(IterState<P, G, (), (), (), F>, Option<KV>), Error> {
158        let p = self.p.as_ref().ok_or_else(argmin_error_closure!(
159            PotentialBug,
160            "`NonlinearConjugateGradient`: Field `p` not set"
161        ))?;
162        let xk = state.take_param().ok_or_else(argmin_error_closure!(
163            PotentialBug,
164            "`NonlinearConjugateGradient`: No `param` in `state`"
165        ))?;
166        let grad = state
167            .take_gradient()
168            .map(Result::Ok)
169            .unwrap_or_else(|| problem.gradient(&xk))?;
170        let cur_cost = state.cost;
171
172        // Linesearch
173        self.linesearch.search_direction(p.clone());
174
175        // Run solver
176        let OptimizationResult {
177            problem: line_problem,
178            state: mut line_state,
179            ..
180        } = Executor::new(
181            problem.take_problem().ok_or_else(argmin_error_closure!(
182                PotentialBug,
183                "`NonlinearConjugateGradient`: Failed to take `problem` for line search"
184            ))?,
185            self.linesearch.clone(),
186        )
187        .configure(|state| state.param(xk).gradient(grad.clone()).cost(cur_cost))
188        .ctrlc(false)
189        .run()?;
190
191        // takes care of the counts of function evaluations
192        problem.consume_problem(line_problem);
193
194        let xk1 = line_state.take_param().ok_or_else(argmin_error_closure!(
195            PotentialBug,
196            "`NonlinearConjugateGradient`: No `param` returned by line search"
197        ))?;
198
199        // Update of beta
200        let new_grad = problem.gradient(&xk1)?;
201
202        let restart_orthogonality = match self.restart_orthogonality {
203            Some(v) => new_grad.dot(&grad).abs() / new_grad.l2_norm().powi(2) >= v,
204            None => false,
205        };
206
207        let restart_iter: bool =
208            (state.get_iter().is_multiple_of(self.restart_iter)) && state.get_iter() != 0;
209
210        if restart_iter || restart_orthogonality {
211            self.beta = float!(0.0);
212        } else {
213            self.beta = self.beta_method.update(&grad, &new_grad, p);
214        }
215
216        // Update of p
217        self.p = Some(new_grad.mul(&(float!(-1.0))).add(&p.mul(&self.beta)));
218
219        // Housekeeping
220        let cost = problem.cost(&xk1)?;
221
222        Ok((
223            state.param(xk1).cost(cost).gradient(new_grad),
224            Some(kv!("beta" => self.beta;
225             "restart_iter" => restart_iter;
226             "restart_orthogonality" => restart_orthogonality;
227            )),
228        ))
229    }
230}
231
232#[cfg(test)]
233#[allow(clippy::let_unit_value)]
234mod tests {
235    use super::*;
236    use crate::core::test_utils::TestProblem;
237    use crate::core::ArgminError;
238    use crate::solver::conjugategradient::beta::PolakRibiere;
239    use crate::solver::linesearch::{
240        condition::ArmijoCondition, BacktrackingLineSearch, MoreThuenteLineSearch,
241    };
242    use approx::assert_relative_eq;
243
244    #[derive(Eq, PartialEq, Clone, Copy, Debug)]
245    struct Linesearch {}
246
247    #[derive(Eq, PartialEq, Clone, Copy, Debug)]
248    struct BetaUpdate {}
249
250    test_trait_impl!(
251        nonlinear_cg,
252        NonlinearConjugateGradient<
253            TestProblem,
254            MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64>,
255            PolakRibiere,
256            f64
257        >
258    );
259
260    #[test]
261    fn test_new() {
262        let linesearch = Linesearch {};
263        let beta_method = BetaUpdate {};
264        let nlcg: NonlinearConjugateGradient<Vec<f64>, _, _, f64> =
265            NonlinearConjugateGradient::new(linesearch, beta_method);
266        let NonlinearConjugateGradient {
267            p,
268            beta,
269            linesearch,
270            beta_method,
271            restart_iter,
272            restart_orthogonality,
273        } = nlcg;
274        assert!(p.is_none());
275        assert!(beta.is_nan());
276        assert_eq!(linesearch, linesearch);
277        assert_eq!(beta_method, beta_method);
278        assert_eq!(restart_iter, u64::MAX);
279        assert!(restart_orthogonality.is_none());
280    }
281
282    #[test]
283    fn test_restart_iters() {
284        let linesearch = ();
285        let beta_method = ();
286        let nlcg: NonlinearConjugateGradient<Vec<f64>, _, _, f64> =
287            NonlinearConjugateGradient::new(linesearch, beta_method);
288        assert_eq!(nlcg.restart_iter, u64::MAX);
289        let nlcg = nlcg.restart_iters(100);
290        assert_eq!(nlcg.restart_iter, 100);
291    }
292
293    #[test]
294    fn test_restart_orthogonality() {
295        let linesearch = ();
296        let beta_method = ();
297        let nlcg: NonlinearConjugateGradient<Vec<f64>, _, _, f64> =
298            NonlinearConjugateGradient::new(linesearch, beta_method);
299        assert!(nlcg.restart_orthogonality.is_none());
300        let nlcg = nlcg.restart_orthogonality(0.1);
301        assert_eq!(
302            nlcg.restart_orthogonality.as_ref().unwrap().to_ne_bytes(),
303            0.1f64.to_ne_bytes()
304        );
305    }
306
307    #[test]
308    fn test_init_param_not_initialized() {
309        let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
310            BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
311        let beta_method = PolakRibiere::new();
312        let mut nlcg: NonlinearConjugateGradient<Vec<f64>, _, _, f64> =
313            NonlinearConjugateGradient::new(linesearch, beta_method);
314        let res = nlcg.init(&mut Problem::new(TestProblem::new()), IterState::new());
315        assert_error!(
316            res,
317            ArgminError,
318            concat!(
319                "Not initialized: \"`NonlinearConjugateGradient` requires an initial parameter vector. ",
320                "Please provide an initial guess via `Executor`s `configure` method.\""
321            )
322        );
323    }
324
325    #[test]
326    fn test_init() {
327        let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
328            BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
329        let beta_method = PolakRibiere::new();
330        let mut nlcg: NonlinearConjugateGradient<Vec<f64>, _, _, f64> =
331            NonlinearConjugateGradient::new(linesearch, beta_method);
332        let state: IterState<Vec<f64>, Vec<f64>, (), (), (), f64> =
333            IterState::new().param(vec![3.0, 4.0]);
334        let (state_out, kv) = nlcg
335            .init(&mut Problem::new(TestProblem::new()), state.clone())
336            .unwrap();
337        assert!(kv.is_none());
338        assert_ne!(state_out, state);
339        assert_eq!(state_out.cost.to_ne_bytes(), 1f64.to_ne_bytes());
340        assert_eq!(
341            state_out.grad.as_ref().unwrap()[0].to_ne_bytes(),
342            3f64.to_ne_bytes()
343        );
344        assert_eq!(
345            state_out.grad.as_ref().unwrap()[1].to_ne_bytes(),
346            4f64.to_ne_bytes()
347        );
348        assert_eq!(
349            state_out.param.as_ref().unwrap()[0].to_ne_bytes(),
350            3f64.to_ne_bytes()
351        );
352        assert_eq!(
353            state_out.param.as_ref().unwrap()[1].to_ne_bytes(),
354            4f64.to_ne_bytes()
355        );
356        assert_eq!(
357            nlcg.p.as_ref().unwrap()[0].to_ne_bytes(),
358            (-3f64).to_ne_bytes()
359        );
360        assert_eq!(
361            nlcg.p.as_ref().unwrap()[1].to_ne_bytes(),
362            (-4f64).to_ne_bytes()
363        );
364    }
365
366    #[test]
367    fn test_next_iter_p_not_set() {
368        let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
369            BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
370        let beta_method = PolakRibiere::new();
371        let mut nlcg: NonlinearConjugateGradient<Vec<f64>, _, _, f64> =
372            NonlinearConjugateGradient::new(linesearch, beta_method);
373        let state = IterState::new().param(vec![1.0f64, 2.0f64]);
374        assert!(nlcg.p.is_none());
375        let res = nlcg.next_iter(&mut Problem::new(TestProblem::new()), state);
376        assert_error!(
377            res,
378            ArgminError,
379            concat!(
380                "Potential bug: \"`NonlinearConjugateGradient`: ",
381                "Field `p` not set\". This is potentially a bug. ",
382                "Please file a report on https://github.com/argmin-rs/argmin/issues"
383            )
384        );
385    }
386
387    #[test]
388    fn test_next_iter_state_param_not_set() {
389        let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
390            BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
391        let beta_method = PolakRibiere::new();
392        let mut nlcg: NonlinearConjugateGradient<Vec<f64>, _, _, f64> =
393            NonlinearConjugateGradient::new(linesearch, beta_method);
394        let state = IterState::new();
395        nlcg.p = Some(vec![]);
396        assert!(nlcg.p.is_some());
397        let res = nlcg.next_iter(&mut Problem::new(TestProblem::new()), state);
398        assert_error!(
399            res,
400            ArgminError,
401            concat!(
402                "Potential bug: \"`NonlinearConjugateGradient`: ",
403                "No `param` in `state`\". This is potentially a bug. ",
404                "Please file a report on https://github.com/argmin-rs/argmin/issues"
405            )
406        );
407    }
408
409    #[test]
410    fn test_next_iter_problem_missing() {
411        let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
412            BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
413        let beta_method = PolakRibiere::new();
414        let mut nlcg: NonlinearConjugateGradient<Vec<f64>, _, _, f64> =
415            NonlinearConjugateGradient::new(linesearch, beta_method);
416        let state = IterState::new()
417            .param(vec![1.0f64, 2.0])
418            .gradient(vec![1.0f64, 2.0]);
419        nlcg.p = Some(vec![]);
420        assert!(nlcg.p.is_some());
421        let mut problem = Problem::new(TestProblem::new());
422        let _ = problem.take_problem().unwrap();
423        let res = nlcg.next_iter(&mut problem, state);
424        assert_error!(
425            res,
426            ArgminError,
427            concat!(
428                "Potential bug: \"`NonlinearConjugateGradient`: ",
429                "Failed to take `problem` for line search\". This is potentially a bug. ",
430                "Please file a report on https://github.com/argmin-rs/argmin/issues"
431            )
432        );
433    }
434
435    #[test]
436    fn test_next_iter() {
437        let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
438            BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
439        let beta_method = PolakRibiere::new();
440        let mut nlcg: NonlinearConjugateGradient<Vec<f64>, _, _, f64> =
441            NonlinearConjugateGradient::new(linesearch, beta_method);
442        let state = IterState::new()
443            .param(vec![1.0f64, 2.0])
444            .gradient(vec![1.0f64, 2.0]);
445        let mut problem = Problem::new(TestProblem::new());
446        let (state, kv) = nlcg.init(&mut problem, state).unwrap();
447        assert!(kv.is_none());
448        let (mut state, kv) = nlcg.next_iter(&mut problem, state).unwrap();
449        state.update();
450        let kv2 = kv!("beta" => 0.0; "restart_iter" => false; "restart_orthogonality" => false;);
451        assert_eq!(kv.unwrap(), kv2);
452        assert_relative_eq!(
453            state.param.as_ref().unwrap()[0],
454            1.0f64,
455            epsilon = f64::EPSILON
456        );
457        assert_relative_eq!(
458            state.param.as_ref().unwrap()[1],
459            2.0f64,
460            epsilon = f64::EPSILON
461        );
462        assert_relative_eq!(
463            state.best_param.as_ref().unwrap()[0],
464            1.0f64,
465            epsilon = f64::EPSILON
466        );
467        assert_relative_eq!(
468            state.best_param.as_ref().unwrap()[1],
469            2.0f64,
470            epsilon = f64::EPSILON
471        );
472        assert_relative_eq!(state.cost, 1.0f64, epsilon = f64::EPSILON);
473        assert_relative_eq!(state.prev_cost, 1.0f64, epsilon = f64::EPSILON);
474        assert_relative_eq!(state.best_cost, 1.0f64, epsilon = f64::EPSILON);
475        assert_relative_eq!(
476            state.grad.as_ref().unwrap()[0],
477            1.0f64,
478            epsilon = f64::EPSILON
479        );
480        assert_relative_eq!(
481            state.grad.as_ref().unwrap()[1],
482            2.0f64,
483            epsilon = f64::EPSILON
484        );
485    }
486}