1use crate::core::{
9 ArgminFloat, Error, IterState, Jacobian, Operator, Problem, Solver, State, TerminationReason,
10 TerminationStatus, KV,
11};
12use argmin_math::{ArgminDot, ArgminInv, ArgminL2Norm, ArgminMul, ArgminSub, ArgminTranspose};
13#[cfg(feature = "serde1")]
14use serde::{Deserialize, Serialize};
15
16#[derive(Clone)]
31#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
32pub struct GaussNewton<F> {
33 gamma: F,
35 tol: F,
37}
38
39impl<F: ArgminFloat> GaussNewton<F> {
40 pub fn new() -> Self {
49 GaussNewton {
50 gamma: float!(1.0),
51 tol: F::epsilon().sqrt(),
52 }
53 }
54
55 pub fn with_gamma(mut self, gamma: F) -> Result<Self, Error> {
70 if gamma <= float!(0.0) || gamma > float!(1.0) {
71 return Err(argmin_error!(
72 InvalidParameter,
73 "Gauss-Newton: gamma must be in (0, 1]."
74 ));
75 }
76 self.gamma = gamma;
77 Ok(self)
78 }
79
80 pub fn with_tolerance(mut self, tol: F) -> Result<Self, Error> {
95 if tol <= float!(0.0) {
96 return Err(argmin_error!(
97 InvalidParameter,
98 "Gauss-Newton: tol must be positive."
99 ));
100 }
101 self.tol = tol;
102 Ok(self)
103 }
104}
105
106impl<F: ArgminFloat> Default for GaussNewton<F> {
107 fn default() -> GaussNewton<F> {
108 GaussNewton::new()
109 }
110}
111
112impl<O, P, J, R, F> Solver<O, IterState<P, (), J, (), R, F>> for GaussNewton<F>
113where
114 O: Operator<Param = P, Output = R> + Jacobian<Param = P, Jacobian = J>,
115 P: Clone + ArgminSub<P, P> + ArgminMul<F, P>,
116 R: ArgminL2Norm<F>,
117 J: Clone
118 + ArgminTranspose<J>
119 + ArgminInv<J>
120 + ArgminDot<J, J>
121 + ArgminDot<R, P>
122 + ArgminDot<P, P>,
123 F: ArgminFloat,
124{
125 fn name(&self) -> &str {
126 "Gauss-Newton method"
127 }
128
129 fn init(
130 &mut self,
131 problem: &mut Problem<O>,
132 mut state: IterState<P, (), J, (), R, F>,
133 ) -> Result<(IterState<P, (), J, (), R, F>, Option<KV>), Error> {
134 let init_param = state.take_param().ok_or_else(argmin_error_closure!(
135 NotInitialized,
136 concat!(
137 "`GaussNewton` requires an initial parameter vector. ",
138 "Please provide an initial guess via `Executor`s `configure` method."
139 )
140 ))?;
141 let residuals = problem.apply(&init_param)?;
142 let cost = residuals.l2_norm();
143 Ok((
144 state.param(init_param).residuals(residuals).cost(cost),
145 None,
146 ))
147 }
148
149 fn next_iter(
150 &mut self,
151 problem: &mut Problem<O>,
152 state: IterState<P, (), J, (), R, F>,
153 ) -> Result<(IterState<P, (), J, (), R, F>, Option<KV>), Error> {
154 let param = state.get_param().ok_or_else(argmin_error_closure!(
155 PotentialBug,
156 "`GaussNewton`: `param` not set"
157 ))?;
158 let residuals = state.get_residuals().ok_or_else(argmin_error_closure!(
159 PotentialBug,
160 "`GaussNewton`: `residuals` not set"
161 ))?;
162 let jacobian = problem.jacobian(param)?;
163
164 let p = jacobian
165 .clone()
166 .t()
167 .dot(&jacobian)
168 .inv()?
169 .dot(&jacobian.t().dot(residuals));
170
171 let new_param = param.sub(&p.mul(&self.gamma));
172 let residuals = problem.apply(&new_param)?;
173
174 let cost = residuals.l2_norm();
175
176 Ok((state.param(new_param).residuals(residuals).cost(cost), None))
177 }
178
179 fn terminate(&mut self, state: &IterState<P, (), J, (), R, F>) -> TerminationStatus {
180 if (state.get_prev_cost() - state.get_cost()).abs() < self.tol {
181 return TerminationStatus::Terminated(TerminationReason::SolverConverged);
182 }
183 TerminationStatus::NotTerminated
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190 use crate::core::ArgminError;
191 #[cfg(feature = "_ndarrayl")]
192 use crate::core::Executor;
193 #[cfg(feature = "_ndarrayl")]
194 use approx::assert_relative_eq;
195
196 test_trait_impl!(gauss_newton_method, GaussNewton<f64>);
197
198 #[test]
199 fn test_new() {
200 let GaussNewton { tol: t, gamma: g } = GaussNewton::<f64>::new();
201
202 assert_eq!(g.to_ne_bytes(), (1.0f64).to_ne_bytes());
203 assert_eq!(t.to_ne_bytes(), f64::EPSILON.sqrt().to_ne_bytes());
204 }
205
206 #[test]
207 fn test_tolerance() {
208 let tol1: f64 = 1e-4;
209
210 let GaussNewton { tol: t, .. } = GaussNewton::new().with_tolerance(tol1).unwrap();
211
212 assert_eq!(t.to_ne_bytes(), tol1.to_ne_bytes());
213 }
214
215 #[test]
216 fn test_tolerance_error() {
217 let tol = -2.0;
218 let error = GaussNewton::new().with_tolerance(tol);
219 assert_error!(
220 error,
221 ArgminError,
222 "Invalid parameter: \"Gauss-Newton: tol must be positive.\""
223 );
224 }
225
226 #[test]
227 fn test_gamma() {
228 let gamma: f64 = 0.5;
229
230 let GaussNewton { gamma: g, .. } = GaussNewton::new().with_gamma(gamma).unwrap();
231
232 assert_eq!(g.to_ne_bytes(), gamma.to_ne_bytes());
233 }
234
235 #[test]
236 fn test_gamma_errors() {
237 let gamma = -0.5;
238 let error = GaussNewton::new().with_gamma(gamma);
239 assert_error!(
240 error,
241 ArgminError,
242 "Invalid parameter: \"Gauss-Newton: gamma must be in (0, 1].\""
243 );
244
245 let gamma = 0.0;
246 let error = GaussNewton::new().with_gamma(gamma);
247 assert_error!(
248 error,
249 ArgminError,
250 "Invalid parameter: \"Gauss-Newton: gamma must be in (0, 1].\""
251 );
252
253 let gamma = 2.0;
254 let error = GaussNewton::new().with_gamma(gamma);
255 assert_error!(
256 error,
257 ArgminError,
258 "Invalid parameter: \"Gauss-Newton: gamma must be in (0, 1].\""
259 );
260 }
261
262 #[cfg(feature = "_ndarrayl")]
263 #[test]
264 fn test_init_param_not_initialized() {
265 use ndarray::{Array, Array1, Array2};
266
267 struct TestProblem {}
268
269 impl Operator for TestProblem {
270 type Param = Array1<f64>;
271 type Output = Array1<f64>;
272
273 fn apply(&self, _p: &Self::Param) -> Result<Self::Output, Error> {
274 Ok(Array1::from_vec(vec![0.5, 2.0]))
275 }
276 }
277
278 impl Jacobian for TestProblem {
279 type Param = Array1<f64>;
280 type Jacobian = Array2<f64>;
281
282 fn jacobian(&self, _p: &Self::Param) -> Result<Self::Jacobian, Error> {
283 Ok(Array::from_shape_vec((2, 2), vec![1f64, 2.0, 3.0, 4.0])?)
284 }
285 }
286
287 let mut gn = GaussNewton::<f64>::new();
288 let res = gn.init(&mut Problem::new(TestProblem {}), IterState::new());
289 assert_error!(
290 res,
291 ArgminError,
292 concat!(
293 "Not initialized: \"`GaussNewton` requires an initial parameter vector. ",
294 "Please provide an initial guess via `Executor`s `configure` method.\""
295 )
296 );
297 }
298
299 #[cfg(feature = "_ndarrayl")]
300 #[test]
301 fn test_next_iter_param_not_initialized() {
302 use ndarray::{Array, Array1, Array2};
303
304 struct TestProblem {}
305
306 impl Operator for TestProblem {
307 type Param = Array1<f64>;
308 type Output = Array1<f64>;
309
310 fn apply(&self, _p: &Self::Param) -> Result<Self::Output, Error> {
311 Ok(Array1::from_vec(vec![0.5, 2.0]))
312 }
313 }
314
315 impl Jacobian for TestProblem {
316 type Param = Array1<f64>;
317 type Jacobian = Array2<f64>;
318
319 fn jacobian(&self, _p: &Self::Param) -> Result<Self::Jacobian, Error> {
320 Ok(Array::from_shape_vec((2, 2), vec![1f64, 2.0, 3.0, 4.0])?)
321 }
322 }
323
324 let mut gn = GaussNewton::<f64>::new();
325 let res = gn.next_iter(&mut Problem::new(TestProblem {}), IterState::new());
326 assert_error!(
327 res,
328 ArgminError,
329 concat!(
330 "Potential bug: \"`GaussNewton`: ",
331 "`param` not set\". This is potentially a bug. ",
332 "Please file a report on https://github.com/argmin-rs/argmin/issues"
333 )
334 );
335 }
336
337 #[cfg(feature = "_ndarrayl")]
338 #[test]
339 fn test_next_iter_residual_not_initialized() {
340 use ndarray::{Array, Array1, Array2};
341
342 struct TestProblem {}
343
344 impl Operator for TestProblem {
345 type Param = Array1<f64>;
346 type Output = Array1<f64>;
347
348 fn apply(&self, _p: &Self::Param) -> Result<Self::Output, Error> {
349 Ok(Array1::from_vec(vec![0.5, 2.0]))
350 }
351 }
352
353 impl Jacobian for TestProblem {
354 type Param = Array1<f64>;
355 type Jacobian = Array2<f64>;
356
357 fn jacobian(&self, _p: &Self::Param) -> Result<Self::Jacobian, Error> {
358 Ok(Array::from_shape_vec((2, 2), vec![1f64, 2.0, 3.0, 4.0])?)
359 }
360 }
361
362 let mut gn = GaussNewton::<f64>::new();
363 let res = gn.next_iter(
364 &mut Problem::new(TestProblem {}),
365 IterState::new().param(vec![1f64, 2.0, 3.0, 4.0].into()),
366 );
367 assert_error!(
368 res,
369 ArgminError,
370 concat!(
371 "Potential bug: \"`GaussNewton`: ",
372 "`residuals` not set\". This is potentially a bug. ",
373 "Please file a report on https://github.com/argmin-rs/argmin/issues"
374 )
375 );
376 }
377
378 #[cfg(feature = "_ndarrayl")]
379 #[test]
380 fn test_solver() {
381 use crate::core::State;
382 use approx::assert_relative_eq;
383 use ndarray::{Array, Array1, Array2};
384 use std::cell::RefCell;
385
386 struct Problem {
387 counter: RefCell<usize>,
388 }
389
390 impl Operator for Problem {
391 type Param = Array1<f64>;
392 type Output = Array1<f64>;
393
394 fn apply(&self, _p: &Self::Param) -> Result<Self::Output, Error> {
395 if *self.counter.borrow() == 0 {
396 let mut c = self.counter.borrow_mut();
397 *c += 1;
398 Ok(Array1::from_vec(vec![0.5, 2.0]))
399 } else {
400 Ok(Array1::from_vec(vec![0.3, 1.0]))
401 }
402 }
403 }
404
405 impl Jacobian for Problem {
406 type Param = Array1<f64>;
407 type Jacobian = Array2<f64>;
408
409 fn jacobian(&self, _p: &Self::Param) -> Result<Self::Jacobian, Error> {
410 Ok(Array::from_shape_vec((2, 2), vec![1f64, 2.0, 3.0, 4.0])?)
411 }
412 }
413
414 let problem = Problem {
416 counter: RefCell::new(0),
417 };
418 let solver: GaussNewton<f64> = GaussNewton::new();
419 let init_param = Array1::from_vec(vec![0.0, 0.0]);
420
421 let state = Executor::new(problem, solver)
422 .configure(|config| config.param(init_param).max_iters(1))
423 .run()
424 .unwrap()
425 .state;
426 let param = state.get_best_param().unwrap().clone();
427 assert_relative_eq!(param[0], -1.0, epsilon = f64::EPSILON.sqrt());
428 assert_relative_eq!(param[1], 0.25, epsilon = f64::EPSILON.sqrt());
429
430 assert_relative_eq!(state.get_residuals().unwrap().l2_norm(), state.get_cost());
432
433 let problem = Problem {
435 counter: RefCell::new(0),
436 };
437 let solver: GaussNewton<f64> = GaussNewton::new();
438 let init_param = Array1::from_vec(vec![0.0, 0.0]);
439
440 let state = Executor::new(problem, solver)
441 .configure(|config| config.param(init_param).max_iters(2))
442 .run()
443 .unwrap()
444 .state;
445 let param = state.get_best_param().unwrap().clone();
446 assert_relative_eq!(param[0], -1.0, epsilon = f64::EPSILON.sqrt());
447 assert_relative_eq!(param[1], 0.25, epsilon = f64::EPSILON.sqrt());
448
449 assert_relative_eq!(state.get_residuals().unwrap().l2_norm(), state.get_cost());
451
452 let problem = Problem {
454 counter: RefCell::new(0),
455 };
456 let solver: GaussNewton<f64> = GaussNewton::new().with_gamma(0.5).unwrap();
457 let init_param = Array1::from_vec(vec![0.0, 0.0]);
458
459 let state = Executor::new(problem, solver)
460 .configure(|config| config.param(init_param).max_iters(1))
461 .run()
462 .unwrap()
463 .state;
464 let param = state.get_best_param().unwrap().clone();
465 assert_relative_eq!(param[0], -0.5, epsilon = f64::EPSILON.sqrt());
466 assert_relative_eq!(param[1], 0.125, epsilon = f64::EPSILON.sqrt());
467
468 assert_relative_eq!(state.get_residuals().unwrap().l2_norm(), state.get_cost());
470
471 let problem = Problem {
473 counter: RefCell::new(0),
474 };
475 let solver: GaussNewton<f64> = GaussNewton::new().with_gamma(0.5).unwrap();
476 let init_param = Array1::from_vec(vec![0.0, 0.0]);
477
478 let state = Executor::new(problem, solver)
479 .configure(|config| config.param(init_param).max_iters(2))
480 .run()
481 .unwrap()
482 .state;
483 let param = state.get_best_param().unwrap().clone();
484 assert_relative_eq!(param[0], -0.5, epsilon = f64::EPSILON.sqrt());
485 assert_relative_eq!(param[1], 0.125, epsilon = f64::EPSILON.sqrt());
486
487 assert_relative_eq!(state.get_residuals().unwrap().l2_norm(), state.get_cost());
489 }
490}