argmin/solver/trustregion/
cauchypoint.rs1use crate::core::{
9 ArgminFloat, Error, Gradient, Hessian, IterState, Problem, Solver, State, TerminationReason,
10 TerminationStatus, TrustRegionRadius, KV,
11};
12use argmin_math::{ArgminL2Norm, ArgminMul, ArgminWeightedDot};
13#[cfg(feature = "serde1")]
14use serde::{Deserialize, Serialize};
15use std::fmt::Debug;
16
17#[derive(Clone, Debug, Copy, PartialEq, Eq, PartialOrd, Default)]
31#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
32pub struct CauchyPoint<F> {
33 radius: F,
35}
36
37impl<F> CauchyPoint<F>
38where
39 F: ArgminFloat,
40{
41 pub fn new() -> Self {
50 CauchyPoint { radius: F::nan() }
51 }
52}
53
54impl<O, F, P, G, H> Solver<O, IterState<P, G, (), H, (), F>> for CauchyPoint<F>
55where
56 O: Gradient<Param = P, Gradient = G> + Hessian<Param = P, Hessian = H>,
57 P: Clone + ArgminMul<F, P> + ArgminWeightedDot<P, F, H>,
58 G: ArgminMul<F, P> + ArgminWeightedDot<G, F, H> + ArgminL2Norm<F>,
59 F: ArgminFloat,
60{
61 fn name(&self) -> &str {
62 "Cauchy Point"
63 }
64
65 fn next_iter(
66 &mut self,
67 problem: &mut Problem<O>,
68 mut state: IterState<P, G, (), H, (), F>,
69 ) -> Result<(IterState<P, G, (), H, (), F>, Option<KV>), Error> {
70 let param = state.take_param().ok_or_else(argmin_error_closure!(
71 NotInitialized,
72 concat!(
73 "`CauchyPoint` requires an initial parameter vector. ",
74 "Please provide an initial guess via `Executor`s `configure` method."
75 )
76 ))?;
77
78 let grad = state
79 .take_gradient()
80 .map(Result::Ok)
81 .unwrap_or_else(|| problem.gradient(¶m))?;
82
83 let grad_norm = grad.l2_norm();
84
85 let hessian = state
86 .take_hessian()
87 .map(Result::Ok)
88 .unwrap_or_else(|| problem.hessian(¶m))?;
89
90 let wdp = grad.weighted_dot(&hessian, &grad);
91
92 let tau: F = if wdp <= float!(0.0) {
93 float!(1.0)
94 } else {
95 float!(1.0).min(grad_norm.powi(3) / (self.radius * wdp))
96 };
97
98 let new_param = grad.mul(&(-tau * self.radius / grad_norm));
99 Ok((state.param(new_param), None))
100 }
101
102 fn terminate(&mut self, state: &IterState<P, G, (), H, (), F>) -> TerminationStatus {
103 if state.get_iter() >= 1 {
105 TerminationStatus::Terminated(TerminationReason::MaxItersReached)
106 } else {
107 TerminationStatus::NotTerminated
108 }
109 }
110}
111
112impl<F> TrustRegionRadius<F> for CauchyPoint<F>
113where
114 F: ArgminFloat,
115{
116 fn set_radius(&mut self, radius: F) {
128 self.radius = radius;
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135 use crate::core::{test_utils::TestProblem, ArgminError};
136 use approx::assert_relative_eq;
137
138 test_trait_impl!(cauchypoint, CauchyPoint<f64>);
139
140 #[test]
141 fn test_new() {
142 let cp: CauchyPoint<f64> = CauchyPoint::new();
143
144 let CauchyPoint { radius } = cp;
145
146 assert_eq!(radius.to_ne_bytes(), f64::NAN.to_ne_bytes());
147 }
148
149 #[test]
150 fn test_next_iter() {
151 let param: Vec<f64> = vec![-1.0, 1.0];
152
153 let mut cp: CauchyPoint<f64> = CauchyPoint::new();
154 cp.set_radius(1.0);
155
156 let state: IterState<Vec<f64>, Vec<f64>, (), Vec<Vec<f64>>, (), f64> = IterState::new();
158 let problem = TestProblem::new();
159 let res = cp.next_iter(&mut Problem::new(problem), state);
160 assert_error!(
161 res,
162 ArgminError,
163 concat!(
164 "Not initialized: \"`CauchyPoint` requires an initial parameter vector. Please ",
165 "provide an initial guess via `Executor`s `configure` method.\""
166 )
167 );
168
169 let state: IterState<Vec<f64>, Vec<f64>, (), Vec<Vec<f64>>, (), f64> =
171 IterState::new().param(param);
172 let problem = TestProblem::new();
173 let (mut state_out, kv) = cp.next_iter(&mut Problem::new(problem), state).unwrap();
174
175 assert!(kv.is_none());
176
177 let s_param = state_out.take_param().unwrap();
178
179 assert_relative_eq!(s_param[0], 1.0f64 / 2.0f64.sqrt(), epsilon = f64::EPSILON);
180 assert_relative_eq!(s_param[1], -1.0f64 / 2.0f64.sqrt(), epsilon = f64::EPSILON);
181 }
182}