argmin/solver/trustregion/
steihaug.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use crate::core::{
9    ArgminFloat, Error, IterState, Problem, Solver, State, TerminationReason, TerminationStatus,
10    TrustRegionRadius, KV,
11};
12use argmin_math::{
13    ArgminAdd, ArgminDot, ArgminL2Norm, ArgminMul, ArgminWeightedDot, ArgminZeroLike,
14};
15#[cfg(feature = "serde1")]
16use serde::{Deserialize, Serialize};
17
18/// # Steihaug method
19///
20/// The Steihaug method is a conjugate gradients based approach for finding an approximate solution
21/// to the second order approximation of the cost function within the trust region.
22///
23/// ## Reference
24///
25/// Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
26/// Springer. ISBN 0-387-30303-0.
27#[derive(Clone, Default)]
28#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
29pub struct Steihaug<P, F> {
30    /// Radius
31    radius: F,
32    /// epsilon
33    epsilon: F,
34    /// p
35    p: Option<P>,
36    /// residual
37    r: Option<P>,
38    /// r^Tr
39    rtr: F,
40    /// initial residual
41    r_0_norm: F,
42    /// direction
43    d: Option<P>,
44    /// max iters
45    max_iters: u64,
46}
47
48impl<P, F> Steihaug<P, F>
49where
50    P: ArgminMul<F, P> + ArgminDot<P, F> + ArgminAdd<P, P>,
51    F: ArgminFloat,
52{
53    /// Construct a new instance of [`Steihaug`]
54    ///
55    /// # Example
56    ///
57    /// ```
58    /// # use argmin::solver::trustregion::Steihaug;
59    /// let sh: Steihaug<Vec<f64>, f64> = Steihaug::new();
60    /// ```
61    pub fn new() -> Self {
62        Steihaug {
63            radius: F::nan(),
64            epsilon: float!(10e-10),
65            p: None,
66            r: None,
67            rtr: F::nan(),
68            r_0_norm: F::nan(),
69            d: None,
70            max_iters: u64::MAX,
71        }
72    }
73
74    /// Set epsilon
75    ///
76    /// The algorithm stops when the residual is smaller than `epsilon`.
77    ///
78    /// Must be larger than 0 and defaults to 10^-10.
79    ///
80    /// # Example
81    ///
82    /// ```
83    /// # use argmin::solver::trustregion::Steihaug;
84    /// # use argmin::core::Error;
85    /// # fn main() -> Result<(), Error> {
86    /// let sh: Steihaug<Vec<f64>, f64> = Steihaug::new().with_epsilon(10e-9)?;
87    /// # Ok(())
88    /// # }
89    /// ```
90    pub fn with_epsilon(mut self, epsilon: F) -> Result<Self, Error> {
91        if epsilon <= float!(0.0) {
92            return Err(argmin_error!(
93                InvalidParameter,
94                "`Steihaug`: epsilon must be > 0.0."
95            ));
96        }
97        self.epsilon = epsilon;
98        Ok(self)
99    }
100
101    /// Set maximum number of iterations
102    ///
103    /// The algorithm stops after `iter` iterations.
104    ///
105    /// Defaults to `u64::MAX`.
106    ///
107    /// # Example
108    ///
109    /// ```
110    /// # use argmin::solver::trustregion::Steihaug;
111    /// # use argmin::core::Error;
112    /// let sh: Steihaug<Vec<f64>, f64> = Steihaug::new().with_max_iters(100);
113    /// ```
114    #[must_use]
115    pub fn with_max_iters(mut self, iters: u64) -> Self {
116        self.max_iters = iters;
117        self
118    }
119
120    /// evaluate m(p) (without considering f_init because it is not available)
121    fn eval_m<H>(&self, p: &P, g: &P, h: &H) -> F
122    where
123        P: ArgminWeightedDot<P, F, H>,
124    {
125        g.dot(p) + float!(0.5) * p.weighted_dot(h, p)
126    }
127
128    /// calculate all possible step lengths
129    #[allow(clippy::many_single_char_names)]
130    fn tau<G, H>(&self, filter_func: G, eval: bool, g: &P, h: &H) -> F
131    where
132        G: Fn(F) -> bool,
133        P: ArgminWeightedDot<P, F, H>,
134    {
135        let p = self.p.as_ref().unwrap();
136        let d = self.d.as_ref().unwrap();
137        let a = p.dot(p);
138        let b = d.dot(d);
139        let c = p.dot(d);
140        let delta = self.radius.powi(2);
141        let t1 = (-a * b + b * delta + c.powi(2)).sqrt();
142        let tau1 = -(t1 + c) / b;
143        let tau2 = (t1 - c) / b;
144        let mut t = vec![tau1, tau2];
145        // Maybe calculating tau3 should only be done if b is close to zero?
146        if tau1.is_nan() || tau2.is_nan() || tau1.is_infinite() || tau2.is_infinite() {
147            let tau3 = (delta - a) / (float!(2.0) * c);
148            t.push(tau3);
149        }
150        let v = if eval {
151            // remove NAN taus and calculate m (without f_init) for all taus, then sort them based
152            // on their result and return the tau which corresponds to the lowest m
153            let mut v = t
154                .iter()
155                .cloned()
156                .enumerate()
157                .filter(|(_, tau)| (!tau.is_nan() || !tau.is_infinite()) && filter_func(*tau))
158                .map(|(i, tau)| {
159                    let p_local = p.add(&d.mul(&tau));
160                    (i, self.eval_m(&p_local, g, h))
161                })
162                .filter(|(_, m)| !m.is_nan() || !m.is_infinite())
163                .collect::<Vec<(usize, F)>>();
164            v.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
165            v
166        } else {
167            let mut v = t
168                .iter()
169                .cloned()
170                .enumerate()
171                .filter(|(_, tau)| (!tau.is_nan() || !tau.is_infinite()) && filter_func(*tau))
172                .collect::<Vec<(usize, F)>>();
173            v.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
174            v
175        };
176
177        t[v[0].0]
178    }
179}
180
181impl<P, O, F, H> Solver<O, IterState<P, P, (), H, (), F>> for Steihaug<P, F>
182where
183    P: Clone
184        + ArgminMul<F, P>
185        + ArgminL2Norm<F>
186        + ArgminDot<P, F>
187        + ArgminAdd<P, P>
188        + ArgminZeroLike,
189    H: ArgminDot<P, P>,
190    F: ArgminFloat,
191{
192    fn name(&self) -> &str {
193        "Steihaug"
194    }
195
196    fn init(
197        &mut self,
198        _problem: &mut Problem<O>,
199        state: IterState<P, P, (), H, (), F>,
200    ) -> Result<(IterState<P, P, (), H, (), F>, Option<KV>), Error> {
201        let r = state
202            .get_gradient()
203            .ok_or_else(argmin_error_closure!(
204                NotInitialized,
205                concat!(
206                    "`Steihaug` requires an initial gradient. ",
207                    "Please provide an initial gradient via `Executor`s `configure` method."
208                )
209            ))?
210            .clone();
211
212        if state.get_hessian().is_none() {
213            return Err(argmin_error!(
214                NotInitialized,
215                concat!(
216                    "`Steihaug` requires an initial Hessian. ",
217                    "Please provide an initial Hessian via `Executor`s `configure` method."
218                )
219            ));
220        }
221
222        self.r_0_norm = r.l2_norm();
223        self.rtr = r.dot(&r);
224        self.d = Some(r.mul(&float!(-1.0)));
225        let p = r.zero_like();
226        self.p = Some(p.clone());
227
228        self.r = Some(r);
229
230        Ok((state.param(p), None))
231    }
232
233    fn next_iter(
234        &mut self,
235        _problem: &mut Problem<O>,
236        mut state: IterState<P, P, (), H, (), F>,
237    ) -> Result<(IterState<P, P, (), H, (), F>, Option<KV>), Error> {
238        let grad = state.take_gradient().ok_or_else(argmin_error_closure!(
239            PotentialBug,
240            "`Steihaug`: Gradient in state not set."
241        ))?;
242
243        let h = state.take_hessian().ok_or_else(argmin_error_closure!(
244            PotentialBug,
245            "`Steihaug`: Hessian in state not set."
246        ))?;
247
248        let d = self.d.as_ref().unwrap();
249        let dhd = d.weighted_dot(&h, d);
250
251        // Current search direction d is a direction of zero curvature or negative curvature
252        let p = self.p.as_ref().unwrap();
253        if dhd <= float!(0.0) {
254            let tau = self.tau(|_| true, true, &grad, &h);
255            return Ok((
256                state
257                    .param(p.add(&d.mul(&tau)))
258                    .terminate_with(TerminationReason::SolverConverged),
259                None,
260            ));
261        }
262
263        let alpha = self.rtr / dhd;
264        let p_n = p.add(&d.mul(&alpha));
265
266        // new p violates trust region bound
267        if p_n.l2_norm() >= self.radius {
268            let tau = self.tau(|x| x >= float!(0.0), false, &grad, &h);
269            return Ok((
270                state
271                    .param(p.add(&d.mul(&tau)))
272                    .terminate_with(TerminationReason::SolverConverged),
273                None,
274            ));
275        }
276
277        let r = self.r.as_ref().unwrap();
278        let r_n = r.add(&h.dot(d).mul(&alpha));
279
280        if r_n.l2_norm() < self.epsilon * self.r_0_norm {
281            return Ok((
282                state
283                    .param(p_n)
284                    .terminate_with(TerminationReason::SolverConverged),
285                None,
286            ));
287        }
288
289        let rjtrj = r_n.dot(&r_n);
290        let beta = rjtrj / self.rtr;
291        self.d = Some(r_n.mul(&float!(-1.0)).add(&d.mul(&beta)));
292        self.r = Some(r_n);
293        self.p = Some(p_n.clone());
294        self.rtr = rjtrj;
295
296        Ok((
297            state.param(p_n).cost(self.rtr).gradient(grad).hessian(h),
298            None,
299        ))
300    }
301
302    fn terminate(&mut self, state: &IterState<P, P, (), H, (), F>) -> TerminationStatus {
303        if self.r_0_norm < self.epsilon {
304            return TerminationStatus::Terminated(TerminationReason::SolverConverged);
305        }
306        if state.get_iter() >= self.max_iters {
307            return TerminationStatus::Terminated(TerminationReason::MaxItersReached);
308        }
309        TerminationStatus::NotTerminated
310    }
311}
312
313impl<P, F: ArgminFloat> TrustRegionRadius<F> for Steihaug<P, F> {
314    /// Set current radius.
315    ///
316    /// Needed by [`TrustRegion`](`crate::solver::trustregion::TrustRegion`).
317    ///
318    /// # Example
319    ///
320    /// ```
321    /// use argmin::solver::trustregion::{Steihaug, TrustRegionRadius};
322    /// let mut sh: Steihaug<Vec<f64>, f64> = Steihaug::new();
323    /// sh.set_radius(0.8);
324    /// ```
325    fn set_radius(&mut self, radius: F) {
326        self.radius = radius;
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333    use crate::core::test_utils::TestProblem;
334    use crate::core::ArgminError;
335    use approx::assert_relative_eq;
336
337    test_trait_impl!(steihaug, Steihaug<TestProblem, f64>);
338
339    #[test]
340    fn test_new() {
341        let sh: Steihaug<Vec<f64>, f64> = Steihaug::new();
342
343        let Steihaug {
344            radius,
345            epsilon,
346            p,
347            r,
348            rtr,
349            r_0_norm,
350            d,
351            max_iters,
352        } = sh;
353
354        assert_eq!(radius.to_ne_bytes(), f64::NAN.to_ne_bytes());
355        assert_eq!(epsilon.to_ne_bytes(), 10e-10f64.to_ne_bytes());
356        assert!(p.is_none());
357        assert!(r.is_none());
358        assert_eq!(rtr.to_ne_bytes(), f64::NAN.to_ne_bytes());
359        assert_eq!(r_0_norm.to_ne_bytes(), f64::NAN.to_ne_bytes());
360        assert!(d.is_none());
361        assert_eq!(max_iters, u64::MAX);
362    }
363
364    #[test]
365    fn test_with_tolerance() {
366        for tolerance in [f64::EPSILON, 1e-10, 1e-12, 1e-6, 1.0, 10.0, 100.0] {
367            let sh: Steihaug<Vec<f64>, f64> = Steihaug::new().with_epsilon(tolerance).unwrap();
368            assert_eq!(sh.epsilon.to_ne_bytes(), tolerance.to_ne_bytes());
369        }
370
371        for tolerance in [-f64::EPSILON, 0.0, -1.0] {
372            let res: Result<Steihaug<Vec<f64>, f64>, _> = Steihaug::new().with_epsilon(tolerance);
373            assert_error!(
374                res,
375                ArgminError,
376                "Invalid parameter: \"`Steihaug`: epsilon must be > 0.0.\""
377            );
378        }
379    }
380
381    #[test]
382    fn test_max_iters() {
383        let sh: Steihaug<Vec<f64>, f64> = Steihaug::new();
384
385        let Steihaug { max_iters, .. } = sh;
386
387        assert_eq!(max_iters, u64::MAX);
388
389        for iters in [1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144] {
390            let sh: Steihaug<Vec<f64>, f64> = Steihaug::new().with_max_iters(iters);
391
392            let Steihaug { max_iters, .. } = sh;
393
394            assert_eq!(max_iters, iters);
395        }
396    }
397
398    #[test]
399    fn test_init() {
400        let grad: Vec<f64> = vec![1.0, 2.0];
401        let hessian: Vec<Vec<f64>> = vec![vec![4.0, 3.0], vec![2.0, 1.0]];
402
403        let mut sh: Steihaug<Vec<f64>, f64> = Steihaug::new();
404        sh.set_radius(1.0);
405
406        // Forgot to initialize gradient
407        let state: IterState<Vec<f64>, Vec<f64>, (), Vec<Vec<f64>>, (), f64> = IterState::new();
408        let problem = TestProblem::new();
409        let res = sh.init(&mut Problem::new(problem), state);
410        assert_error!(
411            res,
412            ArgminError,
413            concat!(
414                "Not initialized: \"`Steihaug` requires an initial gradient. Please ",
415                "provide an initial gradient via `Executor`s `configure` method.\""
416            )
417        );
418
419        // Forgot to initialize Hessian
420        let state: IterState<Vec<f64>, Vec<f64>, (), Vec<Vec<f64>>, (), f64> =
421            IterState::new().gradient(grad.clone());
422        let problem = TestProblem::new();
423        let res = sh.init(&mut Problem::new(problem), state);
424        assert_error!(
425            res,
426            ArgminError,
427            concat!(
428                "Not initialized: \"`Steihaug` requires an initial Hessian. Please ",
429                "provide an initial Hessian via `Executor`s `configure` method.\""
430            )
431        );
432
433        // All good.
434        let state: IterState<Vec<f64>, Vec<f64>, (), Vec<Vec<f64>>, (), f64> =
435            IterState::new().gradient(grad.clone()).hessian(hessian);
436        let problem = TestProblem::new();
437        let (mut state_out, kv) = sh.init(&mut Problem::new(problem), state).unwrap();
438
439        assert!(kv.is_none());
440
441        let s_param = state_out.take_param().unwrap();
442
443        assert_relative_eq!(s_param[0], 0.0f64.sqrt(), epsilon = f64::EPSILON);
444        assert_relative_eq!(s_param[1], 0.0f64.sqrt(), epsilon = f64::EPSILON);
445
446        let Steihaug {
447            radius,
448            epsilon,
449            p,
450            r,
451            rtr,
452            r_0_norm,
453            d,
454            max_iters,
455        } = sh;
456
457        assert_eq!(radius.to_ne_bytes(), 1.0f64.to_ne_bytes());
458        assert_eq!(epsilon.to_ne_bytes(), 10e-10f64.to_ne_bytes());
459        assert_relative_eq!(p.as_ref().unwrap()[0], 0.0f64, epsilon = f64::EPSILON);
460        assert_relative_eq!(p.as_ref().unwrap()[1], 0.0f64, epsilon = f64::EPSILON);
461        assert_relative_eq!(r.as_ref().unwrap()[0], grad[0], epsilon = f64::EPSILON);
462        assert_relative_eq!(r.as_ref().unwrap()[1], grad[1], epsilon = f64::EPSILON);
463        assert_eq!(rtr.to_ne_bytes(), 5.0f64.to_ne_bytes());
464        assert_eq!(r_0_norm.to_ne_bytes(), (5.0f64).sqrt().to_ne_bytes());
465        assert_relative_eq!(d.as_ref().unwrap()[0], -grad[0], epsilon = f64::EPSILON);
466        assert_relative_eq!(d.as_ref().unwrap()[1], -grad[1], epsilon = f64::EPSILON);
467        assert_eq!(max_iters, u64::MAX);
468    }
469}