1use crate::core::{
9 ArgminFloat, CostFunction, Error, Executor, Gradient, IterState, LineSearch, NLCGBetaUpdate,
10 OptimizationResult, Problem, Solver, State, KV,
11};
12use argmin_math::{ArgminAdd, ArgminDot, ArgminL2Norm, ArgminMul};
13#[cfg(feature = "serde1")]
14use serde::{Deserialize, Serialize};
15
16#[derive(Clone)]
31#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
32pub struct NonlinearConjugateGradient<P, L, B, F> {
33 p: Option<P>,
35 beta: F,
37 linesearch: L,
39 beta_method: B,
41 restart_iter: u64,
43 restart_orthogonality: Option<F>,
45}
46
47impl<P, L, B, F> NonlinearConjugateGradient<P, L, B, F>
48where
49 F: ArgminFloat,
50{
51 pub fn new(linesearch: L, beta_method: B) -> Self {
65 NonlinearConjugateGradient {
66 p: None,
67 beta: F::nan(),
68 linesearch,
69 beta_method,
70 restart_iter: u64::MAX,
71 restart_orthogonality: None,
72 }
73 }
74
75 #[must_use]
90 pub fn restart_iters(mut self, iters: u64) -> Self {
91 self.restart_iter = iters;
92 self
93 }
94
95 #[must_use]
115 pub fn restart_orthogonality(mut self, v: F) -> Self {
116 self.restart_orthogonality = Some(v);
117 self
118 }
119}
120
121impl<O, P, G, L, B, F> Solver<O, IterState<P, G, (), (), (), F>>
122 for NonlinearConjugateGradient<P, L, B, F>
123where
124 O: CostFunction<Param = P, Output = F> + Gradient<Param = P, Gradient = G>,
125 P: Clone + ArgminAdd<P, P> + ArgminMul<F, P>,
126 G: Clone + ArgminMul<F, P> + ArgminDot<G, F> + ArgminL2Norm<F>,
127 L: Clone + LineSearch<P, F> + Solver<O, IterState<P, G, (), (), (), F>>,
128 B: NLCGBetaUpdate<G, P, F>,
129 F: ArgminFloat,
130{
131 fn name(&self) -> &str {
132 "Nonlinear Conjugate Gradient"
133 }
134
135 fn init(
136 &mut self,
137 problem: &mut Problem<O>,
138 state: IterState<P, G, (), (), (), F>,
139 ) -> Result<(IterState<P, G, (), (), (), F>, Option<KV>), Error> {
140 let param = state.get_param().ok_or_else(argmin_error_closure!(
141 NotInitialized,
142 concat!(
143 "`NonlinearConjugateGradient` requires an initial parameter vector. ",
144 "Please provide an initial guess via `Executor`s `configure` method."
145 )
146 ))?;
147 let cost = problem.cost(param)?;
148 let grad = problem.gradient(param)?;
149 self.p = Some(grad.mul(&(float!(-1.0))));
150 Ok((state.cost(cost).gradient(grad), None))
151 }
152
153 fn next_iter(
154 &mut self,
155 problem: &mut Problem<O>,
156 mut state: IterState<P, G, (), (), (), F>,
157 ) -> Result<(IterState<P, G, (), (), (), F>, Option<KV>), Error> {
158 let p = self.p.as_ref().ok_or_else(argmin_error_closure!(
159 PotentialBug,
160 "`NonlinearConjugateGradient`: Field `p` not set"
161 ))?;
162 let xk = state.take_param().ok_or_else(argmin_error_closure!(
163 PotentialBug,
164 "`NonlinearConjugateGradient`: No `param` in `state`"
165 ))?;
166 let grad = state
167 .take_gradient()
168 .map(Result::Ok)
169 .unwrap_or_else(|| problem.gradient(&xk))?;
170 let cur_cost = state.cost;
171
172 self.linesearch.search_direction(p.clone());
174
175 let OptimizationResult {
177 problem: line_problem,
178 state: mut line_state,
179 ..
180 } = Executor::new(
181 problem.take_problem().ok_or_else(argmin_error_closure!(
182 PotentialBug,
183 "`NonlinearConjugateGradient`: Failed to take `problem` for line search"
184 ))?,
185 self.linesearch.clone(),
186 )
187 .configure(|state| state.param(xk).gradient(grad.clone()).cost(cur_cost))
188 .ctrlc(false)
189 .run()?;
190
191 problem.consume_problem(line_problem);
193
194 let xk1 = line_state.take_param().ok_or_else(argmin_error_closure!(
195 PotentialBug,
196 "`NonlinearConjugateGradient`: No `param` returned by line search"
197 ))?;
198
199 let new_grad = problem.gradient(&xk1)?;
201
202 let restart_orthogonality = match self.restart_orthogonality {
203 Some(v) => new_grad.dot(&grad).abs() / new_grad.l2_norm().powi(2) >= v,
204 None => false,
205 };
206
207 let restart_iter: bool =
208 (state.get_iter().is_multiple_of(self.restart_iter)) && state.get_iter() != 0;
209
210 if restart_iter || restart_orthogonality {
211 self.beta = float!(0.0);
212 } else {
213 self.beta = self.beta_method.update(&grad, &new_grad, p);
214 }
215
216 self.p = Some(new_grad.mul(&(float!(-1.0))).add(&p.mul(&self.beta)));
218
219 let cost = problem.cost(&xk1)?;
221
222 Ok((
223 state.param(xk1).cost(cost).gradient(new_grad),
224 Some(kv!("beta" => self.beta;
225 "restart_iter" => restart_iter;
226 "restart_orthogonality" => restart_orthogonality;
227 )),
228 ))
229 }
230}
231
232#[cfg(test)]
233#[allow(clippy::let_unit_value)]
234mod tests {
235 use super::*;
236 use crate::core::test_utils::TestProblem;
237 use crate::core::ArgminError;
238 use crate::solver::conjugategradient::beta::PolakRibiere;
239 use crate::solver::linesearch::{
240 condition::ArmijoCondition, BacktrackingLineSearch, MoreThuenteLineSearch,
241 };
242 use approx::assert_relative_eq;
243
244 #[derive(Eq, PartialEq, Clone, Copy, Debug)]
245 struct Linesearch {}
246
247 #[derive(Eq, PartialEq, Clone, Copy, Debug)]
248 struct BetaUpdate {}
249
250 test_trait_impl!(
251 nonlinear_cg,
252 NonlinearConjugateGradient<
253 TestProblem,
254 MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64>,
255 PolakRibiere,
256 f64
257 >
258 );
259
260 #[test]
261 fn test_new() {
262 let linesearch = Linesearch {};
263 let beta_method = BetaUpdate {};
264 let nlcg: NonlinearConjugateGradient<Vec<f64>, _, _, f64> =
265 NonlinearConjugateGradient::new(linesearch, beta_method);
266 let NonlinearConjugateGradient {
267 p,
268 beta,
269 linesearch,
270 beta_method,
271 restart_iter,
272 restart_orthogonality,
273 } = nlcg;
274 assert!(p.is_none());
275 assert!(beta.is_nan());
276 assert_eq!(linesearch, linesearch);
277 assert_eq!(beta_method, beta_method);
278 assert_eq!(restart_iter, u64::MAX);
279 assert!(restart_orthogonality.is_none());
280 }
281
282 #[test]
283 fn test_restart_iters() {
284 let linesearch = ();
285 let beta_method = ();
286 let nlcg: NonlinearConjugateGradient<Vec<f64>, _, _, f64> =
287 NonlinearConjugateGradient::new(linesearch, beta_method);
288 assert_eq!(nlcg.restart_iter, u64::MAX);
289 let nlcg = nlcg.restart_iters(100);
290 assert_eq!(nlcg.restart_iter, 100);
291 }
292
293 #[test]
294 fn test_restart_orthogonality() {
295 let linesearch = ();
296 let beta_method = ();
297 let nlcg: NonlinearConjugateGradient<Vec<f64>, _, _, f64> =
298 NonlinearConjugateGradient::new(linesearch, beta_method);
299 assert!(nlcg.restart_orthogonality.is_none());
300 let nlcg = nlcg.restart_orthogonality(0.1);
301 assert_eq!(
302 nlcg.restart_orthogonality.as_ref().unwrap().to_ne_bytes(),
303 0.1f64.to_ne_bytes()
304 );
305 }
306
307 #[test]
308 fn test_init_param_not_initialized() {
309 let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
310 BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
311 let beta_method = PolakRibiere::new();
312 let mut nlcg: NonlinearConjugateGradient<Vec<f64>, _, _, f64> =
313 NonlinearConjugateGradient::new(linesearch, beta_method);
314 let res = nlcg.init(&mut Problem::new(TestProblem::new()), IterState::new());
315 assert_error!(
316 res,
317 ArgminError,
318 concat!(
319 "Not initialized: \"`NonlinearConjugateGradient` requires an initial parameter vector. ",
320 "Please provide an initial guess via `Executor`s `configure` method.\""
321 )
322 );
323 }
324
325 #[test]
326 fn test_init() {
327 let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
328 BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
329 let beta_method = PolakRibiere::new();
330 let mut nlcg: NonlinearConjugateGradient<Vec<f64>, _, _, f64> =
331 NonlinearConjugateGradient::new(linesearch, beta_method);
332 let state: IterState<Vec<f64>, Vec<f64>, (), (), (), f64> =
333 IterState::new().param(vec![3.0, 4.0]);
334 let (state_out, kv) = nlcg
335 .init(&mut Problem::new(TestProblem::new()), state.clone())
336 .unwrap();
337 assert!(kv.is_none());
338 assert_ne!(state_out, state);
339 assert_eq!(state_out.cost.to_ne_bytes(), 1f64.to_ne_bytes());
340 assert_eq!(
341 state_out.grad.as_ref().unwrap()[0].to_ne_bytes(),
342 3f64.to_ne_bytes()
343 );
344 assert_eq!(
345 state_out.grad.as_ref().unwrap()[1].to_ne_bytes(),
346 4f64.to_ne_bytes()
347 );
348 assert_eq!(
349 state_out.param.as_ref().unwrap()[0].to_ne_bytes(),
350 3f64.to_ne_bytes()
351 );
352 assert_eq!(
353 state_out.param.as_ref().unwrap()[1].to_ne_bytes(),
354 4f64.to_ne_bytes()
355 );
356 assert_eq!(
357 nlcg.p.as_ref().unwrap()[0].to_ne_bytes(),
358 (-3f64).to_ne_bytes()
359 );
360 assert_eq!(
361 nlcg.p.as_ref().unwrap()[1].to_ne_bytes(),
362 (-4f64).to_ne_bytes()
363 );
364 }
365
366 #[test]
367 fn test_next_iter_p_not_set() {
368 let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
369 BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
370 let beta_method = PolakRibiere::new();
371 let mut nlcg: NonlinearConjugateGradient<Vec<f64>, _, _, f64> =
372 NonlinearConjugateGradient::new(linesearch, beta_method);
373 let state = IterState::new().param(vec![1.0f64, 2.0f64]);
374 assert!(nlcg.p.is_none());
375 let res = nlcg.next_iter(&mut Problem::new(TestProblem::new()), state);
376 assert_error!(
377 res,
378 ArgminError,
379 concat!(
380 "Potential bug: \"`NonlinearConjugateGradient`: ",
381 "Field `p` not set\". This is potentially a bug. ",
382 "Please file a report on https://github.com/argmin-rs/argmin/issues"
383 )
384 );
385 }
386
387 #[test]
388 fn test_next_iter_state_param_not_set() {
389 let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
390 BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
391 let beta_method = PolakRibiere::new();
392 let mut nlcg: NonlinearConjugateGradient<Vec<f64>, _, _, f64> =
393 NonlinearConjugateGradient::new(linesearch, beta_method);
394 let state = IterState::new();
395 nlcg.p = Some(vec![]);
396 assert!(nlcg.p.is_some());
397 let res = nlcg.next_iter(&mut Problem::new(TestProblem::new()), state);
398 assert_error!(
399 res,
400 ArgminError,
401 concat!(
402 "Potential bug: \"`NonlinearConjugateGradient`: ",
403 "No `param` in `state`\". This is potentially a bug. ",
404 "Please file a report on https://github.com/argmin-rs/argmin/issues"
405 )
406 );
407 }
408
409 #[test]
410 fn test_next_iter_problem_missing() {
411 let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
412 BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
413 let beta_method = PolakRibiere::new();
414 let mut nlcg: NonlinearConjugateGradient<Vec<f64>, _, _, f64> =
415 NonlinearConjugateGradient::new(linesearch, beta_method);
416 let state = IterState::new()
417 .param(vec![1.0f64, 2.0])
418 .gradient(vec![1.0f64, 2.0]);
419 nlcg.p = Some(vec![]);
420 assert!(nlcg.p.is_some());
421 let mut problem = Problem::new(TestProblem::new());
422 let _ = problem.take_problem().unwrap();
423 let res = nlcg.next_iter(&mut problem, state);
424 assert_error!(
425 res,
426 ArgminError,
427 concat!(
428 "Potential bug: \"`NonlinearConjugateGradient`: ",
429 "Failed to take `problem` for line search\". This is potentially a bug. ",
430 "Please file a report on https://github.com/argmin-rs/argmin/issues"
431 )
432 );
433 }
434
435 #[test]
436 fn test_next_iter() {
437 let linesearch: BacktrackingLineSearch<Vec<f64>, Vec<f64>, ArmijoCondition<f64>, f64> =
438 BacktrackingLineSearch::new(ArmijoCondition::new(0.2).unwrap());
439 let beta_method = PolakRibiere::new();
440 let mut nlcg: NonlinearConjugateGradient<Vec<f64>, _, _, f64> =
441 NonlinearConjugateGradient::new(linesearch, beta_method);
442 let state = IterState::new()
443 .param(vec![1.0f64, 2.0])
444 .gradient(vec![1.0f64, 2.0]);
445 let mut problem = Problem::new(TestProblem::new());
446 let (state, kv) = nlcg.init(&mut problem, state).unwrap();
447 assert!(kv.is_none());
448 let (mut state, kv) = nlcg.next_iter(&mut problem, state).unwrap();
449 state.update();
450 let kv2 = kv!("beta" => 0.0; "restart_iter" => false; "restart_orthogonality" => false;);
451 assert_eq!(kv.unwrap(), kv2);
452 assert_relative_eq!(
453 state.param.as_ref().unwrap()[0],
454 1.0f64,
455 epsilon = f64::EPSILON
456 );
457 assert_relative_eq!(
458 state.param.as_ref().unwrap()[1],
459 2.0f64,
460 epsilon = f64::EPSILON
461 );
462 assert_relative_eq!(
463 state.best_param.as_ref().unwrap()[0],
464 1.0f64,
465 epsilon = f64::EPSILON
466 );
467 assert_relative_eq!(
468 state.best_param.as_ref().unwrap()[1],
469 2.0f64,
470 epsilon = f64::EPSILON
471 );
472 assert_relative_eq!(state.cost, 1.0f64, epsilon = f64::EPSILON);
473 assert_relative_eq!(state.prev_cost, 1.0f64, epsilon = f64::EPSILON);
474 assert_relative_eq!(state.best_cost, 1.0f64, epsilon = f64::EPSILON);
475 assert_relative_eq!(
476 state.grad.as_ref().unwrap()[0],
477 1.0f64,
478 epsilon = f64::EPSILON
479 );
480 assert_relative_eq!(
481 state.grad.as_ref().unwrap()[1],
482 2.0f64,
483 epsilon = f64::EPSILON
484 );
485 }
486}