1#![allow(clippy::nonminimal_bool)]
11
12use crate::core::{
13 ArgminFloat, CostFunction, Error, Gradient, IterState, LineSearch, Problem, Solver, State,
14 TerminationReason, KV,
15};
16use argmin_math::{ArgminDot, ArgminScaledAdd};
17#[cfg(feature = "serde1")]
18use serde::{Deserialize, Serialize};
19
20#[derive(Clone)]
53#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
54pub struct MoreThuenteLineSearch<P, G, F> {
55 search_direction: Option<G>,
57 init_param: Option<P>,
59 finit: F,
61 init_grad: Option<G>,
63 dginit: F,
65 dgtest: F,
67 ftol: F,
69 gtol: F,
71 xtrapf: F,
73 width: F,
75 width1: F,
77 xtol: F,
79 alpha: F,
81 stpmin: F,
83 stpmax: F,
85 stp: Step<F>,
87 stx: Step<F>,
89 sty: Step<F>,
91 f: F,
93 brackt: bool,
95 stage1: bool,
97 infoc: usize,
99}
100
101#[derive(Clone, Eq, PartialEq, Debug)]
102#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
103struct Step<F> {
104 pub x: F,
105 pub fx: F,
106 pub gx: F,
107}
108
109impl<F> Step<F> {
110 pub fn new(x: F, fx: F, gx: F) -> Self {
112 Step { x, fx, gx }
113 }
114}
115
116impl<F> Default for Step<F>
117where
118 F: ArgminFloat,
119{
120 fn default() -> Self {
121 Step {
122 x: float!(0.0),
123 fx: float!(0.0),
124 gx: float!(0.0),
125 }
126 }
127}
128
129impl<P, G, F> MoreThuenteLineSearch<P, G, F>
130where
131 F: ArgminFloat,
132{
133 pub fn new() -> Self {
142 MoreThuenteLineSearch {
143 search_direction: None,
144 init_param: None,
145 finit: F::infinity(),
146 init_grad: None,
147 dginit: float!(0.0),
148 dgtest: float!(0.0),
149 ftol: float!(1e-4),
150 gtol: float!(0.9),
151 xtrapf: float!(4.0),
152 width: F::nan(),
153 width1: F::nan(),
154 xtol: float!(1e-10),
155 alpha: float!(1.0),
156 stpmin: F::epsilon().sqrt(),
157 stpmax: F::infinity(),
158 stp: Step::default(),
159 stx: Step::default(),
160 sty: Step::default(),
161 f: F::nan(),
162 brackt: false,
163 stage1: true,
164 infoc: 1,
165 }
166 }
167
168 pub fn with_c(mut self, c1: F, c2: F) -> Result<Self, Error> {
185 if c1 <= float!(0.0) || c1 >= c2 {
186 return Err(argmin_error!(
187 InvalidParameter,
188 "`MoreThuenteLineSearch`: Parameter c1 must be in (0, c2)."
189 ));
190 }
191 if c2 <= c1 || c2 >= float!(1.0) {
192 return Err(argmin_error!(
193 InvalidParameter,
194 "`MoreThuenteLineSearch`: Parameter c2 must be in (c1, 1)."
195 ));
196 }
197 self.ftol = c1;
198 self.gtol = c2;
199 Ok(self)
200 }
201
202 pub fn with_bounds(mut self, step_min: F, step_max: F) -> Result<Self, Error> {
220 if step_min < float!(0.0) {
221 return Err(argmin_error!(
222 InvalidParameter,
223 "`MoreThuenteLineSearch`: step_min must be >= 0.0."
224 ));
225 }
226 if step_max <= step_min {
227 return Err(argmin_error!(
228 InvalidParameter,
229 "`MoreThuenteLineSearch`: step_min must be smaller than step_max."
230 ));
231 }
232 self.stpmin = step_min;
233 self.stpmax = step_max;
234 Ok(self)
235 }
236
237 pub fn with_width_tolerance(mut self, xtol: F) -> Result<Self, Error> {
256 if xtol < float!(0.0) {
257 return Err(argmin_error!(
258 InvalidParameter,
259 "`MoreThuenteLineSearch`: relative width tolerance must be >= 0.0."
260 ));
261 }
262 self.xtol = xtol;
263 Ok(self)
264 }
265}
266
267impl<P, G, F> Default for MoreThuenteLineSearch<P, G, F>
268where
269 F: ArgminFloat,
270{
271 fn default() -> Self {
272 MoreThuenteLineSearch::new()
273 }
274}
275
276impl<P, G, F> LineSearch<G, F> for MoreThuenteLineSearch<P, G, F>
277where
278 F: ArgminFloat,
279{
280 fn search_direction(&mut self, search_direction: G) {
282 self.search_direction = Some(search_direction);
283 }
284
285 fn initial_step_length(&mut self, alpha: F) -> Result<(), Error> {
287 if alpha <= float!(0.0) {
288 return Err(argmin_error!(
289 InvalidParameter,
290 "MoreThuenteLineSearch: Initial alpha must be > 0."
291 ));
292 }
293 self.alpha = alpha;
294 Ok(())
295 }
296}
297
298impl<P, G, O, F> Solver<O, IterState<P, G, (), (), (), F>> for MoreThuenteLineSearch<P, G, F>
299where
300 O: CostFunction<Param = P, Output = F> + Gradient<Param = P, Gradient = G>,
301 P: Clone + ArgminDot<G, F> + ArgminScaledAdd<G, F, P>,
302 G: Clone + ArgminDot<G, F>,
303 F: ArgminFloat,
304{
305 fn name(&self) -> &str {
306 "More-Thuente Line search"
307 }
308
309 fn init(
310 &mut self,
311 problem: &mut Problem<O>,
312 mut state: IterState<P, G, (), (), (), F>,
313 ) -> Result<(IterState<P, G, (), (), (), F>, Option<KV>), Error> {
314 check_param!(
315 self.search_direction,
316 concat!(
317 "`MoreThuenteLineSearch`: Search direction not initialized. ",
318 "Call `search_direction` before executing the solver."
319 )
320 );
321
322 self.init_param = Some(state.take_param().ok_or_else(argmin_error_closure!(
323 NotInitialized,
324 concat!(
325 "`MoreThuenteLineSearch` requires an initial parameter vector. ",
326 "Please provide an initial guess via `Executor`s `configure` method."
327 )
328 ))?);
329
330 let cost = state.get_cost();
331 self.finit = if cost.is_infinite() {
332 problem.cost(self.init_param.as_ref().unwrap())?
333 } else {
334 cost
335 };
336
337 self.init_grad = Some(
338 state
339 .take_gradient()
340 .map(Result::Ok)
341 .unwrap_or_else(|| problem.gradient(self.init_param.as_ref().unwrap()))?,
342 );
343
344 self.dginit = self
345 .init_grad
346 .as_ref()
347 .unwrap()
348 .dot(self.search_direction.as_ref().unwrap());
349
350 if self.dginit >= float!(0.0) {
352 return Err(argmin_error!(
353 ConditionViolated,
354 "`MoreThuenteLineSearch`: Search direction must be a descent direction."
355 ));
356 }
357
358 self.stage1 = true;
359 self.brackt = false;
360
361 self.dgtest = self.ftol * self.dginit;
362 self.width = self.stpmax - self.stpmin;
363 self.width1 = float!(2.0) * self.width;
364 self.f = self.finit;
365
366 self.stp = Step::new(self.alpha, F::nan(), F::nan());
367 self.stx = Step::new(float!(0.0), self.finit, self.dginit);
368 self.sty = Step::new(float!(0.0), self.finit, self.dginit);
369
370 Ok((state, None))
371 }
372
373 fn next_iter(
374 &mut self,
375 problem: &mut Problem<O>,
376 state: IterState<P, G, (), (), (), F>,
377 ) -> Result<(IterState<P, G, (), (), (), F>, Option<KV>), Error> {
378 let mut info = 0;
380 let (stmin, stmax) = if self.brackt {
381 (self.stx.x.min(self.sty.x), self.stx.x.max(self.sty.x))
382 } else {
383 (
384 self.stx.x,
385 self.stp.x + self.xtrapf * (self.stp.x - self.stx.x),
386 )
387 };
388
389 self.stp.x = self.stp.x.max(self.stpmin);
391 self.stp.x = self.stp.x.min(self.stpmax);
392
393 if (self.brackt && (self.stp.x <= stmin || self.stp.x >= stmax))
396 || (self.brackt && (stmax - stmin) <= self.xtol * stmax)
397 || self.infoc == 0
398 {
399 self.stp.x = self.stx.x;
400 }
401
402 let new_param = self
404 .init_param
405 .as_ref()
406 .unwrap()
407 .scaled_add(&self.stp.x, self.search_direction.as_ref().unwrap());
408 self.f = problem.cost(&new_param)?;
409 let new_grad = problem.gradient(&new_param)?;
410 let cur_cost = self.f;
411 let cur_param = new_param;
412 let cur_grad = new_grad.clone();
413 let dg = self.search_direction.as_ref().unwrap().dot(&new_grad);
415 let ftest1 = self.finit + self.stp.x * self.dgtest;
416 if (self.brackt && (self.stp.x <= stmin || self.stp.x >= stmax)) || self.infoc == 0 {
420 info = 6;
421 }
422
423 if (self.stp.x - self.stpmax).abs() < F::epsilon() && self.f <= ftest1 && dg <= self.dgtest
424 {
425 info = 5;
426 }
427
428 if (self.stp.x - self.stpmin).abs() < F::epsilon() && (self.f > ftest1 || dg >= self.dgtest)
429 {
430 info = 4;
431 }
432
433 if self.brackt && stmax - stmin <= self.xtol * stmax {
434 info = 2;
435 }
436
437 if self.f <= ftest1 && dg.abs() <= self.gtol * (-self.dginit) {
438 info = 1;
439 }
440
441 if info != 0 {
442 return Ok((
443 state
444 .param(cur_param)
445 .cost(cur_cost)
446 .gradient(cur_grad)
447 .terminate_with(TerminationReason::SolverConverged),
448 None,
449 ));
450 }
451
452 if self.stage1 && self.f <= ftest1 && dg >= self.ftol.min(self.gtol) * self.dginit {
453 self.stage1 = false;
454 }
455
456 if self.stage1 && self.f <= self.stp.fx && self.f > ftest1 {
457 let fm = self.f - self.stp.x * self.dgtest;
458 let fxm = self.stx.fx - self.stx.x * self.dgtest;
459 let fym = self.sty.fx - self.sty.x * self.dgtest;
460 let dgm = dg - self.dgtest;
461 let dgxm = self.stx.gx - self.dgtest;
462 let dgym = self.sty.gx - self.dgtest;
463
464 let (stx1, sty1, stp1, brackt1, _stmin, _stmax, infoc) = cstep(
465 Step::new(self.stx.x, fxm, dgxm),
466 Step::new(self.sty.x, fym, dgym),
467 Step::new(self.stp.x, fm, dgm),
468 self.brackt,
469 stmin,
470 stmax,
471 )?;
472
473 self.stx.x = stx1.x;
474 self.sty.x = sty1.x;
475 self.stp.x = stp1.x;
476 self.stx.fx = self.stx.fx + stx1.x * self.dgtest;
477 self.sty.fx = self.sty.fx + sty1.x * self.dgtest;
478 self.stx.gx = self.stx.gx + self.dgtest;
479 self.sty.gx = self.sty.gx + self.dgtest;
480 self.brackt = brackt1;
481 self.stp = stp1;
482 self.infoc = infoc;
483 } else {
484 let (stx1, sty1, stp1, brackt1, _stmin, _stmax, infoc) = cstep(
485 self.stx.clone(),
486 self.sty.clone(),
487 Step::new(self.stp.x, self.f, dg),
488 self.brackt,
489 stmin,
490 stmax,
491 )?;
492 self.stx = stx1;
493 self.sty = sty1;
494 self.stp = stp1;
495 self.f = self.stp.fx;
496 self.brackt = brackt1;
498 self.infoc = infoc;
499 }
500
501 if self.brackt {
502 if (self.sty.x - self.stx.x).abs() >= float!(0.66) * self.width1 {
503 self.stp.x = self.stx.x + float!(0.5) * (self.sty.x - self.stx.x);
504 }
505 self.width1 = self.width;
506 self.width = (self.sty.x - self.stx.x).abs();
507 }
508
509 Ok((state, None))
510 }
511}
512
513type CstepReturnValue<F> = (Step<F>, Step<F>, Step<F>, bool, F, F, usize);
514
515fn cstep<F: ArgminFloat>(
516 stx: Step<F>,
517 sty: Step<F>,
518 stp: Step<F>,
519 brackt: bool,
520 stpmin: F,
521 stpmax: F,
522) -> Result<CstepReturnValue<F>, Error> {
523 let mut info: usize = 0;
524 let bound: bool;
525 let mut stpf: F;
526 let stpc: F;
527 let stpq: F;
528 let mut brackt = brackt;
529
530 if (brackt && (stp.x <= stx.x.min(sty.x) || stp.x >= stx.x.max(sty.x)))
532 || stx.gx * (stp.x - stx.x) >= float!(0.0)
533 || stpmax < stpmin
534 {
535 return Ok((stx, sty, stp, brackt, stpmin, stpmax, info));
536 }
537
538 let sgnd = stp.gx * (stx.gx / stx.gx.abs());
540
541 if stp.fx > stx.fx {
542 info = 1;
546 bound = true;
547 let theta = float!(3.0) * (stx.fx - stp.fx) / (stp.x - stx.x) + stx.gx + stp.gx;
548 let tmp = [theta, stx.gx, stp.gx];
549 if tmp.iter().any(|n| n.is_nan() || n.is_infinite()) {
551 return Err(argmin_error!(
552 ConditionViolated,
553 "MoreThuenteLineSearch: NaN or Inf encountered during iteration"
554 ));
555 }
556 let s = tmp.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
557 let mut gamma = *s * ((theta / *s).powi(2) - (stx.gx / *s) * (stp.gx / *s)).sqrt();
558 if stp.x < stx.x {
559 gamma = -gamma;
560 }
561
562 let p = (gamma - stx.gx) + theta;
563 let q = ((gamma - stx.gx) + gamma) + stp.gx;
564 let r = p / q;
565 stpc = stx.x + r * (stp.x - stx.x);
566 stpq = stx.x
567 + ((stx.gx / ((stx.fx - stp.fx) / (stp.x - stx.x) + stx.gx)) / float!(2.0))
568 * (stp.x - stx.x);
569 if (stpc - stx.x).abs() < (stpq - stx.x).abs() {
570 stpf = stpc;
571 } else {
572 stpf = stpc + (stpq - stpc) / float!(2.0);
573 }
574 brackt = true;
575 } else if sgnd < float!(0.0) {
576 info = 2;
580 bound = false;
581 let theta = float!(3.0) * (stx.fx - stp.fx) / (stp.x - stx.x) + stx.gx + stp.gx;
582 let tmp = [theta, stx.gx, stp.gx];
583 if tmp.iter().any(|n| n.is_nan() || n.is_infinite()) {
585 return Err(argmin_error!(
586 ConditionViolated,
587 "MoreThuenteLineSearch: NaN or Inf encountered during iteration"
588 ));
589 }
590 let s = tmp.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
591 let mut gamma = *s * ((theta / *s).powi(2) - (stx.gx / *s) * (stp.gx / *s)).sqrt();
592 if stp.x > stx.x {
593 gamma = -gamma;
594 }
595 let p = (gamma - stp.gx) + theta;
596 let q = ((gamma - stp.gx) + gamma) + stx.gx;
597 let r = p / q;
598 stpc = stp.x + r * (stx.x - stp.x);
599 stpq = stp.x + (stp.gx / (stp.gx - stx.gx)) * (stx.x - stp.x);
600 if (stpc - stp.x).abs() > (stpq - stp.x).abs() {
601 stpf = stpc;
602 } else {
603 stpf = stpq;
604 }
605 brackt = true;
606 } else if stp.gx.abs() < stx.gx.abs() {
607 info = 3;
614 bound = true;
615 let theta = float!(3.0) * (stx.fx - stp.fx) / (stp.x - stx.x) + stx.gx + stp.gx;
616 let tmp = [theta, stx.gx, stp.gx];
617 if tmp.iter().any(|n| n.is_nan() || n.is_infinite()) {
619 return Err(argmin_error!(
620 ConditionViolated,
621 "`MoreThuenteLineSearch`: NaN or Inf encountered during iteration"
622 ));
623 }
624 let s = tmp.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
625 let mut gamma = *s
629 * float!(0.0)
630 .max((theta / *s).powi(2) - (stx.gx / *s) * (stp.gx / *s))
631 .sqrt();
632 if stp.x > stx.x {
633 gamma = -gamma;
634 }
635
636 let p = (gamma - stp.gx) + theta;
637 let q = (gamma + (stx.gx - stp.gx)) + gamma;
638 let r = p / q;
639 if r < float!(0.0) && gamma != float!(0.0) {
640 stpc = stp.x + r * (stx.x - stp.x);
641 } else if stp.x > stx.x {
642 stpc = stpmax;
643 } else {
644 stpc = stpmin;
645 }
646 stpq = stp.x + (stp.gx / (stp.gx - stx.gx)) * (stx.x - stp.x);
647 if brackt {
648 if (stp.x - stpc).abs() < (stp.x - stpq).abs() {
649 stpf = stpc;
650 } else {
651 stpf = stpq;
652 }
653 } else if (stp.x - stpc).abs() > (stp.x - stpq).abs() {
654 stpf = stpc;
655 } else {
656 stpf = stpq;
657 }
658 } else {
659 info = 4;
663 bound = false;
664 if brackt {
665 let theta = float!(3.0) * (stp.fx - sty.fx) / (sty.x - stp.x) + sty.gx + stp.gx;
666 let tmp = [theta, sty.gx, stp.gx];
667 if tmp.iter().any(|n| n.is_nan() || n.is_infinite()) {
669 return Err(argmin_error!(
670 ConditionViolated,
671 "MoreThuenteLineSearch: NaN or Inf encountered during iteration"
672 ));
673 }
674 let s = tmp.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
675 let mut gamma = *s * ((theta / *s).powi(2) - (sty.gx / *s) * (stp.gx / *s)).sqrt();
676 if stp.x > sty.x {
677 gamma = -gamma;
678 }
679 let p = (gamma - stp.gx) + theta;
680 let q = ((gamma - stp.gx) + gamma) + sty.gx;
681 let r = p / q;
682 stpc = stp.x + r * (sty.x - stp.x);
683 stpf = stpc;
684 } else if stp.x > stx.x {
685 stpf = stpmax;
686 } else {
687 stpf = stpmin;
688 }
689 }
690 let mut stx_o = stx;
694 let mut sty_o = sty;
695 let mut stp_o = stp;
696 if stp_o.fx > stx_o.fx {
697 sty_o = Step::new(stp_o.x, stp_o.fx, stp_o.gx);
698 } else {
699 if sgnd < float!(0.0) {
700 sty_o = Step::new(stx_o.x, stx_o.fx, stx_o.gx);
701 }
702 stx_o = Step::new(stp_o.x, stp_o.fx, stp_o.gx);
703 }
704
705 stpf = stpmax.min(stpf);
708 stpf = stpmin.max(stpf);
709
710 stp_o.x = stpf;
711 if brackt && bound {
712 if sty_o.x > stx_o.x {
713 stp_o.x = stp_o.x.min(stx_o.x + float!(0.66) * (sty_o.x - stx_o.x));
714 } else {
715 stp_o.x = stp_o.x.max(stx_o.x + float!(0.66) * (sty_o.x - stx_o.x));
716 }
717 }
718
719 Ok((stx_o, sty_o, stp_o, brackt, stpmin, stpmax, info))
720}
721
722#[cfg(test)]
723mod tests {
724 use super::*;
725 use crate::core::{test_utils::TestProblem, ArgminError};
726
727 test_trait_impl!(morethuente, MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64>);
728
729 #[test]
730 fn test_new() {
731 let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
732 let MoreThuenteLineSearch {
733 search_direction,
734 init_param,
735 finit,
736 init_grad,
737 dginit,
738 dgtest,
739 ftol,
740 gtol,
741 xtrapf,
742 width,
743 width1,
744 xtol,
745 alpha,
746 stpmin,
747 stpmax,
748 stp,
749 stx,
750 sty,
751 f,
752 brackt,
753 stage1,
754 infoc,
755 } = mtls;
756
757 assert!(search_direction.is_none());
758 assert!(init_param.is_none());
759 assert!(finit.is_infinite());
760 assert!(finit.is_sign_positive());
761 assert!(init_grad.is_none());
762 assert_eq!(dginit.to_ne_bytes(), 0.0f64.to_ne_bytes());
763 assert_eq!(dgtest.to_ne_bytes(), 0.0f64.to_ne_bytes());
764 assert_eq!(ftol.to_ne_bytes(), 1e-4f64.to_ne_bytes());
765 assert_eq!(gtol.to_ne_bytes(), 0.9f64.to_ne_bytes());
766 assert_eq!(xtrapf.to_ne_bytes(), 4.0f64.to_ne_bytes());
767 assert!(width.is_nan());
768 assert!(width1.is_nan());
769 assert_eq!(xtol.to_ne_bytes(), 1e-10f64.to_ne_bytes());
770 assert_eq!(alpha.to_ne_bytes(), 1.0f64.to_ne_bytes());
771 assert_eq!(stpmin.to_ne_bytes(), f64::EPSILON.sqrt().to_ne_bytes());
772 assert!(stpmax.is_infinite());
773 assert!(stpmax.is_sign_positive());
774 assert_eq!(stp, Step::default());
775 assert_eq!(stx, Step::default());
776 assert_eq!(sty, Step::default());
777 assert!(f.is_nan());
778 assert!(!brackt);
779 assert!(stage1);
780 assert_eq!(infoc, 1);
781 }
782
783 #[test]
784 fn test_with_c_correct() {
785 let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
786 let res = mtls.with_c(0.1, 0.9);
787 assert!(res.is_ok());
788
789 let mtls = res.unwrap();
790 assert_eq!(mtls.ftol.to_ne_bytes(), 0.1f64.to_ne_bytes());
791 assert_eq!(mtls.gtol.to_ne_bytes(), 0.9f64.to_ne_bytes());
792 }
793
794 #[test]
795 fn test_with_c_c1_larger_than_c2() {
796 let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
797 let res = mtls.with_c(0.9, 0.1);
798 assert_error!(
799 res,
800 ArgminError,
801 concat!(
802 "Invalid parameter: \"`MoreThuenteLineSearch`: ",
803 "Parameter c1 must be in (0, c2).\""
804 )
805 );
806 }
807
808 #[test]
809 fn test_with_c_c1_smaller_than_0() {
810 let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
811 let res = mtls.with_c(-0.9, 0.99);
812 assert_error!(
813 res,
814 ArgminError,
815 concat!(
816 "Invalid parameter: \"`MoreThuenteLineSearch`: ",
817 "Parameter c1 must be in (0, c2).\""
818 )
819 );
820 }
821
822 #[test]
823 fn test_with_c_c2_larger_than_1() {
824 let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
825 let res = mtls.with_c(0.1, 1.01);
826 assert_error!(
827 res,
828 ArgminError,
829 concat!(
830 "Invalid parameter: \"`MoreThuenteLineSearch`: ",
831 "Parameter c2 must be in (c1, 1).\""
832 )
833 );
834 }
835
836 #[test]
837 fn test_with_bounds_correct() {
838 let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
839 let res = mtls.with_bounds(0.1, 0.9);
840 assert!(res.is_ok());
841
842 let mtls = res.unwrap();
843 assert_eq!(mtls.stpmin.to_ne_bytes(), 0.1f64.to_ne_bytes());
844 assert_eq!(mtls.stpmax.to_ne_bytes(), 0.9f64.to_ne_bytes());
845 }
846
847 #[test]
848 fn test_with_bounds_step_min_smaller_than_0() {
849 let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
850 let res = mtls.with_bounds(-0.1, 0.99);
851 assert_error!(
852 res,
853 ArgminError,
854 concat!(
855 "Invalid parameter: \"`MoreThuenteLineSearch`: ",
856 "step_min must be >= 0.0.\""
857 )
858 );
859 }
860
861 #[test]
862 fn test_with_bounds_step_min_larger_than_step_max() {
863 let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
864 let res = mtls.with_bounds(10.0, 0.99);
865 assert_error!(
866 res,
867 ArgminError,
868 concat!(
869 "Invalid parameter: \"`MoreThuenteLineSearch`: ",
870 "step_min must be smaller than step_max.\""
871 )
872 );
873 }
874
875 #[test]
876 fn test_with_width_tolerance_correct() {
877 let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
878 let res = mtls.with_width_tolerance(1e-9);
879 assert!(res.is_ok());
880
881 let mtls = res.unwrap();
882 assert_eq!(mtls.xtol.to_ne_bytes(), 1e-9f64.to_ne_bytes());
883 }
884
885 #[test]
886 fn test_with_width_tolerance_negative_xtol() {
887 let mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
888 let res = mtls.with_width_tolerance(-1e-10);
889 assert_error!(
890 res,
891 ArgminError,
892 concat!(
893 "Invalid parameter: \"`MoreThuenteLineSearch`: ",
894 "relative width tolerance must be >= 0.0.\""
895 )
896 );
897 }
898
899 #[test]
900 fn test_init_search_direction_not_set() {
901 let mut mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
902 let res = mtls.init(&mut Problem::new(TestProblem::new()), IterState::new());
903 assert_error!(
904 res,
905 ArgminError,
906 concat!(
907 "Not initialized: \"`MoreThuenteLineSearch`: Search direction not initialized. ",
908 "Call `search_direction` before executing the solver.\""
909 )
910 );
911 }
912
913 #[test]
914 fn test_init_param_not_set() {
915 let mut mtls: MoreThuenteLineSearch<Vec<f64>, Vec<f64>, f64> = MoreThuenteLineSearch::new();
916 mtls.search_direction(vec![1.0f64]);
917 let res = mtls.init(&mut Problem::new(TestProblem::new()), IterState::new());
918 assert_error!(
919 res,
920 ArgminError,
921 concat!(
922 "Not initialized: \"`MoreThuenteLineSearch` requires an initial parameter vector. ",
923 "Please provide an initial guess via `Executor`s `configure` method.\""
924 )
925 );
926 }
927}