1use crate::core::{
9 ArgminFloat, CostFunction, Error, Executor, Gradient, Hessian, IterState, OptimizationResult,
10 Problem, Solver, TerminationStatus, TrustRegionRadius, KV,
11};
12use crate::solver::trustregion::reduction_ratio;
13use argmin_math::{ArgminAdd, ArgminDot, ArgminL2Norm, ArgminWeightedDot};
14#[cfg(feature = "serde1")]
15use serde::{Deserialize, Serialize};
16
17#[derive(Clone)]
40#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
41pub struct TrustRegion<R, F> {
42 radius: F,
44 max_radius: F,
46 eta: F,
48 subproblem: R,
50 fxk: F,
52 mk0: F,
54}
55
56impl<R, F> TrustRegion<R, F>
57where
58 F: ArgminFloat,
59{
60 pub fn new(subproblem: R) -> Self {
70 TrustRegion {
71 radius: float!(1.0),
72 max_radius: float!(100.0),
73 eta: float!(0.125),
74 subproblem,
75 fxk: F::nan(),
76 mk0: F::nan(),
77 }
78 }
79
80 pub fn with_radius(mut self, radius: F) -> Result<Self, Error> {
96 if radius <= float!(0.0) {
97 return Err(argmin_error!(
98 InvalidParameter,
99 "`TrustRegion`: radius must be > 0."
100 ));
101 }
102 self.radius = radius;
103 Ok(self)
104 }
105
106 pub fn with_max_radius(mut self, max_radius: F) -> Result<Self, Error> {
122 if max_radius <= float!(0.0) {
123 return Err(argmin_error!(
124 InvalidParameter,
125 "`TrustRegion`: maximum radius must be > 0."
126 ));
127 }
128 self.max_radius = max_radius;
129 Ok(self)
130 }
131
132 pub fn with_eta(mut self, eta: F) -> Result<Self, Error> {
148 if eta >= float!(0.25) || eta < float!(0.0) {
149 return Err(argmin_error!(
150 InvalidParameter,
151 "`TrustRegion`: eta must be in [0, 1/4)."
152 ));
153 }
154 self.eta = eta;
155 Ok(self)
156 }
157}
158
159impl<O, R, F, P, G, H> Solver<O, IterState<P, G, (), H, (), F>> for TrustRegion<R, F>
160where
161 O: CostFunction<Param = P, Output = F>
162 + Gradient<Param = P, Gradient = G>
163 + Hessian<Param = P, Hessian = H>,
164 P: Clone + ArgminL2Norm<F> + ArgminDot<P, F> + ArgminDot<G, F> + ArgminAdd<P, P>,
165 G: Clone,
166 H: Clone + ArgminDot<P, P>,
167 R: Clone + TrustRegionRadius<F> + Solver<O, IterState<P, G, (), H, (), F>>,
168 F: ArgminFloat,
169{
170 fn name(&self) -> &str {
171 "Trust region"
172 }
173
174 fn init(
175 &mut self,
176 problem: &mut Problem<O>,
177 mut state: IterState<P, G, (), H, (), F>,
178 ) -> Result<(IterState<P, G, (), H, (), F>, Option<KV>), Error> {
179 let param = state.take_param().ok_or_else(argmin_error_closure!(
180 NotInitialized,
181 concat!(
182 "`TrustRegion` requires an initial parameter vector. ",
183 "Please provide an initial guess via `Executor`s `configure` method."
184 )
185 ))?;
186
187 let grad = state
188 .take_gradient()
189 .map(Result::Ok)
190 .unwrap_or_else(|| problem.gradient(¶m))?;
191
192 let hessian = state
193 .take_hessian()
194 .map(Result::Ok)
195 .unwrap_or_else(|| problem.hessian(¶m))?;
196
197 let cost = state.get_cost();
198 self.fxk = if cost.is_infinite() && cost.is_sign_positive() {
199 problem.cost(¶m)?
200 } else {
201 cost
202 };
203
204 self.mk0 = self.fxk;
205 Ok((
206 state
207 .param(param)
208 .cost(self.fxk)
209 .gradient(grad)
210 .hessian(hessian),
211 None,
212 ))
213 }
214
215 fn next_iter(
216 &mut self,
217 problem: &mut Problem<O>,
218 mut state: IterState<P, G, (), H, (), F>,
219 ) -> Result<(IterState<P, G, (), H, (), F>, Option<KV>), Error> {
220 let param = state.take_param().ok_or_else(argmin_error_closure!(
221 PotentialBug,
222 "`TrustRegion`: Parameter vector in state not set."
223 ))?;
224
225 let grad = state.take_gradient().ok_or_else(argmin_error_closure!(
226 PotentialBug,
227 "`TrustRegion`: Gradient in state not set."
228 ))?;
229
230 let hessian = state.take_hessian().ok_or_else(argmin_error_closure!(
231 PotentialBug,
232 "`TrustRegion`: Hessian in state not set."
233 ))?;
234
235 self.subproblem.set_radius(self.radius);
236
237 let OptimizationResult {
238 problem: sub_problem,
239 state: mut sub_state,
240 ..
241 } = Executor::new(problem.take_problem().unwrap(), self.subproblem.clone())
242 .configure(|config| {
243 config
244 .param(param.clone())
245 .gradient(grad.clone())
246 .hessian(hessian.clone())
247 })
248 .ctrlc(false)
249 .run()?;
250
251 let pk = sub_state.take_param().unwrap();
252
253 problem.consume_problem(sub_problem);
255
256 let new_param = pk.add(¶m);
257 let fxkpk = problem.cost(&new_param)?;
258 let mkpk = self.fxk + pk.dot(&grad) + float!(0.5) * pk.weighted_dot(&hessian, &pk);
259
260 let rho = reduction_ratio(self.fxk, fxkpk, self.mk0, mkpk);
261
262 let pk_norm = pk.l2_norm();
263
264 let cur_radius = self.radius;
265
266 self.radius = if rho < float!(0.25) {
267 float!(0.25) * pk_norm
268 } else if rho > float!(0.75) && (pk_norm - self.radius).abs() <= float!(10.0) * F::epsilon()
269 {
270 self.max_radius.min(float!(2.0) * self.radius)
271 } else {
272 self.radius
273 };
274
275 Ok((
276 if rho > self.eta {
277 self.fxk = fxkpk;
278 self.mk0 = fxkpk;
279 let grad = problem.gradient(&new_param)?;
280 let hessian = problem.hessian(&new_param)?;
281 state
282 .param(new_param)
283 .cost(fxkpk)
284 .gradient(grad)
285 .hessian(hessian)
286 } else {
287 state
288 .param(param)
289 .cost(self.fxk)
290 .gradient(grad)
291 .hessian(hessian)
292 },
293 Some(kv!("radius" => cur_radius;)),
294 ))
295 }
296
297 fn terminate(&mut self, _state: &IterState<P, G, (), H, (), F>) -> TerminationStatus {
298 TerminationStatus::NotTerminated
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305 use crate::core::test_utils::TestProblem;
306 use crate::core::{ArgminError, State};
307 use crate::solver::trustregion::{CauchyPoint, Steihaug};
308
309 test_trait_impl!(trustregion, TrustRegion<Steihaug<TestProblem, f64>, f64>);
310
311 #[test]
312 fn test_new() {
313 let cp: CauchyPoint<f64> = CauchyPoint::new();
314 let tr: TrustRegion<_, f64> = TrustRegion::new(cp);
315
316 let TrustRegion {
317 radius,
318 max_radius,
319 eta,
320 subproblem: _,
321 fxk,
322 mk0,
323 } = tr;
324
325 assert_eq!(radius.to_ne_bytes(), 1.0f64.to_ne_bytes());
326 assert_eq!(max_radius.to_ne_bytes(), 100.0f64.to_ne_bytes());
327 assert_eq!(eta.to_ne_bytes(), 0.125f64.to_ne_bytes());
328 assert_eq!(fxk.to_ne_bytes(), f64::NAN.to_ne_bytes());
329 assert_eq!(mk0.to_ne_bytes(), f64::NAN.to_ne_bytes());
330 }
331
332 #[test]
333 fn test_with_radius() {
334 for radius in [f64::EPSILON, 1e-2, 1.0, 2.0, 10.0, 100.0] {
336 let cp: CauchyPoint<f64> = CauchyPoint::new();
337 let tr: TrustRegion<_, f64> = TrustRegion::new(cp);
338 let res = tr.with_radius(radius);
339 assert!(res.is_ok());
340
341 let nm = res.unwrap();
342 assert_eq!(nm.radius.to_ne_bytes(), radius.to_ne_bytes());
343 }
344
345 for radius in [0.0, -f64::EPSILON, -1.0, -100.0, -42.0] {
347 let cp: CauchyPoint<f64> = CauchyPoint::new();
348 let tr: TrustRegion<_, f64> = TrustRegion::new(cp);
349 let res = tr.with_radius(radius);
350 assert_error!(
351 res,
352 ArgminError,
353 "Invalid parameter: \"`TrustRegion`: radius must be > 0.\""
354 );
355 }
356 }
357
358 #[test]
359 fn test_with_eta() {
360 for eta in [0.0, f64::EPSILON, 1e-2, 0.125, 0.25 - f64::EPSILON] {
362 let cp: CauchyPoint<f64> = CauchyPoint::new();
363 let tr: TrustRegion<_, f64> = TrustRegion::new(cp);
364 let res = tr.with_eta(eta);
365 assert!(res.is_ok());
366
367 let nm = res.unwrap();
368 assert_eq!(nm.eta.to_ne_bytes(), eta.to_ne_bytes());
369 }
370
371 for eta in [-f64::EPSILON, -1.0, -100.0, -42.0, 0.25, 1.0] {
373 let cp: CauchyPoint<f64> = CauchyPoint::new();
374 let tr: TrustRegion<_, f64> = TrustRegion::new(cp);
375 let res = tr.with_eta(eta);
376 assert_error!(
377 res,
378 ArgminError,
379 "Invalid parameter: \"`TrustRegion`: eta must be in [0, 1/4).\""
380 );
381 }
382 }
383
384 #[test]
385 fn test_init() {
386 let param: Vec<f64> = vec![1.0, 2.0];
387
388 let cp: CauchyPoint<f64> = CauchyPoint::new();
389 let mut tr: TrustRegion<_, f64> = TrustRegion::new(cp);
390
391 let state: IterState<Vec<f64>, Vec<f64>, (), Vec<Vec<f64>>, (), f64> = IterState::new();
393 let problem = TestProblem::new();
394 let res = tr.init(&mut Problem::new(problem), state);
395 assert_error!(
396 res,
397 ArgminError,
398 concat!(
399 "Not initialized: \"`TrustRegion` requires an initial parameter vector. Please ",
400 "provide an initial guess via `Executor`s `configure` method.\""
401 )
402 );
403
404 let state: IterState<Vec<f64>, Vec<f64>, (), Vec<Vec<f64>>, (), f64> =
406 IterState::new().param(param.clone());
407 let problem = TestProblem::new();
408 let (mut state_out, kv) = tr.init(&mut Problem::new(problem), state).unwrap();
409
410 assert!(kv.is_none());
411
412 let s_param = state_out.take_param().unwrap();
413
414 assert_eq!(s_param[0].to_ne_bytes(), param[0].to_ne_bytes());
415 assert_eq!(s_param[1].to_ne_bytes(), param[1].to_ne_bytes());
416
417 let TrustRegion {
418 radius,
419 max_radius,
420 eta,
421 subproblem: _,
422 fxk,
423 mk0,
424 } = tr;
425
426 assert_eq!(radius.to_ne_bytes(), 1.0f64.to_ne_bytes());
427 assert_eq!(max_radius.to_ne_bytes(), 100.0f64.to_ne_bytes());
428 assert_eq!(eta.to_ne_bytes(), 0.125f64.to_ne_bytes());
429 assert_eq!(fxk.to_ne_bytes(), 1.0f64.sqrt().to_ne_bytes());
430 assert_eq!(mk0.to_ne_bytes(), 1.0f64.to_ne_bytes());
431 }
432}