argmin/solver/gaussnewton/
gaussnewton_method.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use crate::core::{
9    ArgminFloat, Error, IterState, Jacobian, Operator, Problem, Solver, State, TerminationReason,
10    TerminationStatus, KV,
11};
12use argmin_math::{ArgminDot, ArgminInv, ArgminL2Norm, ArgminMul, ArgminSub, ArgminTranspose};
13#[cfg(feature = "serde1")]
14use serde::{Deserialize, Serialize};
15
16/// # Gauss-Newton method
17///
18/// The Gauss-Newton method is used to solve non-linear least squares problems.
19///
20/// Requires an initial parameter vector.
21///
22/// ## Requirements on the optimization problem
23///
24/// The optimization problem is required to implement [`Operator`] and [`Jacobian`].
25///
26/// ## Reference
27///
28/// Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
29/// Springer. ISBN 0-387-30303-0.
30#[derive(Clone)]
31#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
32pub struct GaussNewton<F> {
33    /// gamma
34    gamma: F,
35    /// Tolerance for the stopping criterion based on cost difference
36    tol: F,
37}
38
39impl<F: ArgminFloat> GaussNewton<F> {
40    /// Construct a new instance of [`GaussNewton`].
41    ///
42    /// # Example
43    ///
44    /// ```
45    /// # use argmin::solver::gaussnewton::GaussNewton;
46    /// let gauss_newton: GaussNewton<f64> = GaussNewton::new();
47    /// ```
48    pub fn new() -> Self {
49        GaussNewton {
50            gamma: float!(1.0),
51            tol: F::epsilon().sqrt(),
52        }
53    }
54
55    /// Set step width gamma.
56    ///
57    /// Gamma must be within `(0, 1]`. Defaults to `1.0`.
58    ///
59    /// # Example
60    ///
61    /// ```
62    /// # use argmin::solver::gaussnewton::GaussNewton;
63    /// # use argmin::core::Error;
64    /// # fn main() -> Result<(), Error> {
65    /// let gauss_newton = GaussNewton::new().with_gamma(0.5f64)?;
66    /// # Ok(())
67    /// # }
68    /// ```
69    pub fn with_gamma(mut self, gamma: F) -> Result<Self, Error> {
70        if gamma <= float!(0.0) || gamma > float!(1.0) {
71            return Err(argmin_error!(
72                InvalidParameter,
73                "Gauss-Newton: gamma must be in  (0, 1]."
74            ));
75        }
76        self.gamma = gamma;
77        Ok(self)
78    }
79
80    /// Set tolerance for the stopping criterion based on cost difference.
81    ///
82    /// Tolerance must be larger than zero and defaults to `sqrt(EPSILON)`.
83    ///
84    /// # Example
85    ///
86    /// ```
87    /// # use argmin::solver::gaussnewton::GaussNewton;
88    /// # use argmin::core::Error;
89    /// # fn main() -> Result<(), Error> {
90    /// let gauss_newton = GaussNewton::new().with_tolerance(1e-4f64)?;
91    /// # Ok(())
92    /// # }
93    /// ```
94    pub fn with_tolerance(mut self, tol: F) -> Result<Self, Error> {
95        if tol <= float!(0.0) {
96            return Err(argmin_error!(
97                InvalidParameter,
98                "Gauss-Newton: tol must be positive."
99            ));
100        }
101        self.tol = tol;
102        Ok(self)
103    }
104}
105
106impl<F: ArgminFloat> Default for GaussNewton<F> {
107    fn default() -> GaussNewton<F> {
108        GaussNewton::new()
109    }
110}
111
112impl<O, P, J, R, F> Solver<O, IterState<P, (), J, (), R, F>> for GaussNewton<F>
113where
114    O: Operator<Param = P, Output = R> + Jacobian<Param = P, Jacobian = J>,
115    P: Clone + ArgminSub<P, P> + ArgminMul<F, P>,
116    R: ArgminL2Norm<F>,
117    J: Clone
118        + ArgminTranspose<J>
119        + ArgminInv<J>
120        + ArgminDot<J, J>
121        + ArgminDot<R, P>
122        + ArgminDot<P, P>,
123    F: ArgminFloat,
124{
125    fn name(&self) -> &str {
126        "Gauss-Newton method"
127    }
128
129    fn init(
130        &mut self,
131        problem: &mut Problem<O>,
132        mut state: IterState<P, (), J, (), R, F>,
133    ) -> Result<(IterState<P, (), J, (), R, F>, Option<KV>), Error> {
134        let init_param = state.take_param().ok_or_else(argmin_error_closure!(
135            NotInitialized,
136            concat!(
137                "`GaussNewton` requires an initial parameter vector. ",
138                "Please provide an initial guess via `Executor`s `configure` method."
139            )
140        ))?;
141        let residuals = problem.apply(&init_param)?;
142        let cost = residuals.l2_norm();
143        Ok((
144            state.param(init_param).residuals(residuals).cost(cost),
145            None,
146        ))
147    }
148
149    fn next_iter(
150        &mut self,
151        problem: &mut Problem<O>,
152        state: IterState<P, (), J, (), R, F>,
153    ) -> Result<(IterState<P, (), J, (), R, F>, Option<KV>), Error> {
154        let param = state.get_param().ok_or_else(argmin_error_closure!(
155            PotentialBug,
156            "`GaussNewton`: `param` not set"
157        ))?;
158        let residuals = state.get_residuals().ok_or_else(argmin_error_closure!(
159            PotentialBug,
160            "`GaussNewton`: `residuals` not set"
161        ))?;
162        let jacobian = problem.jacobian(param)?;
163
164        let p = jacobian
165            .clone()
166            .t()
167            .dot(&jacobian)
168            .inv()?
169            .dot(&jacobian.t().dot(residuals));
170
171        let new_param = param.sub(&p.mul(&self.gamma));
172        let residuals = problem.apply(&new_param)?;
173
174        let cost = residuals.l2_norm();
175
176        Ok((state.param(new_param).residuals(residuals).cost(cost), None))
177    }
178
179    fn terminate(&mut self, state: &IterState<P, (), J, (), R, F>) -> TerminationStatus {
180        if (state.get_prev_cost() - state.get_cost()).abs() < self.tol {
181            return TerminationStatus::Terminated(TerminationReason::SolverConverged);
182        }
183        TerminationStatus::NotTerminated
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use crate::core::ArgminError;
191    #[cfg(feature = "_ndarrayl")]
192    use crate::core::Executor;
193    #[cfg(feature = "_ndarrayl")]
194    use approx::assert_relative_eq;
195
196    test_trait_impl!(gauss_newton_method, GaussNewton<f64>);
197
198    #[test]
199    fn test_new() {
200        let GaussNewton { tol: t, gamma: g } = GaussNewton::<f64>::new();
201
202        assert_eq!(g.to_ne_bytes(), (1.0f64).to_ne_bytes());
203        assert_eq!(t.to_ne_bytes(), f64::EPSILON.sqrt().to_ne_bytes());
204    }
205
206    #[test]
207    fn test_tolerance() {
208        let tol1: f64 = 1e-4;
209
210        let GaussNewton { tol: t, .. } = GaussNewton::new().with_tolerance(tol1).unwrap();
211
212        assert_eq!(t.to_ne_bytes(), tol1.to_ne_bytes());
213    }
214
215    #[test]
216    fn test_tolerance_error() {
217        let tol = -2.0;
218        let error = GaussNewton::new().with_tolerance(tol);
219        assert_error!(
220            error,
221            ArgminError,
222            "Invalid parameter: \"Gauss-Newton: tol must be positive.\""
223        );
224    }
225
226    #[test]
227    fn test_gamma() {
228        let gamma: f64 = 0.5;
229
230        let GaussNewton { gamma: g, .. } = GaussNewton::new().with_gamma(gamma).unwrap();
231
232        assert_eq!(g.to_ne_bytes(), gamma.to_ne_bytes());
233    }
234
235    #[test]
236    fn test_gamma_errors() {
237        let gamma = -0.5;
238        let error = GaussNewton::new().with_gamma(gamma);
239        assert_error!(
240            error,
241            ArgminError,
242            "Invalid parameter: \"Gauss-Newton: gamma must be in  (0, 1].\""
243        );
244
245        let gamma = 0.0;
246        let error = GaussNewton::new().with_gamma(gamma);
247        assert_error!(
248            error,
249            ArgminError,
250            "Invalid parameter: \"Gauss-Newton: gamma must be in  (0, 1].\""
251        );
252
253        let gamma = 2.0;
254        let error = GaussNewton::new().with_gamma(gamma);
255        assert_error!(
256            error,
257            ArgminError,
258            "Invalid parameter: \"Gauss-Newton: gamma must be in  (0, 1].\""
259        );
260    }
261
262    #[cfg(feature = "_ndarrayl")]
263    #[test]
264    fn test_init_param_not_initialized() {
265        use ndarray::{Array, Array1, Array2};
266
267        struct TestProblem {}
268
269        impl Operator for TestProblem {
270            type Param = Array1<f64>;
271            type Output = Array1<f64>;
272
273            fn apply(&self, _p: &Self::Param) -> Result<Self::Output, Error> {
274                Ok(Array1::from_vec(vec![0.5, 2.0]))
275            }
276        }
277
278        impl Jacobian for TestProblem {
279            type Param = Array1<f64>;
280            type Jacobian = Array2<f64>;
281
282            fn jacobian(&self, _p: &Self::Param) -> Result<Self::Jacobian, Error> {
283                Ok(Array::from_shape_vec((2, 2), vec![1f64, 2.0, 3.0, 4.0])?)
284            }
285        }
286
287        let mut gn = GaussNewton::<f64>::new();
288        let res = gn.init(&mut Problem::new(TestProblem {}), IterState::new());
289        assert_error!(
290            res,
291            ArgminError,
292            concat!(
293                "Not initialized: \"`GaussNewton` requires an initial parameter vector. ",
294                "Please provide an initial guess via `Executor`s `configure` method.\""
295            )
296        );
297    }
298
299    #[cfg(feature = "_ndarrayl")]
300    #[test]
301    fn test_next_iter_param_not_initialized() {
302        use ndarray::{Array, Array1, Array2};
303
304        struct TestProblem {}
305
306        impl Operator for TestProblem {
307            type Param = Array1<f64>;
308            type Output = Array1<f64>;
309
310            fn apply(&self, _p: &Self::Param) -> Result<Self::Output, Error> {
311                Ok(Array1::from_vec(vec![0.5, 2.0]))
312            }
313        }
314
315        impl Jacobian for TestProblem {
316            type Param = Array1<f64>;
317            type Jacobian = Array2<f64>;
318
319            fn jacobian(&self, _p: &Self::Param) -> Result<Self::Jacobian, Error> {
320                Ok(Array::from_shape_vec((2, 2), vec![1f64, 2.0, 3.0, 4.0])?)
321            }
322        }
323
324        let mut gn = GaussNewton::<f64>::new();
325        let res = gn.next_iter(&mut Problem::new(TestProblem {}), IterState::new());
326        assert_error!(
327            res,
328            ArgminError,
329            concat!(
330                "Potential bug: \"`GaussNewton`: ",
331                "`param` not set\". This is potentially a bug. ",
332                "Please file a report on https://github.com/argmin-rs/argmin/issues"
333            )
334        );
335    }
336
337    #[cfg(feature = "_ndarrayl")]
338    #[test]
339    fn test_next_iter_residual_not_initialized() {
340        use ndarray::{Array, Array1, Array2};
341
342        struct TestProblem {}
343
344        impl Operator for TestProblem {
345            type Param = Array1<f64>;
346            type Output = Array1<f64>;
347
348            fn apply(&self, _p: &Self::Param) -> Result<Self::Output, Error> {
349                Ok(Array1::from_vec(vec![0.5, 2.0]))
350            }
351        }
352
353        impl Jacobian for TestProblem {
354            type Param = Array1<f64>;
355            type Jacobian = Array2<f64>;
356
357            fn jacobian(&self, _p: &Self::Param) -> Result<Self::Jacobian, Error> {
358                Ok(Array::from_shape_vec((2, 2), vec![1f64, 2.0, 3.0, 4.0])?)
359            }
360        }
361
362        let mut gn = GaussNewton::<f64>::new();
363        let res = gn.next_iter(
364            &mut Problem::new(TestProblem {}),
365            IterState::new().param(vec![1f64, 2.0, 3.0, 4.0].into()),
366        );
367        assert_error!(
368            res,
369            ArgminError,
370            concat!(
371                "Potential bug: \"`GaussNewton`: ",
372                "`residuals` not set\". This is potentially a bug. ",
373                "Please file a report on https://github.com/argmin-rs/argmin/issues"
374            )
375        );
376    }
377
378    #[cfg(feature = "_ndarrayl")]
379    #[test]
380    fn test_solver() {
381        use crate::core::State;
382        use approx::assert_relative_eq;
383        use ndarray::{Array, Array1, Array2};
384        use std::cell::RefCell;
385
386        struct Problem {
387            counter: RefCell<usize>,
388        }
389
390        impl Operator for Problem {
391            type Param = Array1<f64>;
392            type Output = Array1<f64>;
393
394            fn apply(&self, _p: &Self::Param) -> Result<Self::Output, Error> {
395                if *self.counter.borrow() == 0 {
396                    let mut c = self.counter.borrow_mut();
397                    *c += 1;
398                    Ok(Array1::from_vec(vec![0.5, 2.0]))
399                } else {
400                    Ok(Array1::from_vec(vec![0.3, 1.0]))
401                }
402            }
403        }
404
405        impl Jacobian for Problem {
406            type Param = Array1<f64>;
407            type Jacobian = Array2<f64>;
408
409            fn jacobian(&self, _p: &Self::Param) -> Result<Self::Jacobian, Error> {
410                Ok(Array::from_shape_vec((2, 2), vec![1f64, 2.0, 3.0, 4.0])?)
411            }
412        }
413
414        // Single iteration, starting from [0, 0], gamma = 1
415        let problem = Problem {
416            counter: RefCell::new(0),
417        };
418        let solver: GaussNewton<f64> = GaussNewton::new();
419        let init_param = Array1::from_vec(vec![0.0, 0.0]);
420
421        let state = Executor::new(problem, solver)
422            .configure(|config| config.param(init_param).max_iters(1))
423            .run()
424            .unwrap()
425            .state;
426        let param = state.get_best_param().unwrap().clone();
427        assert_relative_eq!(param[0], -1.0, epsilon = f64::EPSILON.sqrt());
428        assert_relative_eq!(param[1], 0.25, epsilon = f64::EPSILON.sqrt());
429
430        // Assert that cost matches residual:
431        assert_relative_eq!(state.get_residuals().unwrap().l2_norm(), state.get_cost());
432
433        // Two iterations, starting from [0, 0], gamma = 1
434        let problem = Problem {
435            counter: RefCell::new(0),
436        };
437        let solver: GaussNewton<f64> = GaussNewton::new();
438        let init_param = Array1::from_vec(vec![0.0, 0.0]);
439
440        let state = Executor::new(problem, solver)
441            .configure(|config| config.param(init_param).max_iters(2))
442            .run()
443            .unwrap()
444            .state;
445        let param = state.get_best_param().unwrap().clone();
446        assert_relative_eq!(param[0], -1.0, epsilon = f64::EPSILON.sqrt());
447        assert_relative_eq!(param[1], 0.25, epsilon = f64::EPSILON.sqrt());
448
449        // Assert that cost matches residual:
450        assert_relative_eq!(state.get_residuals().unwrap().l2_norm(), state.get_cost());
451
452        // Single iteration, starting from [0, 0], gamma = 0.5
453        let problem = Problem {
454            counter: RefCell::new(0),
455        };
456        let solver: GaussNewton<f64> = GaussNewton::new().with_gamma(0.5).unwrap();
457        let init_param = Array1::from_vec(vec![0.0, 0.0]);
458
459        let state = Executor::new(problem, solver)
460            .configure(|config| config.param(init_param).max_iters(1))
461            .run()
462            .unwrap()
463            .state;
464        let param = state.get_best_param().unwrap().clone();
465        assert_relative_eq!(param[0], -0.5, epsilon = f64::EPSILON.sqrt());
466        assert_relative_eq!(param[1], 0.125, epsilon = f64::EPSILON.sqrt());
467
468        // Assert that cost matches residual:
469        assert_relative_eq!(state.get_residuals().unwrap().l2_norm(), state.get_cost());
470
471        // Two iterations, starting from [0, 0], gamma = 0.5
472        let problem = Problem {
473            counter: RefCell::new(0),
474        };
475        let solver: GaussNewton<f64> = GaussNewton::new().with_gamma(0.5).unwrap();
476        let init_param = Array1::from_vec(vec![0.0, 0.0]);
477
478        let state = Executor::new(problem, solver)
479            .configure(|config| config.param(init_param).max_iters(2))
480            .run()
481            .unwrap()
482            .state;
483        let param = state.get_best_param().unwrap().clone();
484        assert_relative_eq!(param[0], -0.5, epsilon = f64::EPSILON.sqrt());
485        assert_relative_eq!(param[1], 0.125, epsilon = f64::EPSILON.sqrt());
486
487        // Assert that cost matches residual:
488        assert_relative_eq!(state.get_residuals().unwrap().l2_norm(), state.get_cost());
489    }
490}