1use crate::core::{
9 ArgminFloat, Error, IterState, Problem, Solver, State, TerminationReason, TerminationStatus,
10 TrustRegionRadius, KV,
11};
12use argmin_math::{
13 ArgminAdd, ArgminDot, ArgminL2Norm, ArgminMul, ArgminWeightedDot, ArgminZeroLike,
14};
15#[cfg(feature = "serde1")]
16use serde::{Deserialize, Serialize};
17
18#[derive(Clone, Default)]
28#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
29pub struct Steihaug<P, F> {
30 radius: F,
32 epsilon: F,
34 p: Option<P>,
36 r: Option<P>,
38 rtr: F,
40 r_0_norm: F,
42 d: Option<P>,
44 max_iters: u64,
46}
47
48impl<P, F> Steihaug<P, F>
49where
50 P: ArgminMul<F, P> + ArgminDot<P, F> + ArgminAdd<P, P>,
51 F: ArgminFloat,
52{
53 pub fn new() -> Self {
62 Steihaug {
63 radius: F::nan(),
64 epsilon: float!(10e-10),
65 p: None,
66 r: None,
67 rtr: F::nan(),
68 r_0_norm: F::nan(),
69 d: None,
70 max_iters: u64::MAX,
71 }
72 }
73
74 pub fn with_epsilon(mut self, epsilon: F) -> Result<Self, Error> {
91 if epsilon <= float!(0.0) {
92 return Err(argmin_error!(
93 InvalidParameter,
94 "`Steihaug`: epsilon must be > 0.0."
95 ));
96 }
97 self.epsilon = epsilon;
98 Ok(self)
99 }
100
101 #[must_use]
115 pub fn with_max_iters(mut self, iters: u64) -> Self {
116 self.max_iters = iters;
117 self
118 }
119
120 fn eval_m<H>(&self, p: &P, g: &P, h: &H) -> F
122 where
123 P: ArgminWeightedDot<P, F, H>,
124 {
125 g.dot(p) + float!(0.5) * p.weighted_dot(h, p)
126 }
127
128 #[allow(clippy::many_single_char_names)]
130 fn tau<G, H>(&self, filter_func: G, eval: bool, g: &P, h: &H) -> F
131 where
132 G: Fn(F) -> bool,
133 P: ArgminWeightedDot<P, F, H>,
134 {
135 let p = self.p.as_ref().unwrap();
136 let d = self.d.as_ref().unwrap();
137 let a = p.dot(p);
138 let b = d.dot(d);
139 let c = p.dot(d);
140 let delta = self.radius.powi(2);
141 let t1 = (-a * b + b * delta + c.powi(2)).sqrt();
142 let tau1 = -(t1 + c) / b;
143 let tau2 = (t1 - c) / b;
144 let mut t = vec![tau1, tau2];
145 if tau1.is_nan() || tau2.is_nan() || tau1.is_infinite() || tau2.is_infinite() {
147 let tau3 = (delta - a) / (float!(2.0) * c);
148 t.push(tau3);
149 }
150 let v = if eval {
151 let mut v = t
154 .iter()
155 .cloned()
156 .enumerate()
157 .filter(|(_, tau)| (!tau.is_nan() || !tau.is_infinite()) && filter_func(*tau))
158 .map(|(i, tau)| {
159 let p_local = p.add(&d.mul(&tau));
160 (i, self.eval_m(&p_local, g, h))
161 })
162 .filter(|(_, m)| !m.is_nan() || !m.is_infinite())
163 .collect::<Vec<(usize, F)>>();
164 v.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
165 v
166 } else {
167 let mut v = t
168 .iter()
169 .cloned()
170 .enumerate()
171 .filter(|(_, tau)| (!tau.is_nan() || !tau.is_infinite()) && filter_func(*tau))
172 .collect::<Vec<(usize, F)>>();
173 v.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
174 v
175 };
176
177 t[v[0].0]
178 }
179}
180
181impl<P, O, F, H> Solver<O, IterState<P, P, (), H, (), F>> for Steihaug<P, F>
182where
183 P: Clone
184 + ArgminMul<F, P>
185 + ArgminL2Norm<F>
186 + ArgminDot<P, F>
187 + ArgminAdd<P, P>
188 + ArgminZeroLike,
189 H: ArgminDot<P, P>,
190 F: ArgminFloat,
191{
192 fn name(&self) -> &str {
193 "Steihaug"
194 }
195
196 fn init(
197 &mut self,
198 _problem: &mut Problem<O>,
199 state: IterState<P, P, (), H, (), F>,
200 ) -> Result<(IterState<P, P, (), H, (), F>, Option<KV>), Error> {
201 let r = state
202 .get_gradient()
203 .ok_or_else(argmin_error_closure!(
204 NotInitialized,
205 concat!(
206 "`Steihaug` requires an initial gradient. ",
207 "Please provide an initial gradient via `Executor`s `configure` method."
208 )
209 ))?
210 .clone();
211
212 if state.get_hessian().is_none() {
213 return Err(argmin_error!(
214 NotInitialized,
215 concat!(
216 "`Steihaug` requires an initial Hessian. ",
217 "Please provide an initial Hessian via `Executor`s `configure` method."
218 )
219 ));
220 }
221
222 self.r_0_norm = r.l2_norm();
223 self.rtr = r.dot(&r);
224 self.d = Some(r.mul(&float!(-1.0)));
225 let p = r.zero_like();
226 self.p = Some(p.clone());
227
228 self.r = Some(r);
229
230 Ok((state.param(p), None))
231 }
232
233 fn next_iter(
234 &mut self,
235 _problem: &mut Problem<O>,
236 mut state: IterState<P, P, (), H, (), F>,
237 ) -> Result<(IterState<P, P, (), H, (), F>, Option<KV>), Error> {
238 let grad = state.take_gradient().ok_or_else(argmin_error_closure!(
239 PotentialBug,
240 "`Steihaug`: Gradient in state not set."
241 ))?;
242
243 let h = state.take_hessian().ok_or_else(argmin_error_closure!(
244 PotentialBug,
245 "`Steihaug`: Hessian in state not set."
246 ))?;
247
248 let d = self.d.as_ref().unwrap();
249 let dhd = d.weighted_dot(&h, d);
250
251 let p = self.p.as_ref().unwrap();
253 if dhd <= float!(0.0) {
254 let tau = self.tau(|_| true, true, &grad, &h);
255 return Ok((
256 state
257 .param(p.add(&d.mul(&tau)))
258 .terminate_with(TerminationReason::SolverConverged),
259 None,
260 ));
261 }
262
263 let alpha = self.rtr / dhd;
264 let p_n = p.add(&d.mul(&alpha));
265
266 if p_n.l2_norm() >= self.radius {
268 let tau = self.tau(|x| x >= float!(0.0), false, &grad, &h);
269 return Ok((
270 state
271 .param(p.add(&d.mul(&tau)))
272 .terminate_with(TerminationReason::SolverConverged),
273 None,
274 ));
275 }
276
277 let r = self.r.as_ref().unwrap();
278 let r_n = r.add(&h.dot(d).mul(&alpha));
279
280 if r_n.l2_norm() < self.epsilon * self.r_0_norm {
281 return Ok((
282 state
283 .param(p_n)
284 .terminate_with(TerminationReason::SolverConverged),
285 None,
286 ));
287 }
288
289 let rjtrj = r_n.dot(&r_n);
290 let beta = rjtrj / self.rtr;
291 self.d = Some(r_n.mul(&float!(-1.0)).add(&d.mul(&beta)));
292 self.r = Some(r_n);
293 self.p = Some(p_n.clone());
294 self.rtr = rjtrj;
295
296 Ok((
297 state.param(p_n).cost(self.rtr).gradient(grad).hessian(h),
298 None,
299 ))
300 }
301
302 fn terminate(&mut self, state: &IterState<P, P, (), H, (), F>) -> TerminationStatus {
303 if self.r_0_norm < self.epsilon {
304 return TerminationStatus::Terminated(TerminationReason::SolverConverged);
305 }
306 if state.get_iter() >= self.max_iters {
307 return TerminationStatus::Terminated(TerminationReason::MaxItersReached);
308 }
309 TerminationStatus::NotTerminated
310 }
311}
312
313impl<P, F: ArgminFloat> TrustRegionRadius<F> for Steihaug<P, F> {
314 fn set_radius(&mut self, radius: F) {
326 self.radius = radius;
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333 use crate::core::test_utils::TestProblem;
334 use crate::core::ArgminError;
335 use approx::assert_relative_eq;
336
337 test_trait_impl!(steihaug, Steihaug<TestProblem, f64>);
338
339 #[test]
340 fn test_new() {
341 let sh: Steihaug<Vec<f64>, f64> = Steihaug::new();
342
343 let Steihaug {
344 radius,
345 epsilon,
346 p,
347 r,
348 rtr,
349 r_0_norm,
350 d,
351 max_iters,
352 } = sh;
353
354 assert_eq!(radius.to_ne_bytes(), f64::NAN.to_ne_bytes());
355 assert_eq!(epsilon.to_ne_bytes(), 10e-10f64.to_ne_bytes());
356 assert!(p.is_none());
357 assert!(r.is_none());
358 assert_eq!(rtr.to_ne_bytes(), f64::NAN.to_ne_bytes());
359 assert_eq!(r_0_norm.to_ne_bytes(), f64::NAN.to_ne_bytes());
360 assert!(d.is_none());
361 assert_eq!(max_iters, u64::MAX);
362 }
363
364 #[test]
365 fn test_with_tolerance() {
366 for tolerance in [f64::EPSILON, 1e-10, 1e-12, 1e-6, 1.0, 10.0, 100.0] {
367 let sh: Steihaug<Vec<f64>, f64> = Steihaug::new().with_epsilon(tolerance).unwrap();
368 assert_eq!(sh.epsilon.to_ne_bytes(), tolerance.to_ne_bytes());
369 }
370
371 for tolerance in [-f64::EPSILON, 0.0, -1.0] {
372 let res: Result<Steihaug<Vec<f64>, f64>, _> = Steihaug::new().with_epsilon(tolerance);
373 assert_error!(
374 res,
375 ArgminError,
376 "Invalid parameter: \"`Steihaug`: epsilon must be > 0.0.\""
377 );
378 }
379 }
380
381 #[test]
382 fn test_max_iters() {
383 let sh: Steihaug<Vec<f64>, f64> = Steihaug::new();
384
385 let Steihaug { max_iters, .. } = sh;
386
387 assert_eq!(max_iters, u64::MAX);
388
389 for iters in [1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144] {
390 let sh: Steihaug<Vec<f64>, f64> = Steihaug::new().with_max_iters(iters);
391
392 let Steihaug { max_iters, .. } = sh;
393
394 assert_eq!(max_iters, iters);
395 }
396 }
397
398 #[test]
399 fn test_init() {
400 let grad: Vec<f64> = vec![1.0, 2.0];
401 let hessian: Vec<Vec<f64>> = vec![vec![4.0, 3.0], vec![2.0, 1.0]];
402
403 let mut sh: Steihaug<Vec<f64>, f64> = Steihaug::new();
404 sh.set_radius(1.0);
405
406 let state: IterState<Vec<f64>, Vec<f64>, (), Vec<Vec<f64>>, (), f64> = IterState::new();
408 let problem = TestProblem::new();
409 let res = sh.init(&mut Problem::new(problem), state);
410 assert_error!(
411 res,
412 ArgminError,
413 concat!(
414 "Not initialized: \"`Steihaug` requires an initial gradient. Please ",
415 "provide an initial gradient via `Executor`s `configure` method.\""
416 )
417 );
418
419 let state: IterState<Vec<f64>, Vec<f64>, (), Vec<Vec<f64>>, (), f64> =
421 IterState::new().gradient(grad.clone());
422 let problem = TestProblem::new();
423 let res = sh.init(&mut Problem::new(problem), state);
424 assert_error!(
425 res,
426 ArgminError,
427 concat!(
428 "Not initialized: \"`Steihaug` requires an initial Hessian. Please ",
429 "provide an initial Hessian via `Executor`s `configure` method.\""
430 )
431 );
432
433 let state: IterState<Vec<f64>, Vec<f64>, (), Vec<Vec<f64>>, (), f64> =
435 IterState::new().gradient(grad.clone()).hessian(hessian);
436 let problem = TestProblem::new();
437 let (mut state_out, kv) = sh.init(&mut Problem::new(problem), state).unwrap();
438
439 assert!(kv.is_none());
440
441 let s_param = state_out.take_param().unwrap();
442
443 assert_relative_eq!(s_param[0], 0.0f64.sqrt(), epsilon = f64::EPSILON);
444 assert_relative_eq!(s_param[1], 0.0f64.sqrt(), epsilon = f64::EPSILON);
445
446 let Steihaug {
447 radius,
448 epsilon,
449 p,
450 r,
451 rtr,
452 r_0_norm,
453 d,
454 max_iters,
455 } = sh;
456
457 assert_eq!(radius.to_ne_bytes(), 1.0f64.to_ne_bytes());
458 assert_eq!(epsilon.to_ne_bytes(), 10e-10f64.to_ne_bytes());
459 assert_relative_eq!(p.as_ref().unwrap()[0], 0.0f64, epsilon = f64::EPSILON);
460 assert_relative_eq!(p.as_ref().unwrap()[1], 0.0f64, epsilon = f64::EPSILON);
461 assert_relative_eq!(r.as_ref().unwrap()[0], grad[0], epsilon = f64::EPSILON);
462 assert_relative_eq!(r.as_ref().unwrap()[1], grad[1], epsilon = f64::EPSILON);
463 assert_eq!(rtr.to_ne_bytes(), 5.0f64.to_ne_bytes());
464 assert_eq!(r_0_norm.to_ne_bytes(), (5.0f64).sqrt().to_ne_bytes());
465 assert_relative_eq!(d.as_ref().unwrap()[0], -grad[0], epsilon = f64::EPSILON);
466 assert_relative_eq!(d.as_ref().unwrap()[1], -grad[1], epsilon = f64::EPSILON);
467 assert_eq!(max_iters, u64::MAX);
468 }
469}