argmin/solver/conjugategradient/beta.rs
1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! # Beta update methods for [`NonlinearConjugateGradient`](`crate::solver::conjugategradient::NonlinearConjugateGradient`)
9//!
10//! These methods define the update procedure for
11//! [`NonlinearConjugateGradient`](`crate::solver::conjugategradient::NonlinearConjugateGradient`).
12//! They are based on the [`NLCGBetaUpdate`] trait which enables users to implement their own beta
13//! update methods.
14//!
15//! # Reference
16//!
17//! \[0\] Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
18//! Springer. ISBN 0-387-30303-0.
19
20use crate::core::ArgminFloat;
21use argmin_math::{ArgminDot, ArgminL2Norm, ArgminSub};
22#[cfg(feature = "serde1")]
23use serde::{Deserialize, Serialize};
24
25/// Interface for beta update methods ([`NonlinearConjugateGradient`](`crate::solver::conjugategradient::NonlinearConjugateGradient`))
26///
27/// # Example
28///
29/// ```
30/// # use argmin::core::{ArgminFloat, NLCGBetaUpdate};
31/// #[cfg(feature = "serde1")]
32/// use serde::{Deserialize, Serialize};
33///
34/// #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
35/// struct MyBetaMethod {}
36///
37/// impl<G, P, F> NLCGBetaUpdate<G, P, F> for MyBetaMethod
38/// where
39/// F: ArgminFloat,
40/// {
41/// fn update(&self, dfk: &G, dfk1: &G, p_k: &P) -> F {
42/// // Compute updated beta
43/// # F::nan()
44/// }
45/// }
46/// ```
47pub trait NLCGBetaUpdate<G, P, F> {
48 /// Update beta.
49 ///
50 /// # Parameters
51 ///
52 /// * `\nabla f_k`
53 /// * `\nabla f_{k+1}`
54 /// * `p_k`
55 fn update(&self, nabla_f_k: &G, nabla_f_k_p_1: &G, p_k: &P) -> F;
56}
57
58/// Fletcher and Reeves (FR) method
59///
60/// Formula: `<\nabla f_{k+1}, \nabla f_{k+1}> / <\nabla f_k, \nabla f_k>`
61#[derive(Default, Copy, Clone, Debug)]
62#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
63pub struct FletcherReeves {}
64
65impl FletcherReeves {
66 /// Construct a new instance of `FletcherReeves`.
67 ///
68 /// # Example
69 ///
70 /// ```
71 /// # use argmin::solver::conjugategradient::beta::FletcherReeves;
72 /// let beta_method = FletcherReeves::new();
73 /// ```
74 pub fn new() -> Self {
75 FletcherReeves {}
76 }
77}
78
79impl<G, P, F> NLCGBetaUpdate<G, P, F> for FletcherReeves
80where
81 G: ArgminDot<G, F>,
82 F: ArgminFloat,
83{
84 /// Update beta using the Fletcher-Reeves method.
85 ///
86 /// Formula: `<\nabla f_{k+1}, \nabla f_{k+1}> / <\nabla f_k, \nabla f_k>`
87 ///
88 /// # Example
89 ///
90 /// ```
91 /// # extern crate approx;
92 /// # use approx::assert_relative_eq;
93 /// # use argmin::solver::conjugategradient::beta::{NLCGBetaUpdate, FletcherReeves};
94 /// # let dfk = vec![1f64, 2.0];
95 /// # let dfk1 = vec![3f64, 4.0];
96 /// let beta_method = FletcherReeves::new();
97 /// let beta: f64 = beta_method.update(&dfk, &dfk1, &());
98 /// # assert_relative_eq!(beta, 5.0, epsilon = f64::EPSILON);
99 /// ```
100 fn update(&self, dfk: &G, dfk1: &G, _pk: &P) -> F {
101 dfk1.dot(dfk1) / dfk.dot(dfk)
102 }
103}
104
105/// Polak and Ribiere (PR) method
106#[derive(Default, Copy, Clone, Debug)]
107#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
108pub struct PolakRibiere {}
109
110impl PolakRibiere {
111 /// Construct a new instance of `PolakRibiere`.
112 ///
113 /// # Example
114 ///
115 /// ```
116 /// # use argmin::solver::conjugategradient::beta::PolakRibiere;
117 /// let beta_method = PolakRibiere::new();
118 /// ```
119 pub fn new() -> Self {
120 PolakRibiere {}
121 }
122}
123
124impl<G, P, F> NLCGBetaUpdate<G, P, F> for PolakRibiere
125where
126 G: ArgminDot<G, F> + ArgminSub<G, G> + ArgminL2Norm<F>,
127 F: ArgminFloat,
128{
129 /// Update beta using the Polak-Ribiere method.
130 ///
131 /// Formula: `<\nabla f_{k+1}, (\nabla f_{k+1} - \nabla f_k)> / ||\nabla f_k||^2`
132 ///
133 /// # Example
134 ///
135 /// ```
136 /// # extern crate approx;
137 /// # use approx::assert_relative_eq;
138 /// # use argmin::solver::conjugategradient::beta::{NLCGBetaUpdate, PolakRibiere};
139 /// # let dfk = vec![1f64, 2.0];
140 /// # let dfk1 = vec![3f64, 4.0];
141 /// let beta_method = PolakRibiere::new();
142 /// let beta = beta_method.update(&dfk, &dfk1, &());
143 /// # assert_relative_eq!(beta, 14.0/5.0, epsilon = f64::EPSILON);
144 /// ```
145 fn update(&self, dfk: &G, dfk1: &G, _pk: &P) -> F {
146 let dfk_norm_sq = dfk.l2_norm().powi(2);
147 dfk1.dot(&dfk1.sub(dfk)) / dfk_norm_sq
148 }
149}
150
151/// Polak and Ribiere Plus (PR+) method
152///
153/// Formula: `max(0, <\nabla f_{k+1}, (\nabla f_{k+1} - \nabla f_k)> / ||\nabla f_k||^2)`
154#[derive(Default, Copy, Clone, Debug)]
155#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
156pub struct PolakRibierePlus {}
157
158impl PolakRibierePlus {
159 /// Construct a new instance of `PolakRibierePlus`.
160 ///
161 /// # Example
162 ///
163 /// ```
164 /// # use argmin::solver::conjugategradient::beta::PolakRibierePlus;
165 /// let beta_method = PolakRibierePlus::new();
166 /// ```
167 pub fn new() -> Self {
168 PolakRibierePlus {}
169 }
170}
171
172impl<G, P, F> NLCGBetaUpdate<G, P, F> for PolakRibierePlus
173where
174 G: ArgminDot<G, F> + ArgminSub<G, G> + ArgminL2Norm<F>,
175 F: ArgminFloat,
176{
177 /// Update beta using the Polak-Ribiere+ (PR+) method.
178 ///
179 /// Formula: `max(0, <\nabla f_{k+1}, (\nabla f_{k+1} - \nabla f_k)> / ||\nabla f_k||^2)`
180 ///
181 /// # Example
182 ///
183 /// ```
184 /// # extern crate approx;
185 /// # use approx::assert_relative_eq;
186 /// # use argmin::solver::conjugategradient::beta::{NLCGBetaUpdate, PolakRibierePlus};
187 /// # let dfk = vec![1f64, 2.0];
188 /// # let dfk1 = vec![3f64, 4.0];
189 /// let beta_method = PolakRibierePlus::new();
190 /// let beta = beta_method.update(&dfk, &dfk1, &());
191 /// # assert_relative_eq!(beta, 14.0/5.0, epsilon = f64::EPSILON);
192 /// #
193 /// # let dfk = vec![5f64, 6.0];
194 /// # let dfk1 = vec![3f64, 4.0];
195 /// # let beta_method = PolakRibierePlus::new();
196 /// # let beta = beta_method.update(&dfk, &dfk1, &());
197 /// # assert_relative_eq!(beta, 0.0, epsilon = f64::EPSILON);
198 /// ```
199 fn update(&self, dfk: &G, dfk1: &G, _pk: &P) -> F {
200 let dfk_norm_sq = dfk.l2_norm().powi(2);
201 let beta = dfk1.dot(&dfk1.sub(dfk)) / dfk_norm_sq;
202 float!(0.0).max(beta)
203 }
204}
205
206/// Hestenes and Stiefel (HS) method
207///
208/// Formula: `<\nabla f_{k+1}, (\nabla f_{k+1} - \nabla f_k)> / <(\nabla f_{k+1} - \nabla f_k), p_k>`
209#[derive(Default, Copy, Clone, Debug)]
210#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
211pub struct HestenesStiefel {}
212
213impl HestenesStiefel {
214 /// Construct a new instance of `HestenesStiefel`.
215 ///
216 /// # Example
217 ///
218 /// ```
219 /// # use argmin::solver::conjugategradient::beta::HestenesStiefel;
220 /// let beta_method = HestenesStiefel::new();
221 /// ```
222 pub fn new() -> Self {
223 HestenesStiefel {}
224 }
225}
226
227impl<G, P, F> NLCGBetaUpdate<G, P, F> for HestenesStiefel
228where
229 G: ArgminDot<G, F> + ArgminDot<P, F> + ArgminSub<G, G>,
230 F: ArgminFloat,
231{
232 /// Update beta using the Hestenes-Stiefel method.
233 ///
234 /// Formula: `<\nabla f_{k+1}, (\nabla f_{k+1} - \nabla f_k)> / <(\nabla f_{k+1} - \nabla f_k), p_k>`
235 ///
236 /// # Example
237 ///
238 /// ```
239 /// # extern crate approx;
240 /// # use approx::assert_relative_eq;
241 /// # use argmin::solver::conjugategradient::beta::{NLCGBetaUpdate, HestenesStiefel};
242 /// # let dfk = vec![1f64, 2.0];
243 /// # let dfk1 = vec![3f64, 4.0];
244 /// # let pk = vec![5f64, 6.0];
245 /// let beta_method = HestenesStiefel::new();
246 /// let beta: f64 = beta_method.update(&dfk, &dfk1, &pk);
247 /// # assert_relative_eq!(beta, 14.0/22.0, epsilon = f64::EPSILON);
248 /// ```
249 fn update(&self, dfk: &G, dfk1: &G, pk: &P) -> F {
250 let d = dfk1.sub(dfk);
251 dfk1.dot(&d) / d.dot(pk)
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258
259 test_trait_impl!(fletcher_reeves, FletcherReeves);
260 test_trait_impl!(polak_ribiere, PolakRibiere);
261 test_trait_impl!(polak_ribiere_plus, PolakRibierePlus);
262 test_trait_impl!(hestenes_stiefel, HestenesStiefel);
263}