argmin/solver/trustregion/
trustregion_method.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use crate::core::{
9    ArgminFloat, CostFunction, Error, Executor, Gradient, Hessian, IterState, OptimizationResult,
10    Problem, Solver, TerminationStatus, TrustRegionRadius, KV,
11};
12use crate::solver::trustregion::reduction_ratio;
13use argmin_math::{ArgminAdd, ArgminDot, ArgminL2Norm, ArgminWeightedDot};
14#[cfg(feature = "serde1")]
15use serde::{Deserialize, Serialize};
16
17/// # Trust region method
18///
19/// The trust region method approximates the cost function within a certain region around the
20/// current point in parameter space. Depending on the quality of this approximation, the region is
21/// either expanded or contracted.
22///
23/// The calculation of the actual step length and direction is performed by a method which
24/// implements [`TrustRegionRadius`](`crate::solver::trustregion::TrustRegionRadius`), such as:
25///
26/// * [Cauchy point](`crate::solver::trustregion::CauchyPoint`)
27/// * [Dogleg method](`crate::solver::trustregion::Dogleg`)
28/// * [Steihaug method](`crate::solver::trustregion::Steihaug`)
29///
30/// ## Requirements on the optimization problem
31///
32/// The optimization problem is required to implement [`CostFunction`], [`Gradient`] and
33/// [`Hessian`].
34///
35/// ## Reference
36///
37/// Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization.
38/// Springer. ISBN 0-387-30303-0.
39#[derive(Clone)]
40#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
41pub struct TrustRegion<R, F> {
42    /// Radius
43    radius: F,
44    /// Maximum radius
45    max_radius: F,
46    /// eta \in [0, 1/4)
47    eta: F,
48    /// subproblem (must implement [`crate::solver::trustregion::TrustRegionRadius`])
49    subproblem: R,
50    /// f(xk)
51    fxk: F,
52    /// mk(0)
53    mk0: F,
54}
55
56impl<R, F> TrustRegion<R, F>
57where
58    F: ArgminFloat,
59{
60    /// Construct a new instance of [`TrustRegion`]
61    ///
62    /// # Example
63    ///
64    /// ```
65    /// use argmin::solver::trustregion::{CauchyPoint, TrustRegion};
66    /// let cp: CauchyPoint<f64> = CauchyPoint::new();
67    /// let tr: TrustRegion<_, f64> = TrustRegion::new(cp);
68    /// ```
69    pub fn new(subproblem: R) -> Self {
70        TrustRegion {
71            radius: float!(1.0),
72            max_radius: float!(100.0),
73            eta: float!(0.125),
74            subproblem,
75            fxk: F::nan(),
76            mk0: F::nan(),
77        }
78    }
79
80    /// Set radius
81    ///
82    /// Defaults to `1.0`.
83    ///
84    /// # Example
85    ///
86    /// ```
87    /// # use argmin::solver::trustregion::{TrustRegion, CauchyPoint};
88    /// # use argmin::core::Error;
89    /// # fn main() -> Result<(), Error> {
90    /// let cp: CauchyPoint<f64> = CauchyPoint::new();
91    /// let tr: TrustRegion<_, f64> = TrustRegion::new(cp).with_radius(0.8)?;
92    /// # Ok(())
93    /// # }
94    /// ```
95    pub fn with_radius(mut self, radius: F) -> Result<Self, Error> {
96        if radius <= float!(0.0) {
97            return Err(argmin_error!(
98                InvalidParameter,
99                "`TrustRegion`: radius must be > 0."
100            ));
101        }
102        self.radius = radius;
103        Ok(self)
104    }
105
106    /// Set maximum radius
107    ///
108    /// Defaults to `100.0`.
109    ///
110    /// # Example
111    ///
112    /// ```
113    /// # use argmin::solver::trustregion::{TrustRegion, CauchyPoint};
114    /// # use argmin::core::Error;
115    /// # fn main() -> Result<(), Error> {
116    /// let cp: CauchyPoint<f64> = CauchyPoint::new();
117    /// let tr: TrustRegion<_, f64> = TrustRegion::new(cp).with_max_radius(1000.0)?;
118    /// # Ok(())
119    /// # }
120    /// ```
121    pub fn with_max_radius(mut self, max_radius: F) -> Result<Self, Error> {
122        if max_radius <= float!(0.0) {
123            return Err(argmin_error!(
124                InvalidParameter,
125                "`TrustRegion`: maximum radius must be > 0."
126            ));
127        }
128        self.max_radius = max_radius;
129        Ok(self)
130    }
131
132    /// Set eta
133    ///
134    /// Must lie in `[0, 1/4)` and defaults to `0.125`.
135    ///
136    /// # Example
137    ///
138    /// ```
139    /// # use argmin::solver::trustregion::{TrustRegion, CauchyPoint};
140    /// # use argmin::core::Error;
141    /// # fn main() -> Result<(), Error> {
142    /// let cp: CauchyPoint<f64> = CauchyPoint::new();
143    /// let tr: TrustRegion<_, f64> = TrustRegion::new(cp).with_eta(0.2)?;
144    /// # Ok(())
145    /// # }
146    /// ```
147    pub fn with_eta(mut self, eta: F) -> Result<Self, Error> {
148        if eta >= float!(0.25) || eta < float!(0.0) {
149            return Err(argmin_error!(
150                InvalidParameter,
151                "`TrustRegion`: eta must be in [0, 1/4)."
152            ));
153        }
154        self.eta = eta;
155        Ok(self)
156    }
157}
158
159impl<O, R, F, P, G, H> Solver<O, IterState<P, G, (), H, (), F>> for TrustRegion<R, F>
160where
161    O: CostFunction<Param = P, Output = F>
162        + Gradient<Param = P, Gradient = G>
163        + Hessian<Param = P, Hessian = H>,
164    P: Clone + ArgminL2Norm<F> + ArgminDot<P, F> + ArgminDot<G, F> + ArgminAdd<P, P>,
165    G: Clone,
166    H: Clone + ArgminDot<P, P>,
167    R: Clone + TrustRegionRadius<F> + Solver<O, IterState<P, G, (), H, (), F>>,
168    F: ArgminFloat,
169{
170    fn name(&self) -> &str {
171        "Trust region"
172    }
173
174    fn init(
175        &mut self,
176        problem: &mut Problem<O>,
177        mut state: IterState<P, G, (), H, (), F>,
178    ) -> Result<(IterState<P, G, (), H, (), F>, Option<KV>), Error> {
179        let param = state.take_param().ok_or_else(argmin_error_closure!(
180            NotInitialized,
181            concat!(
182                "`TrustRegion` requires an initial parameter vector. ",
183                "Please provide an initial guess via `Executor`s `configure` method."
184            )
185        ))?;
186
187        let grad = state
188            .take_gradient()
189            .map(Result::Ok)
190            .unwrap_or_else(|| problem.gradient(&param))?;
191
192        let hessian = state
193            .take_hessian()
194            .map(Result::Ok)
195            .unwrap_or_else(|| problem.hessian(&param))?;
196
197        let cost = state.get_cost();
198        self.fxk = if cost.is_infinite() && cost.is_sign_positive() {
199            problem.cost(&param)?
200        } else {
201            cost
202        };
203
204        self.mk0 = self.fxk;
205        Ok((
206            state
207                .param(param)
208                .cost(self.fxk)
209                .gradient(grad)
210                .hessian(hessian),
211            None,
212        ))
213    }
214
215    fn next_iter(
216        &mut self,
217        problem: &mut Problem<O>,
218        mut state: IterState<P, G, (), H, (), F>,
219    ) -> Result<(IterState<P, G, (), H, (), F>, Option<KV>), Error> {
220        let param = state.take_param().ok_or_else(argmin_error_closure!(
221            PotentialBug,
222            "`TrustRegion`: Parameter vector in state not set."
223        ))?;
224
225        let grad = state.take_gradient().ok_or_else(argmin_error_closure!(
226            PotentialBug,
227            "`TrustRegion`: Gradient in state not set."
228        ))?;
229
230        let hessian = state.take_hessian().ok_or_else(argmin_error_closure!(
231            PotentialBug,
232            "`TrustRegion`: Hessian in state not set."
233        ))?;
234
235        self.subproblem.set_radius(self.radius);
236
237        let OptimizationResult {
238            problem: sub_problem,
239            state: mut sub_state,
240            ..
241        } = Executor::new(problem.take_problem().unwrap(), self.subproblem.clone())
242            .configure(|config| {
243                config
244                    .param(param.clone())
245                    .gradient(grad.clone())
246                    .hessian(hessian.clone())
247            })
248            .ctrlc(false)
249            .run()?;
250
251        let pk = sub_state.take_param().unwrap();
252
253        // Consume intermediate problem again. This takes care of the function evaluation counts.
254        problem.consume_problem(sub_problem);
255
256        let new_param = pk.add(&param);
257        let fxkpk = problem.cost(&new_param)?;
258        let mkpk = self.fxk + pk.dot(&grad) + float!(0.5) * pk.weighted_dot(&hessian, &pk);
259
260        let rho = reduction_ratio(self.fxk, fxkpk, self.mk0, mkpk);
261
262        let pk_norm = pk.l2_norm();
263
264        let cur_radius = self.radius;
265
266        self.radius = if rho < float!(0.25) {
267            float!(0.25) * pk_norm
268        } else if rho > float!(0.75) && (pk_norm - self.radius).abs() <= float!(10.0) * F::epsilon()
269        {
270            self.max_radius.min(float!(2.0) * self.radius)
271        } else {
272            self.radius
273        };
274
275        Ok((
276            if rho > self.eta {
277                self.fxk = fxkpk;
278                self.mk0 = fxkpk;
279                let grad = problem.gradient(&new_param)?;
280                let hessian = problem.hessian(&new_param)?;
281                state
282                    .param(new_param)
283                    .cost(fxkpk)
284                    .gradient(grad)
285                    .hessian(hessian)
286            } else {
287                state
288                    .param(param)
289                    .cost(self.fxk)
290                    .gradient(grad)
291                    .hessian(hessian)
292            },
293            Some(kv!("radius" => cur_radius;)),
294        ))
295    }
296
297    fn terminate(&mut self, _state: &IterState<P, G, (), H, (), F>) -> TerminationStatus {
298        TerminationStatus::NotTerminated
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305    use crate::core::test_utils::TestProblem;
306    use crate::core::{ArgminError, State};
307    use crate::solver::trustregion::{CauchyPoint, Steihaug};
308
309    test_trait_impl!(trustregion, TrustRegion<Steihaug<TestProblem, f64>, f64>);
310
311    #[test]
312    fn test_new() {
313        let cp: CauchyPoint<f64> = CauchyPoint::new();
314        let tr: TrustRegion<_, f64> = TrustRegion::new(cp);
315
316        let TrustRegion {
317            radius,
318            max_radius,
319            eta,
320            subproblem: _,
321            fxk,
322            mk0,
323        } = tr;
324
325        assert_eq!(radius.to_ne_bytes(), 1.0f64.to_ne_bytes());
326        assert_eq!(max_radius.to_ne_bytes(), 100.0f64.to_ne_bytes());
327        assert_eq!(eta.to_ne_bytes(), 0.125f64.to_ne_bytes());
328        assert_eq!(fxk.to_ne_bytes(), f64::NAN.to_ne_bytes());
329        assert_eq!(mk0.to_ne_bytes(), f64::NAN.to_ne_bytes());
330    }
331
332    #[test]
333    fn test_with_radius() {
334        // correct parameters
335        for radius in [f64::EPSILON, 1e-2, 1.0, 2.0, 10.0, 100.0] {
336            let cp: CauchyPoint<f64> = CauchyPoint::new();
337            let tr: TrustRegion<_, f64> = TrustRegion::new(cp);
338            let res = tr.with_radius(radius);
339            assert!(res.is_ok());
340
341            let nm = res.unwrap();
342            assert_eq!(nm.radius.to_ne_bytes(), radius.to_ne_bytes());
343        }
344
345        // incorrect parameters
346        for radius in [0.0, -f64::EPSILON, -1.0, -100.0, -42.0] {
347            let cp: CauchyPoint<f64> = CauchyPoint::new();
348            let tr: TrustRegion<_, f64> = TrustRegion::new(cp);
349            let res = tr.with_radius(radius);
350            assert_error!(
351                res,
352                ArgminError,
353                "Invalid parameter: \"`TrustRegion`: radius must be > 0.\""
354            );
355        }
356    }
357
358    #[test]
359    fn test_with_eta() {
360        // correct parameters
361        for eta in [0.0, f64::EPSILON, 1e-2, 0.125, 0.25 - f64::EPSILON] {
362            let cp: CauchyPoint<f64> = CauchyPoint::new();
363            let tr: TrustRegion<_, f64> = TrustRegion::new(cp);
364            let res = tr.with_eta(eta);
365            assert!(res.is_ok());
366
367            let nm = res.unwrap();
368            assert_eq!(nm.eta.to_ne_bytes(), eta.to_ne_bytes());
369        }
370
371        // incorrect parameters
372        for eta in [-f64::EPSILON, -1.0, -100.0, -42.0, 0.25, 1.0] {
373            let cp: CauchyPoint<f64> = CauchyPoint::new();
374            let tr: TrustRegion<_, f64> = TrustRegion::new(cp);
375            let res = tr.with_eta(eta);
376            assert_error!(
377                res,
378                ArgminError,
379                "Invalid parameter: \"`TrustRegion`: eta must be in [0, 1/4).\""
380            );
381        }
382    }
383
384    #[test]
385    fn test_init() {
386        let param: Vec<f64> = vec![1.0, 2.0];
387
388        let cp: CauchyPoint<f64> = CauchyPoint::new();
389        let mut tr: TrustRegion<_, f64> = TrustRegion::new(cp);
390
391        // Forgot to initialize parameter vector
392        let state: IterState<Vec<f64>, Vec<f64>, (), Vec<Vec<f64>>, (), f64> = IterState::new();
393        let problem = TestProblem::new();
394        let res = tr.init(&mut Problem::new(problem), state);
395        assert_error!(
396            res,
397            ArgminError,
398            concat!(
399                "Not initialized: \"`TrustRegion` requires an initial parameter vector. Please ",
400                "provide an initial guess via `Executor`s `configure` method.\""
401            )
402        );
403
404        // All good.
405        let state: IterState<Vec<f64>, Vec<f64>, (), Vec<Vec<f64>>, (), f64> =
406            IterState::new().param(param.clone());
407        let problem = TestProblem::new();
408        let (mut state_out, kv) = tr.init(&mut Problem::new(problem), state).unwrap();
409
410        assert!(kv.is_none());
411
412        let s_param = state_out.take_param().unwrap();
413
414        assert_eq!(s_param[0].to_ne_bytes(), param[0].to_ne_bytes());
415        assert_eq!(s_param[1].to_ne_bytes(), param[1].to_ne_bytes());
416
417        let TrustRegion {
418            radius,
419            max_radius,
420            eta,
421            subproblem: _,
422            fxk,
423            mk0,
424        } = tr;
425
426        assert_eq!(radius.to_ne_bytes(), 1.0f64.to_ne_bytes());
427        assert_eq!(max_radius.to_ne_bytes(), 100.0f64.to_ne_bytes());
428        assert_eq!(eta.to_ne_bytes(), 0.125f64.to_ne_bytes());
429        assert_eq!(fxk.to_ne_bytes(), 1.0f64.sqrt().to_ne_bytes());
430        assert_eq!(mk0.to_ne_bytes(), 1.0f64.to_ne_bytes());
431    }
432}