argmin/solver/conjugategradient/
cg.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use crate::core::{ArgminFloat, Error, IterState, Operator, Problem, Solver, State, KV};
9use argmin_math::{ArgminConj, ArgminDot, ArgminL2Norm, ArgminMul, ArgminScaledAdd, ArgminSub};
10#[cfg(feature = "serde1")]
11use serde::{Deserialize, Serialize};
12
13/// # Conjugate Gradient method
14///
15/// A solver for systems of linear equations with a symmetric and positive-definite matrix.
16///
17/// Solves systems of the form `A * x = b` where `x` and `b` are vectors and `A` is a symmetric and
18/// positive-definite matrix.
19///
20/// Requires an initial parameter vector.
21///
22/// ## Requirements on the optimization problem
23///
24/// The optimization problem is required to implement [`Operator`].
25///
26/// ## Reference
27///
28/// Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
29/// Springer. ISBN 0-387-30303-0.
30#[derive(Clone)]
31#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
32pub struct ConjugateGradient<P, F> {
33    /// b (right hand side of `A * x = b`)
34    b: P,
35    /// p
36    p: Option<P>,
37    /// previous p
38    p_prev: Option<P>,
39    /// r^T * r
40    rtr: F,
41}
42
43impl<P, F> ConjugateGradient<P, F>
44where
45    F: ArgminFloat,
46{
47    /// Constructs an instance of [`ConjugateGradient`]
48    ///
49    /// Takes `b`, the right hand side of `A * x = b` as input.
50    ///
51    /// # Example
52    ///
53    /// ```
54    /// # use argmin::solver::conjugategradient::ConjugateGradient;
55    /// # let b = vec![1.0f64, 1.0];
56    /// let cg: ConjugateGradient<_, f64> = ConjugateGradient::new(b);
57    /// ```
58    pub fn new(b: P) -> Self {
59        ConjugateGradient {
60            b,
61            p: None,
62            p_prev: None,
63            rtr: F::nan(),
64        }
65    }
66
67    /// Return the previous search direction (Needed by [`NewtonCG`](`crate::solver::newton::NewtonCG`))
68    ///
69    /// Returns an error if the field `p_prev` is not initialized.
70    ///
71    /// # Example
72    ///
73    /// ```
74    /// # use argmin::solver::conjugategradient::ConjugateGradient;
75    /// # use argmin::core::Error;
76    /// # let cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![1.0f64, 1.0]);
77    /// let p_prev: Result<_, _> = cg.get_prev_p();
78    /// ```
79    pub fn get_prev_p(&self) -> Result<&P, Error> {
80        self.p_prev.as_ref().ok_or_else(argmin_error_closure!(
81            NotInitialized,
82            "Field `p_prev` of `ConjugateGradient` not initialized."
83        ))
84    }
85}
86
87impl<P, O, R, F> Solver<O, IterState<P, (), (), (), R, F>> for ConjugateGradient<P, F>
88where
89    O: Operator<Param = P, Output = P>,
90    P: Clone + ArgminDot<P, F> + ArgminSub<P, R> + ArgminScaledAdd<P, F, P> + ArgminConj,
91    R: ArgminMul<F, R> + ArgminMul<F, P> + ArgminConj + ArgminDot<R, F> + ArgminScaledAdd<P, F, R>,
92    F: ArgminFloat + ArgminL2Norm<F>,
93{
94    fn name(&self) -> &str {
95        "Conjugate Gradient"
96    }
97
98    fn init(
99        &mut self,
100        problem: &mut Problem<O>,
101        state: IterState<P, (), (), (), R, F>,
102    ) -> Result<(IterState<P, (), (), (), R, F>, Option<KV>), Error> {
103        let init_param = state.get_param().ok_or_else(argmin_error_closure!(
104            NotInitialized,
105            concat!(
106                "`ConjugateGradient` requires an initial parameter vector. ",
107                "Please provide an initial guess via `Executor`s `configure` method."
108            )
109        ))?;
110        let ap = problem.apply(init_param)?;
111        let r0: R = self.b.sub(&ap).mul(&(float!(-1.0)));
112        self.p = Some(r0.mul(&(float!(-1.0))));
113        self.rtr = r0.dot(&r0.conj());
114        Ok((state.residuals(r0), None))
115    }
116
117    /// Perform one iteration of CG algorithm
118    fn next_iter(
119        &mut self,
120        problem: &mut Problem<O>,
121        mut state: IterState<P, (), (), (), R, F>,
122    ) -> Result<(IterState<P, (), (), (), R, F>, Option<KV>), Error> {
123        let p = self.p.take().ok_or_else(argmin_error_closure!(
124            PotentialBug,
125            "`ConjugateGradient`: Field `p` not set"
126        ))?;
127        let r = state.take_residuals().ok_or_else(argmin_error_closure!(
128            PotentialBug,
129            "`ConjugateGradient`: Residuals in `state` not set"
130        ))?;
131
132        let apk = problem.apply(&p)?;
133        let alpha = self.rtr.div(p.dot(&apk.conj()));
134        let state_param = state.get_param().ok_or_else(argmin_error_closure!(
135            PotentialBug,
136            "`ConjugateGradient`: Parameter vector in `state` not set"
137        ))?;
138        let new_param = state_param.scaled_add(&alpha, &p);
139        let r = r.scaled_add(&alpha, &apk);
140        let rtr_n = r.dot(&r.conj());
141        let beta = rtr_n.div(self.rtr);
142        self.rtr = rtr_n;
143        let p_n = <R as ArgminMul<F, P>>::mul(&r, &(float!(-1.0))).scaled_add(&beta, &p);
144        let norm = r.dot(&r.conj()).l2_norm();
145
146        self.p = Some(p_n);
147        self.p_prev = Some(p);
148
149        Ok((
150            state.param(new_param).residuals(r).cost(norm),
151            Some(kv!("alpha" => alpha; "beta" => beta;)),
152        ))
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use crate::core::{test_utils::TestProblem, ArgminError};
160    use approx::assert_relative_eq;
161
162    test_trait_impl!(conjugate_gradient, ConjugateGradient<Vec<f64>, f64>);
163
164    #[test]
165    fn test_new() {
166        let cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![1.0f64, 2.0]);
167        let ConjugateGradient { b, p, p_prev, rtr } = cg;
168        assert_eq!(b[0].to_ne_bytes(), 1.0f64.to_ne_bytes());
169        assert_eq!(b[1].to_ne_bytes(), 2.0f64.to_ne_bytes());
170        assert!(p.is_none());
171        assert!(p_prev.is_none());
172        assert!(rtr.is_nan());
173    }
174
175    #[test]
176    fn test_get_prev_p_not_initialized() {
177        let cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![1.0f64, 2.0]);
178        let res: Result<_, _> = cg.get_prev_p();
179        assert_error!(
180            res,
181            ArgminError,
182            "Not initialized: \"Field `p_prev` of `ConjugateGradient` not initialized.\""
183        );
184    }
185
186    #[test]
187    fn test_get_prev_p() {
188        let mut cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![1.0f64, 2.0]);
189        cg.p_prev = Some(vec![3.0f64, 4.0]);
190        let res: Result<_, _> = cg.get_prev_p();
191        assert!(res.is_ok());
192        let p_prev = res.unwrap();
193        assert_eq!(p_prev[0].to_ne_bytes(), 3.0f64.to_ne_bytes());
194        assert_eq!(p_prev[1].to_ne_bytes(), 4.0f64.to_ne_bytes());
195    }
196
197    #[test]
198    fn test_init_param_not_initialized() {
199        let mut cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![1.0f64, 2.0]);
200        let res = cg.init(&mut Problem::new(TestProblem::new()), IterState::new());
201        assert_error!(
202            res,
203            ArgminError,
204            concat!(
205                "Not initialized: \"`ConjugateGradient` requires an initial parameter vector. ",
206                "Please provide an initial guess via `Executor`s `configure` method.\""
207            )
208        );
209    }
210
211    #[test]
212    fn test_init() {
213        let mut cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![1.0f64, 2.0]);
214        let state: IterState<Vec<f64>, (), (), (), Vec<f64>, f64> =
215            IterState::new().param(vec![3.0, 4.0]);
216        let (state_out, kv) = cg
217            .init(&mut Problem::new(TestProblem::new()), state.clone())
218            .unwrap();
219        assert!(kv.is_none());
220
221        let ConjugateGradient { b, p, p_prev, rtr } = cg;
222
223        assert_relative_eq!(b[0], 1.0, epsilon = f64::EPSILON);
224        assert_relative_eq!(b[1], 2.0, epsilon = f64::EPSILON);
225        let r0 = [2.0f64, 2.0];
226        assert_relative_eq!(
227            r0[0],
228            state_out.get_residuals().as_ref().unwrap()[0],
229            epsilon = f64::EPSILON
230        );
231        assert_relative_eq!(
232            r0[1],
233            state_out.get_residuals().as_ref().unwrap()[1],
234            epsilon = f64::EPSILON
235        );
236        let pp = [-2.0f64, -2.0];
237        assert_relative_eq!(pp[0], p.as_ref().unwrap()[0], epsilon = f64::EPSILON);
238        assert_relative_eq!(pp[1], p.as_ref().unwrap()[1], epsilon = f64::EPSILON);
239        assert_relative_eq!(rtr, 8.0, epsilon = f64::EPSILON);
240        assert!(p_prev.is_none());
241    }
242
243    #[test]
244    fn test_next_iter_p_not_set() {
245        let mut cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![1.0f64, 2.0]);
246        let state = IterState::new().param(vec![1.0f64]);
247        assert!(cg.p.is_none());
248        let res = cg.next_iter(&mut Problem::new(TestProblem::new()), state);
249        assert_error!(
250            res,
251            ArgminError,
252            concat!(
253                "Potential bug: \"`ConjugateGradient`: ",
254                "Field `p` not set\". This is potentially a bug. ",
255                "Please file a report on https://github.com/argmin-rs/argmin/issues"
256            )
257        );
258    }
259
260    #[test]
261    fn test_next_iter_r_not_set() {
262        let mut cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![1.0f64, 2.0]);
263        let state = IterState::new().param(vec![1.0f64]);
264        cg.p = Some(vec![]);
265        let res = cg.next_iter(&mut Problem::new(TestProblem::new()), state);
266        assert_error!(
267            res,
268            ArgminError,
269            concat!(
270                "Potential bug: \"`ConjugateGradient`: ",
271                "Residuals in `state` not set\". This is potentially a bug. ",
272                "Please file a report on https://github.com/argmin-rs/argmin/issues"
273            )
274        );
275    }
276
277    #[test]
278    fn test_next_iter_state_param_not_set() {
279        let mut cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![1.0f64, 2.0]);
280        let state = IterState::new().residuals(vec![]);
281        cg.p = Some(vec![]);
282        assert!(state.param.is_none());
283        let res = cg.next_iter(&mut Problem::new(TestProblem::new()), state);
284        assert_error!(
285            res,
286            ArgminError,
287            concat!(
288                "Potential bug: \"`ConjugateGradient`: ",
289                "Parameter vector in `state` not set\". This is potentially a bug. ",
290                "Please file a report on https://github.com/argmin-rs/argmin/issues"
291            )
292        );
293    }
294
295    #[test]
296    fn test_next_iter() {
297        let mut cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![2.0f64]);
298        let state = IterState::new().param(vec![1.0f64]);
299        let mut problem = Problem::new(TestProblem::new());
300        let (state, _) = cg.init(&mut problem, state).unwrap();
301        let rtr = cg.rtr;
302        let p = cg.p.clone().unwrap()[0];
303        let r = state.get_residuals().unwrap()[0];
304
305        let apk = p;
306        let alpha = rtr / (p * apk);
307        let new_param = 1.0 + alpha * p;
308        let r = r + alpha * apk;
309        let rtr_n = -r * r;
310        let beta = rtr_n / rtr;
311        let p_n = -r + beta * p;
312        let norm = (r * r).l2_norm();
313
314        let (state, kv) = cg.next_iter(&mut problem, state).unwrap();
315        assert!(kv.is_some());
316
317        assert_relative_eq!(r, state.get_residuals().unwrap()[0]);
318        assert_relative_eq!(p_n, cg.p.as_ref().unwrap()[0]);
319        assert_relative_eq!(p, cg.p_prev.as_ref().unwrap()[0]);
320        assert_relative_eq!(rtr_n, cg.rtr);
321
322        assert_relative_eq!(norm, state.get_cost());
323        assert_relative_eq!(new_param, state.get_param().unwrap()[0]);
324    }
325}