1use crate::core::{
9 ArgminFloat, CostFunction, Error, Executor, Gradient, IterState, LineSearch,
10 OptimizationResult, Problem, Solver, TerminationReason, TerminationStatus, KV,
11};
12use argmin_math::{ArgminAdd, ArgminDot, ArgminL2Norm, ArgminMul, ArgminSub};
13#[cfg(feature = "serde1")]
14use serde::{Deserialize, Serialize};
15
16#[derive(Clone)]
29#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
30pub struct SR1<L, F> {
31 denominator_factor: F,
33 linesearch: L,
35 tol_grad: F,
37 tol_cost: F,
39}
40
41impl<L, F> SR1<L, F>
42where
43 F: ArgminFloat,
44{
45 pub fn new(linesearch: L) -> Self {
55 SR1 {
56 denominator_factor: float!(1e-8),
57 linesearch,
58 tol_grad: F::epsilon().sqrt(),
59 tol_cost: F::epsilon(),
60 }
61 }
62
63 pub fn with_denominator_factor(mut self, denominator_factor: F) -> Result<Self, Error> {
83 if denominator_factor <= float!(0.0) || denominator_factor >= float!(1.0) {
84 Err(argmin_error!(
85 InvalidParameter,
86 "`SR1`: denominator_factor must be in (0, 1)."
87 ))
88 } else {
89 self.denominator_factor = denominator_factor;
90 Ok(self)
91 }
92 }
93
94 pub fn with_tolerance_grad(mut self, tol_grad: F) -> Result<Self, Error> {
110 if tol_grad < float!(0.0) {
111 return Err(argmin_error!(
112 InvalidParameter,
113 "`SR1`: gradient tolerance must be >= 0."
114 ));
115 }
116 self.tol_grad = tol_grad;
117 Ok(self)
118 }
119
120 pub fn with_tolerance_cost(mut self, tol_cost: F) -> Result<Self, Error> {
136 if tol_cost < float!(0.0) {
137 return Err(argmin_error!(
138 InvalidParameter,
139 "`SR1`: cost tolerance must be >= 0."
140 ));
141 }
142 self.tol_cost = tol_cost;
143 Ok(self)
144 }
145}
146
147impl<O, L, P, G, H, F> Solver<O, IterState<P, G, (), H, (), F>> for SR1<L, F>
148where
149 O: CostFunction<Param = P, Output = F> + Gradient<Param = P, Gradient = G>,
150 P: Clone
151 + ArgminSub<P, P>
152 + ArgminDot<G, F>
153 + ArgminDot<P, F>
154 + ArgminDot<P, H>
155 + ArgminL2Norm<F>
156 + ArgminMul<F, P>,
157 G: Clone + ArgminSub<P, P> + ArgminL2Norm<F> + ArgminSub<G, G>,
158 H: ArgminDot<G, P> + ArgminDot<P, P> + ArgminAdd<H, H> + ArgminMul<F, H>,
159 L: Clone + LineSearch<P, F> + Solver<O, IterState<P, G, (), (), (), F>>,
160 F: ArgminFloat,
161{
162 fn name(&self) -> &str {
163 "SR1"
164 }
165
166 fn init(
167 &mut self,
168 problem: &mut Problem<O>,
169 mut state: IterState<P, G, (), H, (), F>,
170 ) -> Result<(IterState<P, G, (), H, (), F>, Option<KV>), Error> {
171 let param = state.take_param().ok_or_else(argmin_error_closure!(
172 NotInitialized,
173 concat!(
174 "`SR1` requires an initial parameter vector. ",
175 "Please provide an initial guess via `Executor`s `configure` method."
176 )
177 ))?;
178
179 let inv_hessian = state.take_inv_hessian().ok_or_else(argmin_error_closure!(
180 NotInitialized,
181 concat!(
182 "`SR1` requires an initial inverse Hessian. ",
183 "Please provide an initial guess via `Executor`s `configure` method."
184 )
185 ))?;
186
187 let cost = state.get_cost();
188 let cost = if cost.is_infinite() {
189 problem.cost(¶m)?
190 } else {
191 cost
192 };
193
194 let grad = state
195 .take_gradient()
196 .map(Result::Ok)
197 .unwrap_or_else(|| problem.gradient(¶m))?;
198 Ok((
199 state
200 .param(param)
201 .cost(cost)
202 .gradient(grad)
203 .inv_hessian(inv_hessian),
204 None,
205 ))
206 }
207
208 fn next_iter(
209 &mut self,
210 problem: &mut Problem<O>,
211 mut state: IterState<P, G, (), H, (), F>,
212 ) -> Result<(IterState<P, G, (), H, (), F>, Option<KV>), Error> {
213 let param = state.take_param().ok_or_else(argmin_error_closure!(
214 PotentialBug,
215 "`SR1`: Parameter vector in state not set."
216 ))?;
217 let cost = state.get_cost();
218
219 let prev_grad = state.take_gradient().ok_or_else(argmin_error_closure!(
220 PotentialBug,
221 "`SR1`: Gradient in state not set."
222 ))?;
223
224 let mut inv_hessian = state.take_inv_hessian().ok_or_else(argmin_error_closure!(
225 PotentialBug,
226 "`SR1`: Inverse Hessian in state not set."
227 ))?;
228
229 let p = inv_hessian.dot(&prev_grad).mul(&float!(-1.0));
230
231 self.linesearch.search_direction(p);
232
233 let OptimizationResult {
235 problem: line_problem,
236 state: mut linesearch_state,
237 ..
238 } = Executor::new(problem.take_problem().unwrap(), self.linesearch.clone())
239 .configure(|config| {
240 config
241 .param(param.clone())
242 .gradient(prev_grad.clone())
243 .cost(cost)
244 })
245 .ctrlc(false)
246 .run()?;
247
248 let xk1 = linesearch_state.take_param().unwrap();
249 let next_cost = linesearch_state.get_cost();
250
251 problem.consume_problem(line_problem);
253
254 let grad = problem.gradient(&xk1)?;
255 let yk = grad.sub(&prev_grad);
256
257 let sk = xk1.sub(¶m);
258
259 let ykmbksk: P = yk.sub(&inv_hessian.dot(&sk));
263 let a: H = ykmbksk.dot(&ykmbksk);
264 let b: F = ykmbksk.dot(&sk);
265
266 let hessian_update = b.abs() >= self.denominator_factor * sk.l2_norm() * ykmbksk.l2_norm();
268
269 if hessian_update {
270 inv_hessian = inv_hessian.add(&a.mul(&(float!(1.0) / b)));
271 }
272
273 Ok((
274 state
275 .param(xk1)
276 .cost(next_cost)
277 .gradient(grad)
278 .inv_hessian(inv_hessian),
279 Some(kv!["denominator" => b; "hessian_update" => hessian_update;]),
280 ))
281 }
282
283 fn terminate(&mut self, state: &IterState<P, G, (), H, (), F>) -> TerminationStatus {
284 if state.get_gradient().unwrap().l2_norm() < self.tol_grad {
285 return TerminationStatus::Terminated(TerminationReason::SolverConverged);
286 }
287 if (state.get_prev_cost() - state.cost).abs() < self.tol_cost {
288 return TerminationStatus::Terminated(TerminationReason::SolverConverged);
289 }
290 TerminationStatus::NotTerminated
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297 use crate::core::{test_utils::TestProblem, ArgminError, State};
298 use crate::solver::linesearch::MoreThuenteLineSearch;
299
300 test_trait_impl!(
301 sr1,
302 SR1<MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64>, f64>
303 );
304
305 #[test]
306 fn test_new() {
307 #[derive(Eq, PartialEq, Debug)]
308 struct MyFakeLineSearch {}
309
310 let sr1: SR1<_, f64> = SR1::new(MyFakeLineSearch {});
311 let SR1 {
312 denominator_factor,
313 linesearch,
314 tol_grad,
315 tol_cost,
316 } = sr1;
317
318 assert_eq!(linesearch, MyFakeLineSearch {});
319 assert_eq!(tol_grad.to_ne_bytes(), f64::EPSILON.sqrt().to_ne_bytes());
320 assert_eq!(tol_cost.to_ne_bytes(), f64::EPSILON.to_ne_bytes());
321 assert_eq!(denominator_factor.to_ne_bytes(), 1e-8f64.to_ne_bytes());
322 }
323
324 #[test]
325 fn test_with_denominator_factor() {
326 #[derive(Eq, PartialEq, Debug, Clone, Copy)]
327 struct MyFakeLineSearch {}
328
329 for tol in [f64::EPSILON, 1e-8, 1e-6, 1e-2, 1.0 - f64::EPSILON] {
331 let sr1: SR1<_, f64> = SR1::new(MyFakeLineSearch {});
332 let res = sr1.with_denominator_factor(tol);
333 assert!(res.is_ok());
334
335 let nm = res.unwrap();
336 assert_eq!(nm.denominator_factor.to_ne_bytes(), tol.to_ne_bytes());
337 }
338
339 for tol in [-f64::EPSILON, 0.0, -1.0, 1.0] {
341 let sr1: SR1<_, f64> = SR1::new(MyFakeLineSearch {});
342 let res = sr1.with_denominator_factor(tol);
343 assert_error!(
344 res,
345 ArgminError,
346 "Invalid parameter: \"`SR1`: denominator_factor must be in (0, 1).\""
347 );
348 }
349 }
350
351 #[test]
352 fn test_with_tolerance_grad() {
353 #[derive(Eq, PartialEq, Debug, Clone, Copy)]
354 struct MyFakeLineSearch {}
355
356 for tol in [1e-6, 0.0, 1e-2, 1.0, 2.0] {
358 let sr1: SR1<_, f64> = SR1::new(MyFakeLineSearch {});
359 let res = sr1.with_tolerance_grad(tol);
360 assert!(res.is_ok());
361
362 let nm = res.unwrap();
363 assert_eq!(nm.tol_grad.to_ne_bytes(), tol.to_ne_bytes());
364 }
365
366 for tol in [-f64::EPSILON, -1.0, -100.0, -42.0] {
368 let sr1: SR1<_, f64> = SR1::new(MyFakeLineSearch {});
369 let res = sr1.with_tolerance_grad(tol);
370 assert_error!(
371 res,
372 ArgminError,
373 "Invalid parameter: \"`SR1`: gradient tolerance must be >= 0.\""
374 );
375 }
376 }
377
378 #[test]
379 fn test_with_tolerance_cost() {
380 #[derive(Eq, PartialEq, Debug, Clone, Copy)]
381 struct MyFakeLineSearch {}
382
383 for tol in [1e-6, 0.0, 1e-2, 1.0, 2.0] {
385 let sr1: SR1<_, f64> = SR1::new(MyFakeLineSearch {});
386 let res = sr1.with_tolerance_cost(tol);
387 assert!(res.is_ok());
388
389 let nm = res.unwrap();
390 assert_eq!(nm.tol_cost.to_ne_bytes(), tol.to_ne_bytes());
391 }
392
393 for tol in [-f64::EPSILON, -1.0, -100.0, -42.0] {
395 let sr1: SR1<_, f64> = SR1::new(MyFakeLineSearch {});
396 let res = sr1.with_tolerance_cost(tol);
397 assert_error!(
398 res,
399 ArgminError,
400 "Invalid parameter: \"`SR1`: cost tolerance must be >= 0.\""
401 );
402 }
403 }
404
405 #[test]
406 fn test_init() {
407 let linesearch = MoreThuenteLineSearch::new().with_c(1e-4, 0.9).unwrap();
408
409 let param: Vec<f64> = vec![-1.0, 1.0];
410 let inv_hessian: Vec<Vec<f64>> = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
411
412 let mut sr1: SR1<_, f64> = SR1::new(linesearch);
413
414 let state: IterState<Vec<f64>, Vec<f64>, (), Vec<Vec<f64>>, (), f64> = IterState::new();
416 let problem = TestProblem::new();
417 let res = sr1.init(&mut Problem::new(problem), state);
418 assert_error!(
419 res,
420 ArgminError,
421 concat!(
422 "Not initialized: \"`SR1` requires an initial parameter vector. Please ",
423 "provide an initial guess via `Executor`s `configure` method.\""
424 )
425 );
426
427 let state: IterState<Vec<f64>, Vec<f64>, (), Vec<Vec<f64>>, (), f64> =
429 IterState::new().param(param.clone());
430 let problem = TestProblem::new();
431 let res = sr1.init(&mut Problem::new(problem), state);
432
433 assert_error!(
434 res,
435 ArgminError,
436 concat!(
437 "Not initialized: \"`SR1` requires an initial inverse Hessian. Please ",
438 "provide an initial guess via `Executor`s `configure` method.\""
439 )
440 );
441
442 let state: IterState<Vec<f64>, Vec<f64>, (), Vec<Vec<f64>>, (), f64> = IterState::new()
444 .param(param.clone())
445 .inv_hessian(inv_hessian.clone());
446 let problem = TestProblem::new();
447 let (mut state_out, kv) = sr1.init(&mut Problem::new(problem), state).unwrap();
448
449 assert!(kv.is_none());
450
451 let s_param = state_out.take_param().unwrap();
452
453 for (s, p) in s_param.iter().zip(param.iter()) {
454 assert_eq!(s.to_ne_bytes(), p.to_ne_bytes());
455 }
456
457 let s_grad = state_out.take_gradient().unwrap();
458
459 for (s, p) in s_grad.iter().zip(param.iter()) {
460 assert_eq!(s.to_ne_bytes(), p.to_ne_bytes());
461 }
462
463 let s_inv_hessian = state_out.take_inv_hessian().unwrap();
464
465 for (s, h) in s_inv_hessian
466 .iter()
467 .flatten()
468 .zip(inv_hessian.iter().flatten())
469 {
470 assert_eq!(s.to_ne_bytes(), h.to_ne_bytes());
471 }
472
473 assert_eq!(state_out.get_cost().to_ne_bytes(), 1.0f64.to_ne_bytes())
474 }
475
476 #[test]
477 fn test_init_provided_cost() {
478 let linesearch = MoreThuenteLineSearch::new().with_c(1e-4, 0.9).unwrap();
479
480 let param: Vec<f64> = vec![-1.0, 1.0];
481 let inv_hessian: Vec<Vec<f64>> = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
482
483 let mut sr1: SR1<_, f64> = SR1::new(linesearch);
484
485 let state: IterState<Vec<f64>, Vec<f64>, (), Vec<Vec<f64>>, (), f64> = IterState::new()
486 .param(param)
487 .inv_hessian(inv_hessian)
488 .cost(1234.0);
489
490 let problem = TestProblem::new();
491 let (state_out, kv) = sr1.init(&mut Problem::new(problem), state).unwrap();
492
493 assert!(kv.is_none());
494
495 assert_eq!(state_out.get_cost().to_ne_bytes(), 1234.0f64.to_ne_bytes())
496 }
497
498 #[test]
499 fn test_init_provided_grad() {
500 let linesearch = MoreThuenteLineSearch::new().with_c(1e-4, 0.9).unwrap();
501
502 let param: Vec<f64> = vec![-1.0, 1.0];
503 let gradient: Vec<f64> = vec![4.0, 9.0];
504 let inv_hessian: Vec<Vec<f64>> = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
505
506 let mut sr1: SR1<_, f64> = SR1::new(linesearch);
507
508 let state: IterState<Vec<f64>, Vec<f64>, (), Vec<Vec<f64>>, (), f64> = IterState::new()
509 .param(param)
510 .inv_hessian(inv_hessian)
511 .gradient(gradient.clone());
512
513 let problem = TestProblem::new();
514 let (mut state_out, kv) = sr1.init(&mut Problem::new(problem), state).unwrap();
515
516 assert!(kv.is_none());
517
518 let s_grad = state_out.take_gradient().unwrap();
519
520 for (s, g) in s_grad.iter().zip(gradient.iter()) {
521 assert_eq!(s.to_ne_bytes(), g.to_ne_bytes());
522 }
523 }
524}