argmin/solver/trustregion/
cauchypoint.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use crate::core::{
9    ArgminFloat, Error, Gradient, Hessian, IterState, Problem, Solver, State, TerminationReason,
10    TerminationStatus, TrustRegionRadius, KV,
11};
12use argmin_math::{ArgminL2Norm, ArgminMul, ArgminWeightedDot};
13#[cfg(feature = "serde1")]
14use serde::{Deserialize, Serialize};
15use std::fmt::Debug;
16
17/// # Cauchy point method
18///
19/// The Cauchy point is the minimum of the quadratic approximation of the cost function within the
20/// trust region along the direction given by the first derivative.
21///
22/// ## Requirements on the optimization problem
23///
24/// The optimization problem is required to implement [`Gradient`] and [`Hessian`].
25///
26/// ## Reference
27///
28/// Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
29/// Springer. ISBN 0-387-30303-0.
30#[derive(Clone, Debug, Copy, PartialEq, Eq, PartialOrd, Default)]
31#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
32pub struct CauchyPoint<F> {
33    /// Radius
34    radius: F,
35}
36
37impl<F> CauchyPoint<F>
38where
39    F: ArgminFloat,
40{
41    /// Construct a new instance of [`CauchyPoint`]
42    ///
43    /// # Example
44    ///
45    /// ```
46    /// # use argmin::solver::trustregion::CauchyPoint;
47    /// let cp: CauchyPoint<f64> = CauchyPoint::new();
48    /// ```
49    pub fn new() -> Self {
50        CauchyPoint { radius: F::nan() }
51    }
52}
53
54impl<O, F, P, G, H> Solver<O, IterState<P, G, (), H, (), F>> for CauchyPoint<F>
55where
56    O: Gradient<Param = P, Gradient = G> + Hessian<Param = P, Hessian = H>,
57    P: Clone + ArgminMul<F, P> + ArgminWeightedDot<P, F, H>,
58    G: ArgminMul<F, P> + ArgminWeightedDot<G, F, H> + ArgminL2Norm<F>,
59    F: ArgminFloat,
60{
61    fn name(&self) -> &str {
62        "Cauchy Point"
63    }
64
65    fn next_iter(
66        &mut self,
67        problem: &mut Problem<O>,
68        mut state: IterState<P, G, (), H, (), F>,
69    ) -> Result<(IterState<P, G, (), H, (), F>, Option<KV>), Error> {
70        let param = state.take_param().ok_or_else(argmin_error_closure!(
71            NotInitialized,
72            concat!(
73                "`CauchyPoint` requires an initial parameter vector. ",
74                "Please provide an initial guess via `Executor`s `configure` method."
75            )
76        ))?;
77
78        let grad = state
79            .take_gradient()
80            .map(Result::Ok)
81            .unwrap_or_else(|| problem.gradient(&param))?;
82
83        let grad_norm = grad.l2_norm();
84
85        let hessian = state
86            .take_hessian()
87            .map(Result::Ok)
88            .unwrap_or_else(|| problem.hessian(&param))?;
89
90        let wdp = grad.weighted_dot(&hessian, &grad);
91
92        let tau: F = if wdp <= float!(0.0) {
93            float!(1.0)
94        } else {
95            float!(1.0).min(grad_norm.powi(3) / (self.radius * wdp))
96        };
97
98        let new_param = grad.mul(&(-tau * self.radius / grad_norm));
99        Ok((state.param(new_param), None))
100    }
101
102    fn terminate(&mut self, state: &IterState<P, G, (), H, (), F>) -> TerminationStatus {
103        // Not an iterative algorithm
104        if state.get_iter() >= 1 {
105            TerminationStatus::Terminated(TerminationReason::MaxItersReached)
106        } else {
107            TerminationStatus::NotTerminated
108        }
109    }
110}
111
112impl<F> TrustRegionRadius<F> for CauchyPoint<F>
113where
114    F: ArgminFloat,
115{
116    /// Set current radius.
117    ///
118    /// Needed by [`TrustRegion`](`crate::solver::trustregion::TrustRegion`).
119    ///
120    /// # Example
121    ///
122    /// ```
123    /// use argmin::solver::trustregion::{CauchyPoint, TrustRegionRadius};
124    /// let mut cp: CauchyPoint<f64> = CauchyPoint::new();
125    /// cp.set_radius(0.8);
126    /// ```
127    fn set_radius(&mut self, radius: F) {
128        self.radius = radius;
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use crate::core::{test_utils::TestProblem, ArgminError};
136    use approx::assert_relative_eq;
137
138    test_trait_impl!(cauchypoint, CauchyPoint<f64>);
139
140    #[test]
141    fn test_new() {
142        let cp: CauchyPoint<f64> = CauchyPoint::new();
143
144        let CauchyPoint { radius } = cp;
145
146        assert_eq!(radius.to_ne_bytes(), f64::NAN.to_ne_bytes());
147    }
148
149    #[test]
150    fn test_next_iter() {
151        let param: Vec<f64> = vec![-1.0, 1.0];
152
153        let mut cp: CauchyPoint<f64> = CauchyPoint::new();
154        cp.set_radius(1.0);
155
156        // Forgot to initialize the parameter vector
157        let state: IterState<Vec<f64>, Vec<f64>, (), Vec<Vec<f64>>, (), f64> = IterState::new();
158        let problem = TestProblem::new();
159        let res = cp.next_iter(&mut Problem::new(problem), state);
160        assert_error!(
161            res,
162            ArgminError,
163            concat!(
164                "Not initialized: \"`CauchyPoint` requires an initial parameter vector. Please ",
165                "provide an initial guess via `Executor`s `configure` method.\""
166            )
167        );
168
169        // All good.
170        let state: IterState<Vec<f64>, Vec<f64>, (), Vec<Vec<f64>>, (), f64> =
171            IterState::new().param(param);
172        let problem = TestProblem::new();
173        let (mut state_out, kv) = cp.next_iter(&mut Problem::new(problem), state).unwrap();
174
175        assert!(kv.is_none());
176
177        let s_param = state_out.take_param().unwrap();
178
179        assert_relative_eq!(s_param[0], 1.0f64 / 2.0f64.sqrt(), epsilon = f64::EPSILON);
180        assert_relative_eq!(s_param[1], -1.0f64 / 2.0f64.sqrt(), epsilon = f64::EPSILON);
181    }
182}