1use crate::core::{ArgminFloat, Error, IterState, Operator, Problem, Solver, State, KV};
9use argmin_math::{ArgminConj, ArgminDot, ArgminL2Norm, ArgminMul, ArgminScaledAdd, ArgminSub};
10#[cfg(feature = "serde1")]
11use serde::{Deserialize, Serialize};
12
13#[derive(Clone)]
31#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
32pub struct ConjugateGradient<P, F> {
33 b: P,
35 p: Option<P>,
37 p_prev: Option<P>,
39 rtr: F,
41}
42
43impl<P, F> ConjugateGradient<P, F>
44where
45 F: ArgminFloat,
46{
47 pub fn new(b: P) -> Self {
59 ConjugateGradient {
60 b,
61 p: None,
62 p_prev: None,
63 rtr: F::nan(),
64 }
65 }
66
67 pub fn get_prev_p(&self) -> Result<&P, Error> {
80 self.p_prev.as_ref().ok_or_else(argmin_error_closure!(
81 NotInitialized,
82 "Field `p_prev` of `ConjugateGradient` not initialized."
83 ))
84 }
85}
86
87impl<P, O, R, F> Solver<O, IterState<P, (), (), (), R, F>> for ConjugateGradient<P, F>
88where
89 O: Operator<Param = P, Output = P>,
90 P: Clone + ArgminDot<P, F> + ArgminSub<P, R> + ArgminScaledAdd<P, F, P> + ArgminConj,
91 R: ArgminMul<F, R> + ArgminMul<F, P> + ArgminConj + ArgminDot<R, F> + ArgminScaledAdd<P, F, R>,
92 F: ArgminFloat + ArgminL2Norm<F>,
93{
94 fn name(&self) -> &str {
95 "Conjugate Gradient"
96 }
97
98 fn init(
99 &mut self,
100 problem: &mut Problem<O>,
101 state: IterState<P, (), (), (), R, F>,
102 ) -> Result<(IterState<P, (), (), (), R, F>, Option<KV>), Error> {
103 let init_param = state.get_param().ok_or_else(argmin_error_closure!(
104 NotInitialized,
105 concat!(
106 "`ConjugateGradient` requires an initial parameter vector. ",
107 "Please provide an initial guess via `Executor`s `configure` method."
108 )
109 ))?;
110 let ap = problem.apply(init_param)?;
111 let r0: R = self.b.sub(&ap).mul(&(float!(-1.0)));
112 self.p = Some(r0.mul(&(float!(-1.0))));
113 self.rtr = r0.dot(&r0.conj());
114 Ok((state.residuals(r0), None))
115 }
116
117 fn next_iter(
119 &mut self,
120 problem: &mut Problem<O>,
121 mut state: IterState<P, (), (), (), R, F>,
122 ) -> Result<(IterState<P, (), (), (), R, F>, Option<KV>), Error> {
123 let p = self.p.take().ok_or_else(argmin_error_closure!(
124 PotentialBug,
125 "`ConjugateGradient`: Field `p` not set"
126 ))?;
127 let r = state.take_residuals().ok_or_else(argmin_error_closure!(
128 PotentialBug,
129 "`ConjugateGradient`: Residuals in `state` not set"
130 ))?;
131
132 let apk = problem.apply(&p)?;
133 let alpha = self.rtr.div(p.dot(&apk.conj()));
134 let state_param = state.get_param().ok_or_else(argmin_error_closure!(
135 PotentialBug,
136 "`ConjugateGradient`: Parameter vector in `state` not set"
137 ))?;
138 let new_param = state_param.scaled_add(&alpha, &p);
139 let r = r.scaled_add(&alpha, &apk);
140 let rtr_n = r.dot(&r.conj());
141 let beta = rtr_n.div(self.rtr);
142 self.rtr = rtr_n;
143 let p_n = <R as ArgminMul<F, P>>::mul(&r, &(float!(-1.0))).scaled_add(&beta, &p);
144 let norm = r.dot(&r.conj()).l2_norm();
145
146 self.p = Some(p_n);
147 self.p_prev = Some(p);
148
149 Ok((
150 state.param(new_param).residuals(r).cost(norm),
151 Some(kv!("alpha" => alpha; "beta" => beta;)),
152 ))
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use crate::core::{test_utils::TestProblem, ArgminError};
160 use approx::assert_relative_eq;
161
162 test_trait_impl!(conjugate_gradient, ConjugateGradient<Vec<f64>, f64>);
163
164 #[test]
165 fn test_new() {
166 let cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![1.0f64, 2.0]);
167 let ConjugateGradient { b, p, p_prev, rtr } = cg;
168 assert_eq!(b[0].to_ne_bytes(), 1.0f64.to_ne_bytes());
169 assert_eq!(b[1].to_ne_bytes(), 2.0f64.to_ne_bytes());
170 assert!(p.is_none());
171 assert!(p_prev.is_none());
172 assert!(rtr.is_nan());
173 }
174
175 #[test]
176 fn test_get_prev_p_not_initialized() {
177 let cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![1.0f64, 2.0]);
178 let res: Result<_, _> = cg.get_prev_p();
179 assert_error!(
180 res,
181 ArgminError,
182 "Not initialized: \"Field `p_prev` of `ConjugateGradient` not initialized.\""
183 );
184 }
185
186 #[test]
187 fn test_get_prev_p() {
188 let mut cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![1.0f64, 2.0]);
189 cg.p_prev = Some(vec![3.0f64, 4.0]);
190 let res: Result<_, _> = cg.get_prev_p();
191 assert!(res.is_ok());
192 let p_prev = res.unwrap();
193 assert_eq!(p_prev[0].to_ne_bytes(), 3.0f64.to_ne_bytes());
194 assert_eq!(p_prev[1].to_ne_bytes(), 4.0f64.to_ne_bytes());
195 }
196
197 #[test]
198 fn test_init_param_not_initialized() {
199 let mut cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![1.0f64, 2.0]);
200 let res = cg.init(&mut Problem::new(TestProblem::new()), IterState::new());
201 assert_error!(
202 res,
203 ArgminError,
204 concat!(
205 "Not initialized: \"`ConjugateGradient` requires an initial parameter vector. ",
206 "Please provide an initial guess via `Executor`s `configure` method.\""
207 )
208 );
209 }
210
211 #[test]
212 fn test_init() {
213 let mut cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![1.0f64, 2.0]);
214 let state: IterState<Vec<f64>, (), (), (), Vec<f64>, f64> =
215 IterState::new().param(vec![3.0, 4.0]);
216 let (state_out, kv) = cg
217 .init(&mut Problem::new(TestProblem::new()), state.clone())
218 .unwrap();
219 assert!(kv.is_none());
220
221 let ConjugateGradient { b, p, p_prev, rtr } = cg;
222
223 assert_relative_eq!(b[0], 1.0, epsilon = f64::EPSILON);
224 assert_relative_eq!(b[1], 2.0, epsilon = f64::EPSILON);
225 let r0 = [2.0f64, 2.0];
226 assert_relative_eq!(
227 r0[0],
228 state_out.get_residuals().as_ref().unwrap()[0],
229 epsilon = f64::EPSILON
230 );
231 assert_relative_eq!(
232 r0[1],
233 state_out.get_residuals().as_ref().unwrap()[1],
234 epsilon = f64::EPSILON
235 );
236 let pp = [-2.0f64, -2.0];
237 assert_relative_eq!(pp[0], p.as_ref().unwrap()[0], epsilon = f64::EPSILON);
238 assert_relative_eq!(pp[1], p.as_ref().unwrap()[1], epsilon = f64::EPSILON);
239 assert_relative_eq!(rtr, 8.0, epsilon = f64::EPSILON);
240 assert!(p_prev.is_none());
241 }
242
243 #[test]
244 fn test_next_iter_p_not_set() {
245 let mut cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![1.0f64, 2.0]);
246 let state = IterState::new().param(vec![1.0f64]);
247 assert!(cg.p.is_none());
248 let res = cg.next_iter(&mut Problem::new(TestProblem::new()), state);
249 assert_error!(
250 res,
251 ArgminError,
252 concat!(
253 "Potential bug: \"`ConjugateGradient`: ",
254 "Field `p` not set\". This is potentially a bug. ",
255 "Please file a report on https://github.com/argmin-rs/argmin/issues"
256 )
257 );
258 }
259
260 #[test]
261 fn test_next_iter_r_not_set() {
262 let mut cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![1.0f64, 2.0]);
263 let state = IterState::new().param(vec![1.0f64]);
264 cg.p = Some(vec![]);
265 let res = cg.next_iter(&mut Problem::new(TestProblem::new()), state);
266 assert_error!(
267 res,
268 ArgminError,
269 concat!(
270 "Potential bug: \"`ConjugateGradient`: ",
271 "Residuals in `state` not set\". This is potentially a bug. ",
272 "Please file a report on https://github.com/argmin-rs/argmin/issues"
273 )
274 );
275 }
276
277 #[test]
278 fn test_next_iter_state_param_not_set() {
279 let mut cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![1.0f64, 2.0]);
280 let state = IterState::new().residuals(vec![]);
281 cg.p = Some(vec![]);
282 assert!(state.param.is_none());
283 let res = cg.next_iter(&mut Problem::new(TestProblem::new()), state);
284 assert_error!(
285 res,
286 ArgminError,
287 concat!(
288 "Potential bug: \"`ConjugateGradient`: ",
289 "Parameter vector in `state` not set\". This is potentially a bug. ",
290 "Please file a report on https://github.com/argmin-rs/argmin/issues"
291 )
292 );
293 }
294
295 #[test]
296 fn test_next_iter() {
297 let mut cg: ConjugateGradient<_, f64> = ConjugateGradient::new(vec![2.0f64]);
298 let state = IterState::new().param(vec![1.0f64]);
299 let mut problem = Problem::new(TestProblem::new());
300 let (state, _) = cg.init(&mut problem, state).unwrap();
301 let rtr = cg.rtr;
302 let p = cg.p.clone().unwrap()[0];
303 let r = state.get_residuals().unwrap()[0];
304
305 let apk = p;
306 let alpha = rtr / (p * apk);
307 let new_param = 1.0 + alpha * p;
308 let r = r + alpha * apk;
309 let rtr_n = -r * r;
310 let beta = rtr_n / rtr;
311 let p_n = -r + beta * p;
312 let norm = (r * r).l2_norm();
313
314 let (state, kv) = cg.next_iter(&mut problem, state).unwrap();
315 assert!(kv.is_some());
316
317 assert_relative_eq!(r, state.get_residuals().unwrap()[0]);
318 assert_relative_eq!(p_n, cg.p.as_ref().unwrap()[0]);
319 assert_relative_eq!(p, cg.p_prev.as_ref().unwrap()[0]);
320 assert_relative_eq!(rtr_n, cg.rtr);
321
322 assert_relative_eq!(norm, state.get_cost());
323 assert_relative_eq!(new_param, state.get_param().unwrap()[0]);
324 }
325}