1use crate::core::{
9 ArgminFloat, CostFunction, Error, Executor, Gradient, IterState, LineSearch,
10 OptimizationResult, Problem, Solver, TerminationReason, TerminationStatus, KV,
11};
12use argmin_math::{
13 ArgminAdd, ArgminDot, ArgminEye, ArgminL2Norm, ArgminMul, ArgminSub, ArgminTranspose,
14};
15#[cfg(feature = "serde1")]
16use serde::{Deserialize, Serialize};
17
18#[derive(Clone)]
49#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
50pub struct BFGS<L, F> {
51 linesearch: L,
53 tol_grad: F,
55 tol_cost: F,
57}
58
59impl<L, F> BFGS<L, F>
60where
61 F: ArgminFloat,
62{
63 pub fn new(linesearch: L) -> Self {
73 BFGS {
74 linesearch,
75 tol_grad: F::epsilon().sqrt(),
76 tol_cost: F::epsilon(),
77 }
78 }
79
80 pub fn with_tolerance_grad(mut self, tol_grad: F) -> Result<Self, Error> {
96 if tol_grad < float!(0.0) {
97 return Err(argmin_error!(
98 InvalidParameter,
99 "`BFGS`: gradient tolerance must be >= 0."
100 ));
101 }
102 self.tol_grad = tol_grad;
103 Ok(self)
104 }
105
106 pub fn with_tolerance_cost(mut self, tol_cost: F) -> Result<Self, Error> {
122 if tol_cost < float!(0.0) {
123 return Err(argmin_error!(
124 InvalidParameter,
125 "`BFGS`: cost tolerance must be >= 0."
126 ));
127 }
128 self.tol_cost = tol_cost;
129 Ok(self)
130 }
131}
132
133impl<O, L, P, G, H, F> Solver<O, IterState<P, G, (), H, (), F>> for BFGS<L, F>
134where
135 O: CostFunction<Param = P, Output = F> + Gradient<Param = P, Gradient = G>,
136 P: Clone + ArgminSub<P, P> + ArgminDot<G, H> + ArgminDot<P, H>,
137 G: Clone
138 + ArgminL2Norm<F>
139 + ArgminMul<F, P>
140 + ArgminMul<F, G>
141 + ArgminDot<P, F>
142 + ArgminSub<G, G>,
143 H: ArgminSub<H, H>
144 + ArgminDot<G, G>
145 + ArgminDot<H, H>
146 + ArgminAdd<H, H>
147 + ArgminMul<F, H>
148 + ArgminTranspose<H>
149 + ArgminEye,
150 L: Clone + LineSearch<G, F> + Solver<O, IterState<P, G, (), (), (), F>>,
151 F: ArgminFloat,
152{
153 fn name(&self) -> &str {
154 "BFGS"
155 }
156
157 fn init(
158 &mut self,
159 problem: &mut Problem<O>,
160 mut state: IterState<P, G, (), H, (), F>,
161 ) -> Result<(IterState<P, G, (), H, (), F>, Option<KV>), Error> {
162 let param = state.take_param().ok_or_else(argmin_error_closure!(
163 NotInitialized,
164 concat!(
165 "`BFGS` requires an initial parameter vector. ",
166 "Please provide an initial guess via `Executor`s `configure` method."
167 )
168 ))?;
169
170 let inv_hessian = state.take_inv_hessian().ok_or_else(argmin_error_closure!(
171 NotInitialized,
172 concat!(
173 "`BFGS` requires an initial inverse Hessian. ",
174 "Please provide an initial guess via `Executor`s `configure` method."
175 )
176 ))?;
177
178 let cost = state.get_cost();
179 let cost = if cost.is_infinite() {
180 problem.cost(¶m)?
181 } else {
182 cost
183 };
184
185 let grad = state
186 .take_gradient()
187 .map(Result::Ok)
188 .unwrap_or_else(|| problem.gradient(¶m))?;
189
190 Ok((
191 state
192 .param(param)
193 .cost(cost)
194 .gradient(grad)
195 .inv_hessian(inv_hessian),
196 None,
197 ))
198 }
199
200 fn next_iter(
201 &mut self,
202 problem: &mut Problem<O>,
203 mut state: IterState<P, G, (), H, (), F>,
204 ) -> Result<(IterState<P, G, (), H, (), F>, Option<KV>), Error> {
205 let param = state.take_param().ok_or_else(argmin_error_closure!(
206 PotentialBug,
207 "`BFGS`: Parameter vector in state not set."
208 ))?;
209
210 let cur_cost = state.get_cost();
211
212 let prev_grad = state.take_gradient().ok_or_else(argmin_error_closure!(
213 PotentialBug,
214 "`BFGS`: Gradient in state not set."
215 ))?;
216
217 let inv_hessian = state.take_inv_hessian().ok_or_else(argmin_error_closure!(
218 PotentialBug,
219 "`BFGS`: Inverse Hessian in state not set."
220 ))?;
221
222 let g: G = inv_hessian.dot(&prev_grad).mul(&float!(-1.0));
223
224 self.linesearch.search_direction(g);
225
226 let OptimizationResult {
228 problem: line_problem,
229 state: mut sub_state,
230 ..
231 } = Executor::new(problem.take_problem().unwrap(), self.linesearch.clone())
232 .configure(|config| {
233 config
234 .param(param.clone())
235 .gradient(prev_grad.clone())
236 .cost(cur_cost)
237 })
238 .ctrlc(false)
239 .run()?;
240
241 let xk1 = sub_state.take_param().ok_or_else(argmin_error_closure!(
242 PotentialBug,
243 "`BFGS`: No parameters returned by line search."
244 ))?;
245
246 let next_cost = sub_state.get_cost();
247
248 problem.consume_problem(line_problem);
250
251 let grad = problem.gradient(&xk1)?;
252
253 let yk = grad.sub(&prev_grad);
254
255 let sk = xk1.sub(¶m);
256
257 let yksk: F = yk.dot(&sk);
258 let rhok = float!(1.0) / yksk;
259
260 let e = inv_hessian.eye_like();
261 let mat1: H = sk.dot(&yk);
262 let mat1 = mat1.mul(&rhok);
263
264 let tmp1 = e.sub(&mat1);
265
266 let mat2 = mat1.t();
267 let tmp2 = e.sub(&mat2);
268
269 let sksk: H = sk.dot(&sk);
270 let sksk = sksk.mul(&rhok);
271
272 let inv_hessian = tmp1.dot(&inv_hessian.dot(&tmp2)).add(&sksk);
279
280 Ok((
281 state
282 .param(xk1)
283 .cost(next_cost)
284 .gradient(grad)
285 .inv_hessian(inv_hessian),
286 None,
287 ))
288 }
289
290 fn terminate(&mut self, state: &IterState<P, G, (), H, (), F>) -> TerminationStatus {
291 if state.get_gradient().unwrap().l2_norm() < self.tol_grad {
292 return TerminationStatus::Terminated(TerminationReason::SolverConverged);
293 }
294 if (state.get_prev_cost() - state.cost).abs() < self.tol_cost {
295 return TerminationStatus::Terminated(TerminationReason::SolverConverged);
296 }
297 TerminationStatus::NotTerminated
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304 use crate::core::{test_utils::TestProblem, ArgminError, State};
305 use crate::solver::linesearch::MoreThuenteLineSearch;
306
307 test_trait_impl!(
308 bfgs,
309 BFGS<MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64>, f64>
310 );
311
312 #[test]
313 fn test_new() {
314 #[derive(Eq, PartialEq, Debug)]
315 struct MyFakeLineSearch {}
316
317 let bfgs: BFGS<_, f64> = BFGS::new(MyFakeLineSearch {});
318 let BFGS {
319 linesearch,
320 tol_grad,
321 tol_cost,
322 } = bfgs;
323
324 assert_eq!(linesearch, MyFakeLineSearch {});
325 assert_eq!(tol_grad.to_ne_bytes(), f64::EPSILON.sqrt().to_ne_bytes());
326 assert_eq!(tol_cost.to_ne_bytes(), f64::EPSILON.to_ne_bytes());
327 }
328
329 #[test]
330 fn test_with_tolerance_grad() {
331 #[derive(Eq, PartialEq, Debug, Clone, Copy)]
332 struct MyFakeLineSearch {}
333
334 for tol in [1e-6, 0.0, 1e-2, 1.0, 2.0] {
336 let bfgs: BFGS<_, f64> = BFGS::new(MyFakeLineSearch {});
337 let res = bfgs.with_tolerance_grad(tol);
338 assert!(res.is_ok());
339
340 let nm = res.unwrap();
341 assert_eq!(nm.tol_grad.to_ne_bytes(), tol.to_ne_bytes());
342 }
343
344 for tol in [-f64::EPSILON, -1.0, -100.0, -42.0] {
346 let bfgs: BFGS<_, f64> = BFGS::new(MyFakeLineSearch {});
347 let res = bfgs.with_tolerance_grad(tol);
348 assert_error!(
349 res,
350 ArgminError,
351 "Invalid parameter: \"`BFGS`: gradient tolerance must be >= 0.\""
352 );
353 }
354 }
355
356 #[test]
357 fn test_with_tolerance_cost() {
358 #[derive(Eq, PartialEq, Debug, Clone, Copy)]
359 struct MyFakeLineSearch {}
360
361 for tol in [1e-6, 0.0, 1e-2, 1.0, 2.0] {
363 let bfgs: BFGS<_, f64> = BFGS::new(MyFakeLineSearch {});
364 let res = bfgs.with_tolerance_cost(tol);
365 assert!(res.is_ok());
366
367 let nm = res.unwrap();
368 assert_eq!(nm.tol_cost.to_ne_bytes(), tol.to_ne_bytes());
369 }
370
371 for tol in [-f64::EPSILON, -1.0, -100.0, -42.0] {
373 let bfgs: BFGS<_, f64> = BFGS::new(MyFakeLineSearch {});
374 let res = bfgs.with_tolerance_cost(tol);
375 assert_error!(
376 res,
377 ArgminError,
378 "Invalid parameter: \"`BFGS`: cost tolerance must be >= 0.\""
379 );
380 }
381 }
382
383 #[test]
384 fn test_init() {
385 let linesearch = MoreThuenteLineSearch::new().with_c(1e-4, 0.9).unwrap();
386
387 let param: Vec<f64> = vec![-1.0, 1.0];
388 let inv_hessian: Vec<Vec<f64>> = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
389
390 let mut bfgs: BFGS<_, f64> = BFGS::new(linesearch);
391
392 let state: IterState<Vec<f64>, Vec<f64>, (), Vec<Vec<f64>>, (), f64> = IterState::new();
394 let problem = TestProblem::new();
395 let res = bfgs.init(&mut Problem::new(problem), state);
396 assert_error!(
397 res,
398 ArgminError,
399 concat!(
400 "Not initialized: \"`BFGS` requires an initial parameter vector. Please ",
401 "provide an initial guess via `Executor`s `configure` method.\""
402 )
403 );
404
405 let state: IterState<Vec<f64>, Vec<f64>, (), Vec<Vec<f64>>, (), f64> =
407 IterState::new().param(param.clone());
408 let problem = TestProblem::new();
409 let res = bfgs.init(&mut Problem::new(problem), state);
410
411 assert_error!(
412 res,
413 ArgminError,
414 concat!(
415 "Not initialized: \"`BFGS` requires an initial inverse Hessian. Please ",
416 "provide an initial guess via `Executor`s `configure` method.\""
417 )
418 );
419
420 let state: IterState<Vec<f64>, Vec<f64>, (), Vec<Vec<f64>>, (), f64> = IterState::new()
422 .param(param.clone())
423 .inv_hessian(inv_hessian.clone());
424 let problem = TestProblem::new();
425 let (mut state_out, kv) = bfgs.init(&mut Problem::new(problem), state).unwrap();
426
427 assert!(kv.is_none());
428
429 let s_param = state_out.take_param().unwrap();
430
431 for (s, p) in s_param.iter().zip(param.iter()) {
432 assert_eq!(s.to_ne_bytes(), p.to_ne_bytes());
433 }
434
435 let s_grad = state_out.take_gradient().unwrap();
436
437 for (s, p) in s_grad.iter().zip(param.iter()) {
438 assert_eq!(s.to_ne_bytes(), p.to_ne_bytes());
439 }
440
441 let s_inv_hessian = state_out.take_inv_hessian().unwrap();
442
443 for (s, h) in s_inv_hessian
444 .iter()
445 .flatten()
446 .zip(inv_hessian.iter().flatten())
447 {
448 assert_eq!(s.to_ne_bytes(), h.to_ne_bytes());
449 }
450
451 assert_eq!(state_out.get_cost().to_ne_bytes(), 1.0f64.to_ne_bytes())
452 }
453
454 #[test]
455 fn test_init_provided_cost() {
456 let linesearch = MoreThuenteLineSearch::new().with_c(1e-4, 0.9).unwrap();
457
458 let param: Vec<f64> = vec![-1.0, 1.0];
459 let inv_hessian: Vec<Vec<f64>> = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
460
461 let mut bfgs: BFGS<_, f64> = BFGS::new(linesearch);
462
463 let state: IterState<Vec<f64>, Vec<f64>, (), Vec<Vec<f64>>, (), f64> = IterState::new()
464 .param(param)
465 .inv_hessian(inv_hessian)
466 .cost(1234.0);
467
468 let problem = TestProblem::new();
469 let (state_out, kv) = bfgs.init(&mut Problem::new(problem), state).unwrap();
470
471 assert!(kv.is_none());
472
473 assert_eq!(state_out.get_cost().to_ne_bytes(), 1234.0f64.to_ne_bytes())
474 }
475
476 #[test]
477 fn test_init_provided_grad() {
478 let linesearch = MoreThuenteLineSearch::new().with_c(1e-4, 0.9).unwrap();
479
480 let param: Vec<f64> = vec![-1.0, 1.0];
481 let gradient: Vec<f64> = vec![4.0, 9.0];
482 let inv_hessian: Vec<Vec<f64>> = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
483
484 let mut bfgs: BFGS<_, f64> = BFGS::new(linesearch);
485
486 let state: IterState<Vec<f64>, Vec<f64>, (), Vec<Vec<f64>>, (), f64> = IterState::new()
487 .param(param)
488 .inv_hessian(inv_hessian)
489 .gradient(gradient.clone());
490
491 let problem = TestProblem::new();
492 let (mut state_out, kv) = bfgs.init(&mut Problem::new(problem), state).unwrap();
493
494 assert!(kv.is_none());
495
496 let s_grad = state_out.take_gradient().unwrap();
497
498 for (s, g) in s_grad.iter().zip(gradient.iter()) {
499 assert_eq!(s.to_ne_bytes(), g.to_ne_bytes());
500 }
501 }
502}