argmin/solver/quasinewton/
sr1.rs

1// Copyright 2019-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use crate::core::{
9    ArgminFloat, CostFunction, Error, Executor, Gradient, IterState, LineSearch,
10    OptimizationResult, Problem, Solver, TerminationReason, TerminationStatus, KV,
11};
12use argmin_math::{ArgminAdd, ArgminDot, ArgminL2Norm, ArgminMul, ArgminSub};
13#[cfg(feature = "serde1")]
14use serde::{Deserialize, Serialize};
15
16/// # Symmetric rank-one (SR1) method
17///
18/// This method currently has problems: <https://github.com/argmin-rs/argmin/issues/221>.
19///
20/// ## Requirements on the optimization problem
21///
22/// The optimization problem is required to implement [`CostFunction`] and [`Gradient`].
23///
24/// ## Reference
25///
26/// Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
27/// Springer. ISBN 0-387-30303-0.
28#[derive(Clone)]
29#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
30pub struct SR1<L, F> {
31    /// parameter for skipping rule
32    denominator_factor: F,
33    /// line search
34    linesearch: L,
35    /// Tolerance for the stopping criterion based on the change of the norm on the gradient
36    tol_grad: F,
37    /// Tolerance for the stopping criterion based on the change of the cost stopping criterion
38    tol_cost: F,
39}
40
41impl<L, F> SR1<L, F>
42where
43    F: ArgminFloat,
44{
45    /// Construct a new instance of [`SR1`]
46    ///
47    /// # Example
48    ///
49    /// ```
50    /// # use argmin::solver::quasinewton::SR1;
51    /// # let linesearch = ();
52    /// let sr1: SR1<_, f64> = SR1::new(linesearch);
53    /// ```
54    pub fn new(linesearch: L) -> Self {
55        SR1 {
56            denominator_factor: float!(1e-8),
57            linesearch,
58            tol_grad: F::epsilon().sqrt(),
59            tol_cost: F::epsilon(),
60        }
61    }
62
63    /// Set denominator factor
64    ///
65    /// If the denominator of the update is below the `denominator_factor` (scaled with other
66    /// factors derived from the parameter vectors and the gradients), then the update of the
67    /// inverse Hessian will be skipped.
68    ///
69    /// Must be in `(0, 1)` and defaults to `1e-8`.
70    ///
71    /// # Example
72    ///
73    /// ```
74    /// # use argmin::solver::quasinewton::SR1;
75    /// # use argmin::core::Error;
76    /// # fn main() -> Result<(), Error> {
77    /// # let linesearch = ();
78    /// let sr1: SR1<_, f64> = SR1::new(linesearch).with_denominator_factor(1e-7)?;
79    /// # Ok(())
80    /// # }
81    /// ```
82    pub fn with_denominator_factor(mut self, denominator_factor: F) -> Result<Self, Error> {
83        if denominator_factor <= float!(0.0) || denominator_factor >= float!(1.0) {
84            Err(argmin_error!(
85                InvalidParameter,
86                "`SR1`: denominator_factor must be in (0, 1)."
87            ))
88        } else {
89            self.denominator_factor = denominator_factor;
90            Ok(self)
91        }
92    }
93
94    /// The algorithm stops if the norm of the gradient is below `tol_grad`.
95    ///
96    /// The provided value must be non-negative. Defaults to `sqrt(EPSILON)`.
97    ///
98    /// # Example
99    ///
100    /// ```
101    /// # use argmin::solver::quasinewton::SR1;
102    /// # use argmin::core::Error;
103    /// # fn main() -> Result<(), Error> {
104    /// # let linesearch = ();
105    /// let sr1: SR1<_, f64> = SR1::new(linesearch).with_tolerance_grad(1e-6)?;
106    /// # Ok(())
107    /// # }
108    /// ```
109    pub fn with_tolerance_grad(mut self, tol_grad: F) -> Result<Self, Error> {
110        if tol_grad < float!(0.0) {
111            return Err(argmin_error!(
112                InvalidParameter,
113                "`SR1`: gradient tolerance must be >= 0."
114            ));
115        }
116        self.tol_grad = tol_grad;
117        Ok(self)
118    }
119
120    /// Sets tolerance for the stopping criterion based on the change of the cost stopping criterion
121    ///
122    /// The provided value must be non-negative. Defaults to `EPSILON`.
123    ///
124    /// # Example
125    ///
126    /// ```
127    /// # use argmin::solver::quasinewton::SR1;
128    /// # use argmin::core::Error;
129    /// # fn main() -> Result<(), Error> {
130    /// # let linesearch = ();
131    /// let sr1: SR1<_, f64> = SR1::new(linesearch).with_tolerance_cost(1e-6)?;
132    /// # Ok(())
133    /// # }
134    /// ```
135    pub fn with_tolerance_cost(mut self, tol_cost: F) -> Result<Self, Error> {
136        if tol_cost < float!(0.0) {
137            return Err(argmin_error!(
138                InvalidParameter,
139                "`SR1`: cost tolerance must be >= 0."
140            ));
141        }
142        self.tol_cost = tol_cost;
143        Ok(self)
144    }
145}
146
147impl<O, L, P, G, H, F> Solver<O, IterState<P, G, (), H, (), F>> for SR1<L, F>
148where
149    O: CostFunction<Param = P, Output = F> + Gradient<Param = P, Gradient = G>,
150    P: Clone
151        + ArgminSub<P, P>
152        + ArgminDot<G, F>
153        + ArgminDot<P, F>
154        + ArgminDot<P, H>
155        + ArgminL2Norm<F>
156        + ArgminMul<F, P>,
157    G: Clone + ArgminSub<P, P> + ArgminL2Norm<F> + ArgminSub<G, G>,
158    H: ArgminDot<G, P> + ArgminDot<P, P> + ArgminAdd<H, H> + ArgminMul<F, H>,
159    L: Clone + LineSearch<P, F> + Solver<O, IterState<P, G, (), (), (), F>>,
160    F: ArgminFloat,
161{
162    fn name(&self) -> &str {
163        "SR1"
164    }
165
166    fn init(
167        &mut self,
168        problem: &mut Problem<O>,
169        mut state: IterState<P, G, (), H, (), F>,
170    ) -> Result<(IterState<P, G, (), H, (), F>, Option<KV>), Error> {
171        let param = state.take_param().ok_or_else(argmin_error_closure!(
172            NotInitialized,
173            concat!(
174                "`SR1` requires an initial parameter vector. ",
175                "Please provide an initial guess via `Executor`s `configure` method."
176            )
177        ))?;
178
179        let inv_hessian = state.take_inv_hessian().ok_or_else(argmin_error_closure!(
180            NotInitialized,
181            concat!(
182                "`SR1` requires an initial inverse Hessian. ",
183                "Please provide an initial guess via `Executor`s `configure` method."
184            )
185        ))?;
186
187        let cost = state.get_cost();
188        let cost = if cost.is_infinite() {
189            problem.cost(&param)?
190        } else {
191            cost
192        };
193
194        let grad = state
195            .take_gradient()
196            .map(Result::Ok)
197            .unwrap_or_else(|| problem.gradient(&param))?;
198        Ok((
199            state
200                .param(param)
201                .cost(cost)
202                .gradient(grad)
203                .inv_hessian(inv_hessian),
204            None,
205        ))
206    }
207
208    fn next_iter(
209        &mut self,
210        problem: &mut Problem<O>,
211        mut state: IterState<P, G, (), H, (), F>,
212    ) -> Result<(IterState<P, G, (), H, (), F>, Option<KV>), Error> {
213        let param = state.take_param().ok_or_else(argmin_error_closure!(
214            PotentialBug,
215            "`SR1`: Parameter vector in state not set."
216        ))?;
217        let cost = state.get_cost();
218
219        let prev_grad = state.take_gradient().ok_or_else(argmin_error_closure!(
220            PotentialBug,
221            "`SR1`: Gradient in state not set."
222        ))?;
223
224        let mut inv_hessian = state.take_inv_hessian().ok_or_else(argmin_error_closure!(
225            PotentialBug,
226            "`SR1`: Inverse Hessian in state not set."
227        ))?;
228
229        let p = inv_hessian.dot(&prev_grad).mul(&float!(-1.0));
230
231        self.linesearch.search_direction(p);
232
233        // Run solver
234        let OptimizationResult {
235            problem: line_problem,
236            state: mut linesearch_state,
237            ..
238        } = Executor::new(problem.take_problem().unwrap(), self.linesearch.clone())
239            .configure(|config| {
240                config
241                    .param(param.clone())
242                    .gradient(prev_grad.clone())
243                    .cost(cost)
244            })
245            .ctrlc(false)
246            .run()?;
247
248        let xk1 = linesearch_state.take_param().unwrap();
249        let next_cost = linesearch_state.get_cost();
250
251        // take care of function eval counts
252        problem.consume_problem(line_problem);
253
254        let grad = problem.gradient(&xk1)?;
255        let yk = grad.sub(&prev_grad);
256
257        let sk = xk1.sub(&param);
258
259        // let skmhkyk: P = sk.sub(&inv_hessian.dot(&yk));
260        // let a: H = skmhkyk.dot(&skmhkyk);
261        // let b: F = skmhkyk.dot(&yk);
262        let ykmbksk: P = yk.sub(&inv_hessian.dot(&sk));
263        let a: H = ykmbksk.dot(&ykmbksk);
264        let b: F = ykmbksk.dot(&sk);
265
266        // let hessian_update = b.abs() >= self.r * yk.l2_norm() * skmhkyk.l2_norm();
267        let hessian_update = b.abs() >= self.denominator_factor * sk.l2_norm() * ykmbksk.l2_norm();
268
269        if hessian_update {
270            inv_hessian = inv_hessian.add(&a.mul(&(float!(1.0) / b)));
271        }
272
273        Ok((
274            state
275                .param(xk1)
276                .cost(next_cost)
277                .gradient(grad)
278                .inv_hessian(inv_hessian),
279            Some(kv!["denominator" => b; "hessian_update" => hessian_update;]),
280        ))
281    }
282
283    fn terminate(&mut self, state: &IterState<P, G, (), H, (), F>) -> TerminationStatus {
284        if state.get_gradient().unwrap().l2_norm() < self.tol_grad {
285            return TerminationStatus::Terminated(TerminationReason::SolverConverged);
286        }
287        if (state.get_prev_cost() - state.cost).abs() < self.tol_cost {
288            return TerminationStatus::Terminated(TerminationReason::SolverConverged);
289        }
290        TerminationStatus::NotTerminated
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297    use crate::core::{test_utils::TestProblem, ArgminError, State};
298    use crate::solver::linesearch::MoreThuenteLineSearch;
299
300    test_trait_impl!(
301        sr1,
302        SR1<MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64>, f64>
303    );
304
305    #[test]
306    fn test_new() {
307        #[derive(Eq, PartialEq, Debug)]
308        struct MyFakeLineSearch {}
309
310        let sr1: SR1<_, f64> = SR1::new(MyFakeLineSearch {});
311        let SR1 {
312            denominator_factor,
313            linesearch,
314            tol_grad,
315            tol_cost,
316        } = sr1;
317
318        assert_eq!(linesearch, MyFakeLineSearch {});
319        assert_eq!(tol_grad.to_ne_bytes(), f64::EPSILON.sqrt().to_ne_bytes());
320        assert_eq!(tol_cost.to_ne_bytes(), f64::EPSILON.to_ne_bytes());
321        assert_eq!(denominator_factor.to_ne_bytes(), 1e-8f64.to_ne_bytes());
322    }
323
324    #[test]
325    fn test_with_denominator_factor() {
326        #[derive(Eq, PartialEq, Debug, Clone, Copy)]
327        struct MyFakeLineSearch {}
328
329        // correct parameters
330        for tol in [f64::EPSILON, 1e-8, 1e-6, 1e-2, 1.0 - f64::EPSILON] {
331            let sr1: SR1<_, f64> = SR1::new(MyFakeLineSearch {});
332            let res = sr1.with_denominator_factor(tol);
333            assert!(res.is_ok());
334
335            let nm = res.unwrap();
336            assert_eq!(nm.denominator_factor.to_ne_bytes(), tol.to_ne_bytes());
337        }
338
339        // incorrect parameters
340        for tol in [-f64::EPSILON, 0.0, -1.0, 1.0] {
341            let sr1: SR1<_, f64> = SR1::new(MyFakeLineSearch {});
342            let res = sr1.with_denominator_factor(tol);
343            assert_error!(
344                res,
345                ArgminError,
346                "Invalid parameter: \"`SR1`: denominator_factor must be in (0, 1).\""
347            );
348        }
349    }
350
351    #[test]
352    fn test_with_tolerance_grad() {
353        #[derive(Eq, PartialEq, Debug, Clone, Copy)]
354        struct MyFakeLineSearch {}
355
356        // correct parameters
357        for tol in [1e-6, 0.0, 1e-2, 1.0, 2.0] {
358            let sr1: SR1<_, f64> = SR1::new(MyFakeLineSearch {});
359            let res = sr1.with_tolerance_grad(tol);
360            assert!(res.is_ok());
361
362            let nm = res.unwrap();
363            assert_eq!(nm.tol_grad.to_ne_bytes(), tol.to_ne_bytes());
364        }
365
366        // incorrect parameters
367        for tol in [-f64::EPSILON, -1.0, -100.0, -42.0] {
368            let sr1: SR1<_, f64> = SR1::new(MyFakeLineSearch {});
369            let res = sr1.with_tolerance_grad(tol);
370            assert_error!(
371                res,
372                ArgminError,
373                "Invalid parameter: \"`SR1`: gradient tolerance must be >= 0.\""
374            );
375        }
376    }
377
378    #[test]
379    fn test_with_tolerance_cost() {
380        #[derive(Eq, PartialEq, Debug, Clone, Copy)]
381        struct MyFakeLineSearch {}
382
383        // correct parameters
384        for tol in [1e-6, 0.0, 1e-2, 1.0, 2.0] {
385            let sr1: SR1<_, f64> = SR1::new(MyFakeLineSearch {});
386            let res = sr1.with_tolerance_cost(tol);
387            assert!(res.is_ok());
388
389            let nm = res.unwrap();
390            assert_eq!(nm.tol_cost.to_ne_bytes(), tol.to_ne_bytes());
391        }
392
393        // incorrect parameters
394        for tol in [-f64::EPSILON, -1.0, -100.0, -42.0] {
395            let sr1: SR1<_, f64> = SR1::new(MyFakeLineSearch {});
396            let res = sr1.with_tolerance_cost(tol);
397            assert_error!(
398                res,
399                ArgminError,
400                "Invalid parameter: \"`SR1`: cost tolerance must be >= 0.\""
401            );
402        }
403    }
404
405    #[test]
406    fn test_init() {
407        let linesearch = MoreThuenteLineSearch::new().with_c(1e-4, 0.9).unwrap();
408
409        let param: Vec<f64> = vec![-1.0, 1.0];
410        let inv_hessian: Vec<Vec<f64>> = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
411
412        let mut sr1: SR1<_, f64> = SR1::new(linesearch);
413
414        // Forgot to initialize the parameter vector
415        let state: IterState<Vec<f64>, Vec<f64>, (), Vec<Vec<f64>>, (), f64> = IterState::new();
416        let problem = TestProblem::new();
417        let res = sr1.init(&mut Problem::new(problem), state);
418        assert_error!(
419            res,
420            ArgminError,
421            concat!(
422                "Not initialized: \"`SR1` requires an initial parameter vector. Please ",
423                "provide an initial guess via `Executor`s `configure` method.\""
424            )
425        );
426
427        // Forgot initial inverse Hessian guess
428        let state: IterState<Vec<f64>, Vec<f64>, (), Vec<Vec<f64>>, (), f64> =
429            IterState::new().param(param.clone());
430        let problem = TestProblem::new();
431        let res = sr1.init(&mut Problem::new(problem), state);
432
433        assert_error!(
434            res,
435            ArgminError,
436            concat!(
437                "Not initialized: \"`SR1` requires an initial inverse Hessian. Please ",
438                "provide an initial guess via `Executor`s `configure` method.\""
439            )
440        );
441
442        // All good.
443        let state: IterState<Vec<f64>, Vec<f64>, (), Vec<Vec<f64>>, (), f64> = IterState::new()
444            .param(param.clone())
445            .inv_hessian(inv_hessian.clone());
446        let problem = TestProblem::new();
447        let (mut state_out, kv) = sr1.init(&mut Problem::new(problem), state).unwrap();
448
449        assert!(kv.is_none());
450
451        let s_param = state_out.take_param().unwrap();
452
453        for (s, p) in s_param.iter().zip(param.iter()) {
454            assert_eq!(s.to_ne_bytes(), p.to_ne_bytes());
455        }
456
457        let s_grad = state_out.take_gradient().unwrap();
458
459        for (s, p) in s_grad.iter().zip(param.iter()) {
460            assert_eq!(s.to_ne_bytes(), p.to_ne_bytes());
461        }
462
463        let s_inv_hessian = state_out.take_inv_hessian().unwrap();
464
465        for (s, h) in s_inv_hessian
466            .iter()
467            .flatten()
468            .zip(inv_hessian.iter().flatten())
469        {
470            assert_eq!(s.to_ne_bytes(), h.to_ne_bytes());
471        }
472
473        assert_eq!(state_out.get_cost().to_ne_bytes(), 1.0f64.to_ne_bytes())
474    }
475
476    #[test]
477    fn test_init_provided_cost() {
478        let linesearch = MoreThuenteLineSearch::new().with_c(1e-4, 0.9).unwrap();
479
480        let param: Vec<f64> = vec![-1.0, 1.0];
481        let inv_hessian: Vec<Vec<f64>> = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
482
483        let mut sr1: SR1<_, f64> = SR1::new(linesearch);
484
485        let state: IterState<Vec<f64>, Vec<f64>, (), Vec<Vec<f64>>, (), f64> = IterState::new()
486            .param(param)
487            .inv_hessian(inv_hessian)
488            .cost(1234.0);
489
490        let problem = TestProblem::new();
491        let (state_out, kv) = sr1.init(&mut Problem::new(problem), state).unwrap();
492
493        assert!(kv.is_none());
494
495        assert_eq!(state_out.get_cost().to_ne_bytes(), 1234.0f64.to_ne_bytes())
496    }
497
498    #[test]
499    fn test_init_provided_grad() {
500        let linesearch = MoreThuenteLineSearch::new().with_c(1e-4, 0.9).unwrap();
501
502        let param: Vec<f64> = vec![-1.0, 1.0];
503        let gradient: Vec<f64> = vec![4.0, 9.0];
504        let inv_hessian: Vec<Vec<f64>> = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
505
506        let mut sr1: SR1<_, f64> = SR1::new(linesearch);
507
508        let state: IterState<Vec<f64>, Vec<f64>, (), Vec<Vec<f64>>, (), f64> = IterState::new()
509            .param(param)
510            .inv_hessian(inv_hessian)
511            .gradient(gradient.clone());
512
513        let problem = TestProblem::new();
514        let (mut state_out, kv) = sr1.init(&mut Problem::new(problem), state).unwrap();
515
516        assert!(kv.is_none());
517
518        let s_grad = state_out.take_gradient().unwrap();
519
520        for (s, g) in s_grad.iter().zip(gradient.iter()) {
521            assert_eq!(s.to_ne_bytes(), g.to_ne_bytes());
522        }
523    }
524}