1use crate::core::{
9 ArgminFloat, CostFunction, Error, Executor, Gradient, IterState, LineSearch,
10 OptimizationResult, Problem, Solver, State, TerminationReason, TerminationStatus, KV,
11};
12use argmin_math::{
13 ArgminAdd, ArgminDot, ArgminL1Norm, ArgminL2Norm, ArgminMinMax, ArgminMul, ArgminSignum,
14 ArgminSub, ArgminZeroLike,
15};
16#[cfg(feature = "serde1")]
17use serde::{Deserialize, Serialize};
18use std::collections::VecDeque;
19use std::marker::PhantomData;
20
21fn calculate_pseudo_gradient<P, G, F>(l1_coeff: F, param: &P, gradient: &G) -> G
23where
24 P: ArgminAdd<F, P> + ArgminSub<F, P> + ArgminMul<F, P> + ArgminSignum,
25 G: ArgminAdd<G, G> + ArgminAdd<P, G> + ArgminMinMax + ArgminZeroLike,
26 F: ArgminFloat,
27{
28 let coeff_p = param.add(&F::min_positive_value()).signum().mul(&l1_coeff);
29 let coeff_n = param.sub(&F::min_positive_value()).signum().mul(&l1_coeff);
30 let zeros = gradient.zero_like();
31 G::max(&gradient.add(&coeff_n), &zeros).add(&G::min(&gradient.add(&coeff_p), &zeros))
32}
33
34#[derive(Clone)]
77#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
78pub struct LBFGS<L, P, G, F> {
79 linesearch: L,
81 m: usize,
83 s: VecDeque<P>,
85 y: VecDeque<G>,
87 tol_grad: F,
89 tol_cost: F,
91 l1_coeff: Option<F>,
93 l1_prev_unreg_grad: Option<G>,
95}
96
97impl<L, P, G, F> LBFGS<L, P, G, F>
98where
99 F: ArgminFloat,
100{
101 pub fn new(linesearch: L, m: usize) -> Self {
111 LBFGS {
112 linesearch,
113 m,
114 s: VecDeque::with_capacity(m),
115 y: VecDeque::with_capacity(m),
116 tol_grad: F::epsilon().sqrt(),
117 tol_cost: F::epsilon(),
118 l1_coeff: None,
119 l1_prev_unreg_grad: None,
120 }
121 }
122
123 pub fn with_tolerance_grad(mut self, tol_grad: F) -> Result<Self, Error> {
139 if tol_grad < float!(0.0) {
140 return Err(argmin_error!(
141 InvalidParameter,
142 "`L-BFGS`: gradient tolerance must be >= 0."
143 ));
144 }
145 self.tol_grad = tol_grad;
146 Ok(self)
147 }
148
149 pub fn with_tolerance_cost(mut self, tol_cost: F) -> Result<Self, Error> {
165 if tol_cost < float!(0.0) {
166 return Err(argmin_error!(
167 InvalidParameter,
168 "`L-BFGS`: cost tolerance must be >= 0."
169 ));
170 }
171 self.tol_cost = tol_cost;
172 Ok(self)
173 }
174
175 pub fn with_l1_regularization(mut self, l1_coeff: F) -> Result<Self, Error> {
190 if l1_coeff <= float!(0.0) {
191 return Err(argmin_error!(
192 InvalidParameter,
193 "`L-BFGS`: coefficient of L1-regularization must be > 0."
194 ));
195 }
196 self.l1_coeff = Some(l1_coeff);
197 Ok(self)
198 }
199}
200
201struct LineSearchProblem<O, P, G, F> {
203 problem: O,
204 xi: Option<P>,
205 l1_coeff: Option<F>,
206 phantom: PhantomData<G>,
207}
208
209impl<O, P, G, F> LineSearchProblem<O, P, G, F>
210where
211 P: ArgminSub<F, P>,
212 F: ArgminFloat,
213{
214 fn new(problem: O) -> Self {
215 Self {
216 problem,
217 xi: None,
218 l1_coeff: None,
219 phantom: PhantomData,
220 }
221 }
222
223 fn with_l1_constraint(&mut self, l1_coeff: F, param: &P, pseudo_gradient: &G)
224 where
225 P: ArgminZeroLike
226 + ArgminMinMax
227 + ArgminSignum
228 + ArgminAdd<P, P>
229 + ArgminAdd<F, P>
230 + ArgminMul<P, P>
231 + ArgminSub<F, P>
232 + ArgminMul<G, P>,
233 {
234 let zeros = param.zero_like();
235 let sig_param = P::max(¶m.sub(&F::min_positive_value()).signum(), &zeros).add(&P::min(
236 ¶m.add(&F::min_positive_value()).signum(),
237 &zeros,
238 ));
239 self.xi = Some(
240 sig_param.add(
241 &sig_param
242 .mul(&sig_param)
243 .sub(&float!(1.0))
244 .mul(pseudo_gradient),
245 ),
246 );
247 self.l1_coeff = Some(l1_coeff);
248 }
249}
250
251impl<O, P, G, F> CostFunction for LineSearchProblem<O, P, G, F>
252where
253 O: CostFunction<Param = P, Output = F>,
254 P: ArgminMul<P, P> + ArgminMinMax + ArgminSignum + ArgminZeroLike + ArgminL1Norm<F>,
255 F: ArgminFloat,
256{
257 type Param = P;
258 type Output = F;
259
260 fn cost(&self, param: &Self::Param) -> Result<Self::Output, Error> {
261 if let Some(xi) = self.xi.as_ref() {
262 let zeros = param.zero_like();
263 let param = P::max(¶m.mul(xi).signum(), &zeros).mul(param);
264 let cost = self.problem.cost(¶m)?;
265 Ok(cost + self.l1_coeff.unwrap() * param.l1_norm())
266 } else {
267 self.problem.cost(param)
268 }
269 }
270}
271
272impl<O, P, G, F> Gradient for LineSearchProblem<O, P, G, F>
273where
274 O: Gradient<Param = P, Gradient = G>,
275 P: ArgminAdd<F, P>
276 + ArgminMul<P, P>
277 + ArgminMul<F, P>
278 + ArgminSub<F, P>
279 + ArgminMinMax
280 + ArgminSignum
281 + ArgminZeroLike,
282 G: ArgminAdd<P, G> + ArgminZeroLike + ArgminMinMax + ArgminAdd<G, G>,
283 F: ArgminFloat,
284{
285 type Param = P;
286 type Gradient = G;
287
288 fn gradient(&self, param: &Self::Param) -> Result<Self::Gradient, Error> {
289 if let Some(xi) = self.xi.as_ref() {
290 let zeros = param.zero_like();
291 let param = P::max(¶m.mul(xi).signum(), &zeros).mul(param);
292 let gradient = self.problem.gradient(¶m)?;
293 Ok(calculate_pseudo_gradient(
294 self.l1_coeff.unwrap(),
295 ¶m,
296 &gradient,
297 ))
298 } else {
299 self.problem.gradient(param)
300 }
301 }
302}
303
304impl<O, L, P, G, F> Solver<O, IterState<P, G, (), (), (), F>> for LBFGS<L, P, G, F>
305where
306 O: CostFunction<Param = P, Output = F> + Gradient<Param = P, Gradient = G>,
307 P: Clone
308 + ArgminSub<P, P>
309 + ArgminSub<F, P>
310 + ArgminAdd<P, P>
311 + ArgminAdd<F, P>
312 + ArgminDot<G, F>
313 + ArgminMul<F, P>
314 + ArgminMul<P, P>
315 + ArgminMul<G, P>
316 + ArgminL1Norm<F>
317 + ArgminSignum
318 + ArgminZeroLike
319 + ArgminMinMax,
320 G: Clone
321 + ArgminL2Norm<F>
322 + ArgminSub<G, G>
323 + ArgminAdd<G, G>
324 + ArgminAdd<P, G>
325 + ArgminDot<G, F>
326 + ArgminDot<P, F>
327 + ArgminMul<F, G>
328 + ArgminMul<F, P>
329 + ArgminZeroLike
330 + ArgminMinMax,
331 L: Clone
332 + LineSearch<P, F>
333 + Solver<LineSearchProblem<O, P, G, F>, IterState<P, G, (), (), (), F>>,
334 F: ArgminFloat,
335{
336 fn name(&self) -> &str {
337 "L-BFGS"
338 }
339
340 fn init(
341 &mut self,
342 problem: &mut Problem<O>,
343 mut state: IterState<P, G, (), (), (), F>,
344 ) -> Result<(IterState<P, G, (), (), (), F>, Option<KV>), Error> {
345 let param = state.take_param().ok_or_else(argmin_error_closure!(
346 NotInitialized,
347 concat!(
348 "`L-BFGS` requires an initial parameter vector. ",
349 "Please provide an initial guess via `Executor`s `configure` method."
350 )
351 ))?;
352
353 let cost = state.get_cost();
354 let cost = if cost.is_infinite() {
355 if let Some(l1_coeff) = self.l1_coeff {
356 problem.cost(¶m)? + l1_coeff * param.l1_norm()
357 } else {
358 problem.cost(¶m)?
359 }
360 } else {
361 cost
362 };
363
364 let grad = state
365 .take_gradient()
366 .map(Result::Ok)
367 .unwrap_or_else(|| problem.gradient(¶m))?;
368
369 Ok((state.param(param).cost(cost).gradient(grad), None))
370 }
371
372 fn next_iter(
373 &mut self,
374 problem: &mut Problem<O>,
375 mut state: IterState<P, G, (), (), (), F>,
376 ) -> Result<(IterState<P, G, (), (), (), F>, Option<KV>), Error> {
377 let param = state.take_param().ok_or_else(argmin_error_closure!(
378 PotentialBug,
379 "`L-BFGS`: Parameter vector in state not set."
380 ))?;
381 let cur_cost = state.get_cost();
382
383 let mut prev_grad = state.take_gradient().ok_or_else(argmin_error_closure!(
385 PotentialBug,
386 "`L-BFGS`: Gradient in state not set."
387 ))?;
388 if let Some(l1_coeff) = self.l1_coeff {
389 if self.l1_prev_unreg_grad.is_none() {
390 self.l1_prev_unreg_grad = Some(prev_grad.clone());
391 prev_grad = calculate_pseudo_gradient(l1_coeff, ¶m, &prev_grad)
392 }
393 }
394
395 let gamma: F = if let (Some(sk), Some(yk)) = (self.s.back(), self.y.back()) {
396 sk.dot(yk) / yk.dot(yk)
397 } else {
398 float!(1.0)
399 };
400
401 #[allow(clippy::redundant_clone)]
403 let mut q = prev_grad.clone();
404 let cur_m = self.s.len();
405 let mut alpha: Vec<F> = vec![float!(0.0); cur_m];
406 let mut rho: Vec<F> = vec![float!(0.0); cur_m];
407 for (i, (sk, yk)) in self.s.iter().rev().zip(self.y.iter().rev()).enumerate() {
408 let yksk: F = yk.dot(sk);
409 let rho_t = float!(1.0) / yksk;
410 let skq: F = sk.dot(&q);
411 let alpha_t = skq.mul(rho_t);
412 q = q.sub(&yk.mul(&alpha_t));
413 rho[cur_m - i - 1] = rho_t;
414 alpha[cur_m - i - 1] = alpha_t;
415 }
416 let mut r: P = q.mul(&gamma);
417 for (i, (sk, yk)) in self.s.iter().zip(self.y.iter()).enumerate() {
418 let beta: F = yk.dot(&r);
419 let beta = beta.mul(rho[i]);
420 r = r.add(&sk.mul(&(alpha[i] - beta)));
421 }
422
423 let mut line_problem = LineSearchProblem::new(problem.take_problem().unwrap());
424 let d = if let Some(l1_coeff) = self.l1_coeff {
425 line_problem.with_l1_constraint(l1_coeff, ¶m, &prev_grad);
426 let zeros = r.zero_like();
427 P::max(
428 &r.mul(&prev_grad).sub(&F::min_positive_value()).signum(),
429 &zeros,
430 )
431 .mul(&r)
432 .mul(&float!(-1.0))
433 } else {
434 r.mul(&float!(-1.0))
435 };
436
437 self.linesearch.search_direction(d);
438
439 let linesearch_result = Executor::new(line_problem, self.linesearch.clone())
441 .configure(|config| {
442 config
443 .param(param.clone())
444 .gradient(prev_grad.clone())
445 .cost(cur_cost)
446 })
447 .ctrlc(false)
448 .run();
449
450 let OptimizationResult {
451 problem: mut line_problem,
452 state: mut linesearch_state,
453 ..
454 } = match linesearch_result {
455 Ok(res) => res,
456 Err(e) => {
457 return Ok((
458 state.terminate_with(TerminationReason::SolverExit(format!(
459 "Line search terminated with: '{e}'",
460 ))),
461 Some(kv!("gamma" => gamma;)),
462 ))
463 }
464 };
465
466 let mut xk1 = linesearch_state.take_param().unwrap();
467 let next_cost = linesearch_state.get_cost();
468
469 let mut internal_line_problem = line_problem.take_problem().unwrap();
471 let xi = internal_line_problem.xi.take();
472 problem.problem = Some(internal_line_problem.problem);
473 problem.consume_func_counts(line_problem);
474 if let Some(xi) = xi {
475 let zeros = xk1.zero_like();
476 xk1 = P::max(&xk1.mul(&xi).signum(), &zeros).mul(&xk1);
477 }
478
479 if state.get_iter() >= self.m as u64 {
480 self.s.pop_front();
481 self.y.pop_front();
482 }
483
484 let grad = problem.gradient(&xk1)?;
485
486 self.s.push_back(xk1.sub(¶m));
487 let grad = if let Some(l1_coeff) = self.l1_coeff {
488 let pseudo_grad = calculate_pseudo_gradient(l1_coeff, &xk1, &grad);
490 self.y
491 .push_back(grad.sub(self.l1_prev_unreg_grad.as_ref().unwrap()));
492 self.l1_prev_unreg_grad = Some(grad);
493 pseudo_grad
494 } else {
495 self.y.push_back(grad.sub(&prev_grad));
496 grad
497 };
498
499 Ok((
500 state.param(xk1).cost(next_cost).gradient(grad),
501 Some(kv!("gamma" => gamma;)),
502 ))
503 }
504
505 fn terminate(&mut self, state: &IterState<P, G, (), (), (), F>) -> TerminationStatus {
506 if state.get_gradient().unwrap().l2_norm() < self.tol_grad {
507 return TerminationStatus::Terminated(TerminationReason::SolverConverged);
508 }
509 if (state.get_prev_cost() - state.get_cost()).abs() < self.tol_cost {
510 return TerminationStatus::Terminated(TerminationReason::SolverConverged);
511 }
512 TerminationStatus::NotTerminated
513 }
514}
515
516#[cfg(test)]
517mod tests {
518 use super::*;
519 use crate::core::{
520 test_utils::{TestProblem, TestSparseProblem},
521 ArgminError,
522 };
523 use crate::solver::linesearch::MoreThuenteLineSearch;
524
525 test_trait_impl!(
526 lbfgs,
527 LBFGS<MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64>, Vec<f64>, Vec<f64>, f64>
528 );
529
530 #[test]
531 fn test_new() {
532 #[derive(Eq, PartialEq, Debug)]
533 struct MyFakeLineSearch {}
534
535 let lbfgs: LBFGS<_, Vec<f64>, Vec<f64>, f64> = LBFGS::new(MyFakeLineSearch {}, 3);
536 let LBFGS {
537 linesearch,
538 tol_grad,
539 tol_cost,
540 m,
541 s,
542 y,
543 l1_coeff,
544 l1_prev_unreg_grad,
545 } = lbfgs;
546
547 assert_eq!(linesearch, MyFakeLineSearch {});
548 assert_eq!(tol_grad.to_ne_bytes(), f64::EPSILON.sqrt().to_ne_bytes());
549 assert_eq!(tol_cost.to_ne_bytes(), f64::EPSILON.to_ne_bytes());
550 assert_eq!(m, 3);
551 assert!(s.capacity() >= 3);
552 assert!(y.capacity() >= 3);
553 assert!(l1_coeff.is_none());
554 assert!(l1_prev_unreg_grad.is_none());
555 }
556
557 #[test]
558 fn test_with_tolerance_grad() {
559 #[derive(Eq, PartialEq, Debug, Clone, Copy)]
560 struct MyFakeLineSearch {}
561
562 for tol in [1e-6, 0.0, 1e-2, 1.0, 2.0] {
564 let lbfgs: LBFGS<_, Vec<f64>, Vec<f64>, f64> = LBFGS::new(MyFakeLineSearch {}, 3);
565 let res = lbfgs.with_tolerance_grad(tol);
566 assert!(res.is_ok());
567
568 let nm = res.unwrap();
569 assert_eq!(nm.tol_grad.to_ne_bytes(), tol.to_ne_bytes());
570 }
571
572 for tol in [-f64::EPSILON, -1.0, -100.0, -42.0] {
574 let lbfgs: LBFGS<_, Vec<f64>, Vec<f64>, f64> = LBFGS::new(MyFakeLineSearch {}, 3);
575 let res = lbfgs.with_tolerance_grad(tol);
576 assert_error!(
577 res,
578 ArgminError,
579 "Invalid parameter: \"`L-BFGS`: gradient tolerance must be >= 0.\""
580 );
581 }
582 }
583
584 #[test]
585 fn test_with_tolerance_cost() {
586 #[derive(Eq, PartialEq, Debug, Clone, Copy)]
587 struct MyFakeLineSearch {}
588
589 for tol in [1e-6, 0.0, 1e-2, 1.0, 2.0] {
591 let lbfgs: LBFGS<_, Vec<f64>, Vec<f64>, f64> = LBFGS::new(MyFakeLineSearch {}, 3);
592 let res = lbfgs.with_tolerance_cost(tol);
593 assert!(res.is_ok());
594
595 let nm = res.unwrap();
596 assert_eq!(nm.tol_cost.to_ne_bytes(), tol.to_ne_bytes());
597 }
598
599 for tol in [-f64::EPSILON, -1.0, -100.0, -42.0] {
601 let lbfgs: LBFGS<_, Vec<f64>, Vec<f64>, f64> = LBFGS::new(MyFakeLineSearch {}, 3);
602 let res = lbfgs.with_tolerance_cost(tol);
603 assert_error!(
604 res,
605 ArgminError,
606 "Invalid parameter: \"`L-BFGS`: cost tolerance must be >= 0.\""
607 );
608 }
609 }
610
611 #[test]
612 fn test_init() {
613 let linesearch = MoreThuenteLineSearch::new().with_c(1e-4, 0.9).unwrap();
614
615 let param: Vec<f64> = vec![-1.0, 1.0];
616
617 let mut lbfgs: LBFGS<_, Vec<f64>, Vec<f64>, f64> = LBFGS::new(linesearch, 3);
618
619 let state: IterState<Vec<f64>, Vec<f64>, (), (), (), f64> = IterState::new();
621 let problem = TestProblem::new();
622 let res = lbfgs.init(&mut Problem::new(problem), state);
623 assert_error!(
624 res,
625 ArgminError,
626 concat!(
627 "Not initialized: \"`L-BFGS` requires an initial parameter vector. Please ",
628 "provide an initial guess via `Executor`s `configure` method.\""
629 )
630 );
631
632 let state: IterState<Vec<f64>, Vec<f64>, (), (), (), f64> =
634 IterState::new().param(param.clone());
635 let problem = TestProblem::new();
636 let (mut state_out, kv) = lbfgs.init(&mut Problem::new(problem), state).unwrap();
637
638 assert!(kv.is_none());
639
640 let s_param = state_out.take_param().unwrap();
641
642 for (s, p) in s_param.iter().zip(param.iter()) {
643 assert_eq!(s.to_ne_bytes(), p.to_ne_bytes());
644 }
645
646 let s_grad = state_out.take_gradient().unwrap();
647
648 for (s, p) in s_grad.iter().zip(param.iter()) {
649 assert_eq!(s.to_ne_bytes(), p.to_ne_bytes());
650 }
651
652 assert_eq!(state_out.get_cost().to_ne_bytes(), 1.0f64.to_ne_bytes())
653 }
654
655 #[test]
656 fn test_init_provided_cost() {
657 let linesearch = MoreThuenteLineSearch::new().with_c(1e-4, 0.9).unwrap();
658
659 let param: Vec<f64> = vec![-1.0, 1.0];
660
661 let mut lbfgs: LBFGS<_, Vec<f64>, Vec<f64>, f64> = LBFGS::new(linesearch, 3);
662
663 let state: IterState<Vec<f64>, Vec<f64>, (), (), (), f64> =
664 IterState::new().param(param).cost(1234.0);
665
666 let problem = TestProblem::new();
667 let (state_out, kv) = lbfgs.init(&mut Problem::new(problem), state).unwrap();
668
669 assert!(kv.is_none());
670
671 assert_eq!(state_out.get_cost().to_ne_bytes(), 1234.0f64.to_ne_bytes())
672 }
673
674 #[test]
675 fn test_init_provided_grad() {
676 let linesearch = MoreThuenteLineSearch::new().with_c(1e-4, 0.9).unwrap();
677
678 let param: Vec<f64> = vec![-1.0, 1.0];
679 let gradient: Vec<f64> = vec![4.0, 9.0];
680
681 let mut lbfgs: LBFGS<_, Vec<f64>, Vec<f64>, f64> = LBFGS::new(linesearch, 3);
682
683 let state: IterState<Vec<f64>, Vec<f64>, (), (), (), f64> =
684 IterState::new().param(param).gradient(gradient.clone());
685
686 let problem = TestProblem::new();
687 let (mut state_out, kv) = lbfgs.init(&mut Problem::new(problem), state).unwrap();
688
689 assert!(kv.is_none());
690
691 let s_grad = state_out.take_gradient().unwrap();
692
693 for (s, g) in s_grad.iter().zip(gradient.iter()) {
694 assert_eq!(s.to_ne_bytes(), g.to_ne_bytes());
695 }
696 }
697
698 #[test]
699 fn test_l1_regularization() {
700 {
701 let linesearch = MoreThuenteLineSearch::new().with_c(1e-4, 0.9).unwrap();
702
703 let param: Vec<f64> = vec![0.0; 4];
704
705 let lbfgs: LBFGS<_, Vec<f64>, Vec<f64>, f64> = LBFGS::new(linesearch, 3);
706
707 let cost = TestSparseProblem::new();
708 let res = Executor::new(cost, lbfgs)
709 .configure(|state| state.param(param).max_iters(2))
710 .run()
711 .unwrap();
712
713 let result_param = res.state.param.unwrap();
714
715 assert!((result_param[0] - 0.5).abs() > 1e-6);
716 assert!((result_param[1]).abs() > 1e-6);
717 assert!((result_param[2] + 0.5).abs() > 1e-6);
718 assert!((result_param[3]).abs() > 1e-6);
719 }
720 {
721 let linesearch = MoreThuenteLineSearch::new().with_c(1e-4, 0.9).unwrap();
722
723 let param: Vec<f64> = vec![0.0; 4];
724
725 let lbfgs: LBFGS<_, Vec<f64>, Vec<f64>, f64> = LBFGS::new(linesearch, 3)
726 .with_l1_regularization(2.0)
727 .unwrap();
728
729 let cost = TestSparseProblem::new();
730 let res = Executor::new(cost, lbfgs)
731 .configure(|state| state.param(param).max_iters(2))
732 .run()
733 .unwrap();
734
735 let result_param = res.state.param.unwrap();
736 dbg!(&result_param);
737
738 assert!((result_param[0] - 0.5).abs() < 1e-6);
739 assert!((result_param[1]).abs() < 1e-6);
740 assert!((result_param[2] + 0.5).abs() < 1e-6);
741 assert!((result_param[3]).abs() < 1e-6);
742 }
743 }
744}