1use crate::core::{
22 ArgminFloat, CostFunction, Error, IterState, Problem, Solver, TerminationReason,
23 TerminationStatus, KV,
24};
25use argmin_math::{ArgminAdd, ArgminMul, ArgminSub};
26#[cfg(feature = "serde1")]
27use serde::{Deserialize, Serialize};
28use std::fmt;
29
30#[derive(Clone)]
62#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
63pub struct NelderMead<P, F> {
64 alpha: F,
66 gamma: F,
68 rho: F,
70 sigma: F,
72 params: Vec<(P, F)>,
74 sd_tolerance: F,
76}
77
78impl<P, F> NelderMead<P, F>
79where
80 P: Clone + ArgminAdd<P, P> + ArgminSub<P, P> + ArgminMul<F, P>,
81 F: ArgminFloat,
82{
83 pub fn new(params: Vec<P>) -> Self {
96 NelderMead {
97 alpha: float!(1.0),
98 gamma: float!(2.0),
99 rho: float!(0.5),
100 sigma: float!(0.5),
101 params: params.into_iter().map(|p| (p, F::nan())).collect(),
102 sd_tolerance: F::epsilon(),
103 }
104 }
105
106 pub fn with_sd_tolerance(mut self, tol: F) -> Result<Self, Error> {
123 if tol < float!(0.0) {
124 return Err(argmin_error!(
125 InvalidParameter,
126 "`Nelder-Mead`: sd_tolerance must be >= 0."
127 ));
128 }
129 self.sd_tolerance = tol;
130 Ok(self)
131 }
132
133 pub fn with_alpha(mut self, alpha: F) -> Result<Self, Error> {
150 if alpha <= float!(0.0) {
151 return Err(argmin_error!(
152 InvalidParameter,
153 "`Nelder-Mead`: alpha must be > 0."
154 ));
155 }
156 self.alpha = alpha;
157 Ok(self)
158 }
159
160 pub fn with_gamma(mut self, gamma: F) -> Result<Self, Error> {
177 if gamma <= float!(1.0) {
178 return Err(argmin_error!(
179 InvalidParameter,
180 "`Nelder-Mead`: gamma must be > 1."
181 ));
182 }
183 self.gamma = gamma;
184 Ok(self)
185 }
186
187 pub fn with_rho(mut self, rho: F) -> Result<Self, Error> {
204 if rho <= float!(0.0) || rho > float!(0.5) {
205 return Err(argmin_error!(
206 InvalidParameter,
207 "`Nelder-Mead`: rho must be in (0, 0.5]."
208 ));
209 }
210 self.rho = rho;
211 Ok(self)
212 }
213
214 pub fn with_sigma(mut self, sigma: F) -> Result<Self, Error> {
231 if sigma <= float!(0.0) || sigma > float!(1.0) {
232 return Err(argmin_error!(
233 InvalidParameter,
234 "`Nelder-Mead`: sigma must be in (0, 1]."
235 ));
236 }
237 self.sigma = sigma;
238 Ok(self)
239 }
240
241 fn sort_param_vecs(&mut self) {
243 self.params
244 .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
245 }
246
247 fn calculate_centroid(&self) -> P {
249 let num_param = self.params.len() - 1;
251 self.params
252 .iter()
253 .take(num_param)
255 .skip(1)
257 .fold(self.params[0].0.clone(), |acc, p| acc.add(&p.0))
259 .mul(&(float!(1.0) / (float!(num_param as f64))))
261 }
262
263 fn reflect(&self, x0: &P, x: &P) -> P {
265 x0.add(&x0.sub(x).mul(&self.alpha))
266 }
267
268 fn expand(&self, x0: &P, x: &P) -> P {
270 x0.add(&x.sub(x0).mul(&self.gamma))
271 }
272
273 fn contract(&self, x0: &P, x: &P) -> P {
275 x0.add(&x.sub(x0).mul(&self.rho))
276 }
277
278 fn shrink<S>(&mut self, mut cost: S) -> Result<(), Error>
280 where
281 S: FnMut(&P) -> Result<F, Error>,
282 {
283 let x0 = self.params[0].0.clone();
285 self.params
286 .iter_mut()
287 .skip(1)
289 .try_for_each(|(p, c)| -> Result<(), Error> {
290 *p = x0.add(&p.sub(&x0).mul(&self.sigma));
291 *c = (cost)(p)?;
292 Ok(())
293 })?;
294 Ok(())
295 }
296}
297
298#[derive(Debug)]
299enum Action {
300 Reflection,
301 Expansion,
302 ContractionOutside,
303 ContractionInside,
304 Shrink,
305}
306
307impl fmt::Display for Action {
308 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
309 match *self {
310 Action::Reflection => write!(f, "Reflection"),
311 Action::Expansion => write!(f, "Expansion"),
312 Action::ContractionOutside => write!(f, "ContractionOutside"),
313 Action::ContractionInside => write!(f, "ContractionInside"),
314 Action::Shrink => write!(f, "Shrink"),
315 }
316 }
317}
318
319impl<O, P, F> Solver<O, IterState<P, (), (), (), (), F>> for NelderMead<P, F>
320where
321 O: CostFunction<Param = P, Output = F>,
322 P: Clone + ArgminSub<P, P> + ArgminAdd<P, P> + ArgminMul<F, P>,
323 F: ArgminFloat + std::iter::Sum<F>,
324{
325 fn name(&self) -> &str {
326 "Nelder-Mead method"
327 }
328
329 fn init(
330 &mut self,
331 problem: &mut Problem<O>,
332 state: IterState<P, (), (), (), (), F>,
333 ) -> Result<(IterState<P, (), (), (), (), F>, Option<KV>), Error> {
334 self.params
335 .iter_mut()
336 .for_each(|(p, c)| *c = problem.cost(p).unwrap());
337
338 self.sort_param_vecs();
339
340 Ok((
341 state.param(self.params[0].0.clone()).cost(self.params[0].1),
342 None,
343 ))
344 }
345
346 fn next_iter(
347 &mut self,
348 problem: &mut Problem<O>,
349 state: IterState<P, (), (), (), (), F>,
350 ) -> Result<(IterState<P, (), (), (), (), F>, Option<KV>), Error> {
351 let num_param_vecs = self.params.len();
352
353 let x0 = self.calculate_centroid();
354
355 let p_best = &self.params[0];
356 let p_worst = &self.params[num_param_vecs - 1];
357 let p_second_worst = &self.params[num_param_vecs - 2];
358
359 let xr = self.reflect(&x0, &p_worst.0);
360 let xr_cost = problem.cost(&xr)?;
361
362 let action = if xr_cost < p_second_worst.1 && xr_cost >= p_best.1 {
363 *self.params.last_mut().unwrap() = (xr, xr_cost);
365 Action::Reflection
366 } else if xr_cost < p_best.1 {
367 let xe = self.expand(&x0, &xr);
369 let xe_cost = problem.cost(&xe)?;
370 *self.params.last_mut().unwrap() = if xe_cost < xr_cost {
371 (xe, xe_cost)
372 } else {
373 (xr, xr_cost)
374 };
375 Action::Expansion
376 } else if xr_cost >= p_second_worst.1 {
377 if xr_cost < p_worst.1 {
379 let xc = self.contract(&x0, &xr);
381 let xc_cost = problem.cost(&xc)?;
382 if xc_cost <= xr_cost {
383 *self.params.last_mut().unwrap() = (xc, xc_cost);
384 Action::ContractionOutside
385 } else {
386 self.shrink(|x| problem.cost(x))?;
388 Action::Shrink
389 }
390 } else {
391 let xc = self.contract(&x0, &p_worst.0);
393 let xc_cost = problem.cost(&xc)?;
394 if xc_cost < p_worst.1 {
395 *self.params.last_mut().unwrap() = (xc, xc_cost);
396 Action::ContractionInside
397 } else {
398 self.shrink(|x| problem.cost(x))?;
400 Action::Shrink
401 }
402 }
403 } else {
404 return Err(argmin_error!(
405 PotentialBug,
406 "`NelderMead`: Reached unreachable point."
407 ));
408 };
409
410 self.sort_param_vecs();
411
412 Ok((
413 state.param(self.params[0].0.clone()).cost(self.params[0].1),
414 Some(kv!("action" => format!("{action}");)),
415 ))
416 }
417
418 fn terminate(&mut self, _state: &IterState<P, (), (), (), (), F>) -> TerminationStatus {
419 let n = float!(self.params.len() as f64);
420 let c0: F = self.params.iter().map(|(_, c)| *c).sum::<F>() / n;
421 let s: F = (float!(1.0) / (n - float!(1.0))
422 * self
423 .params
424 .iter()
425 .map(|(_, c)| (*c - c0).powi(2))
426 .sum::<F>())
427 .sqrt();
428 if s < self.sd_tolerance {
429 return TerminationStatus::Terminated(TerminationReason::SolverConverged);
430 }
431 TerminationStatus::NotTerminated
432 }
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438 use crate::core::{test_utils::TestProblem, ArgminError, State};
439 use approx::assert_relative_eq;
440
441 test_trait_impl!(nelder_mead, NelderMead<TestProblem, f64>);
442
443 struct MwProblem {}
444
445 impl CostFunction for MwProblem {
446 type Param = Vec<f64>;
447 type Output = f64;
448
449 fn cost(&self, p: &Self::Param) -> Result<Self::Output, Error> {
450 Ok(p.iter().fold(0.0, |acc, x| acc + x.powi(2)))
451 }
452 }
453
454 #[test]
455 fn test_new() {
456 let params = vec![vec![1.0], vec![2.0]];
457 let nm: NelderMead<Vec<f64>, f64> = NelderMead::new(params);
458
459 let NelderMead {
460 alpha,
461 gamma,
462 rho,
463 sigma,
464 params,
465 sd_tolerance,
466 } = nm;
467
468 assert_eq!(alpha.to_ne_bytes(), 1.0f64.to_ne_bytes());
469 assert_eq!(gamma.to_ne_bytes(), 2.0f64.to_ne_bytes());
470 assert_eq!(rho.to_ne_bytes(), 0.5f64.to_ne_bytes());
471 assert_eq!(sigma.to_ne_bytes(), 0.5f64.to_ne_bytes());
472 assert_eq!(params[0].0[0].to_ne_bytes(), 1.0f64.to_ne_bytes());
473 assert_eq!(params[1].0[0].to_ne_bytes(), 2.0f64.to_ne_bytes());
474 assert_eq!(params[0].1.to_ne_bytes(), f64::NAN.to_ne_bytes());
475 assert_eq!(params[1].1.to_ne_bytes(), f64::NAN.to_ne_bytes());
476 assert_eq!(sd_tolerance.to_ne_bytes(), f64::EPSILON.to_ne_bytes());
477 }
478
479 #[test]
480 fn test_with_sd_tolerance() {
481 for tol in [1e-6, 0.0, 1e-2, 1.0, 2.0] {
483 let params = vec![vec![1.0], vec![2.0]];
484 let nm: NelderMead<Vec<f64>, f64> = NelderMead::new(params);
485 let res = nm.with_sd_tolerance(tol);
486 assert!(res.is_ok());
487
488 let nm = res.unwrap();
489 assert_eq!(nm.sd_tolerance.to_ne_bytes(), tol.to_ne_bytes());
490 }
491
492 for tol in [-f64::EPSILON, -1.0, -100.0, -42.0] {
494 let params = vec![vec![1.0], vec![2.0]];
495 let nm: NelderMead<Vec<f64>, f64> = NelderMead::new(params);
496 let res = nm.with_sd_tolerance(tol);
497 assert_error!(
498 res,
499 ArgminError,
500 concat!(
501 "Invalid parameter: \"`Nelder-Mead`: ",
502 "sd_tolerance must be >= 0.\""
503 )
504 );
505 }
506 }
507
508 #[test]
509 fn test_with_alpha() {
510 for alpha in [f64::EPSILON, 1e-6, 1e-2, 1.0, 2.0] {
512 let params = vec![vec![1.0], vec![2.0]];
513 let nm: NelderMead<Vec<f64>, f64> = NelderMead::new(params);
514 let res = nm.with_alpha(alpha);
515 assert!(res.is_ok());
516
517 let nm = res.unwrap();
518 assert_eq!(nm.alpha.to_ne_bytes(), alpha.to_ne_bytes());
519 }
520
521 for alpha in [-f64::EPSILON, -1.0, -100.0, -42.0] {
523 let params = vec![vec![1.0], vec![2.0]];
524 let nm: NelderMead<Vec<f64>, f64> = NelderMead::new(params);
525 let res = nm.with_alpha(alpha);
526 assert_error!(
527 res,
528 ArgminError,
529 concat!(
530 "Invalid parameter: \"`Nelder-Mead`: ",
531 "alpha must be > 0.\""
532 )
533 );
534 }
535 }
536
537 #[test]
538 fn test_with_rho() {
539 for rho in [f64::EPSILON, 0.1, 0.3, 0.5] {
541 let params = vec![vec![1.0], vec![2.0]];
542 let nm: NelderMead<Vec<f64>, f64> = NelderMead::new(params);
543 let res = nm.with_rho(rho);
544 assert!(res.is_ok());
545
546 let nm = res.unwrap();
547 assert_eq!(nm.rho.to_ne_bytes(), rho.to_ne_bytes());
548 }
549
550 for rho in [-1.0, 0.0, 0.5 + f64::EPSILON, 1.0] {
552 let params = vec![vec![1.0], vec![2.0]];
553 let nm: NelderMead<Vec<f64>, f64> = NelderMead::new(params);
554 let res = nm.with_rho(rho);
555 assert_error!(
556 res,
557 ArgminError,
558 concat!(
559 "Invalid parameter: \"`Nelder-Mead`: ",
560 "rho must be in (0, 0.5].\""
561 )
562 );
563 }
564 }
565
566 #[test]
567 fn test_with_sigma() {
568 for sigma in [f64::EPSILON, 0.3, 0.5, 0.9, 1.0 - f64::EPSILON] {
570 let params = vec![vec![1.0], vec![2.0]];
571 let nm: NelderMead<Vec<f64>, f64> = NelderMead::new(params);
572 let res = nm.with_sigma(sigma);
573 assert!(res.is_ok());
574
575 let nm = res.unwrap();
576 assert_eq!(nm.sigma.to_ne_bytes(), sigma.to_ne_bytes());
577 }
578
579 for sigma in [-1.0, 0.0, 1.0 + f64::EPSILON, 10.0] {
581 let params = vec![vec![1.0], vec![2.0]];
582 let nm: NelderMead<Vec<f64>, f64> = NelderMead::new(params);
583 let res = nm.with_sigma(sigma);
584 assert_error!(
585 res,
586 ArgminError,
587 concat!(
588 "Invalid parameter: \"`Nelder-Mead`: ",
589 "sigma must be in (0, 1].\""
590 )
591 );
592 }
593 }
594
595 #[test]
596 fn test_sort_param_vecs() {
597 let params: Vec<Vec<f64>> = vec![vec![2.0], vec![1.0], vec![3.0]];
598 let params_sorted: Vec<Vec<f64>> = vec![vec![1.0], vec![2.0], vec![3.0]];
599 let mut nm: NelderMead<_, f64> = NelderMead::new(params);
600 nm.params.iter_mut().for_each(|(p, c)| *c = p[0]);
601 nm.sort_param_vecs();
602 for ((p, c), ps) in nm.params.iter().zip(params_sorted.iter()) {
603 assert_eq!(p[0].to_ne_bytes(), ps[0].to_ne_bytes());
604 assert_eq!(c.to_ne_bytes(), ps[0].to_ne_bytes());
605 }
606 }
607
608 #[test]
609 fn test_calculate_centroid() {
610 let params: Vec<Vec<f64>> = vec![vec![0.2, 0.0], vec![0.4, 1.0], vec![1.0, 0.0]];
611 let mut nm: NelderMead<_, f64> = NelderMead::new(params);
612 nm.params
613 .iter_mut()
614 .enumerate()
615 .for_each(|(i, (_, c))| *c = i as f64);
616 nm.sort_param_vecs();
617 let centroid = nm.calculate_centroid();
618 assert_relative_eq!(centroid[0], 0.3f64, epsilon = f64::EPSILON);
619 assert_relative_eq!(centroid[1], 0.5f64, epsilon = f64::EPSILON);
620 }
621
622 #[test]
623 fn test_reflect() {
624 let params: Vec<Vec<f64>> = vec![vec![0.0, 1.0], vec![1.0, 0.0], vec![0.0, 0.0]];
625 let mut nm: NelderMead<_, f64> = NelderMead::new(params);
626 nm.params
627 .iter_mut()
628 .enumerate()
629 .for_each(|(i, (_, c))| *c = i as f64);
630 nm.sort_param_vecs();
631 let centroid = nm.calculate_centroid();
632 let reflected = nm.reflect(¢roid, &vec![0.0, 0.0]);
633 assert_relative_eq!(reflected[0], 1.0f64, epsilon = f64::EPSILON);
634 assert_relative_eq!(reflected[1], 1.0f64, epsilon = f64::EPSILON);
635 }
636
637 #[test]
638 fn test_expand() {
639 let params: Vec<Vec<f64>> = vec![vec![0.0, 1.0], vec![1.0, 0.0], vec![0.0, 0.0]];
640 let mut nm: NelderMead<_, f64> = NelderMead::new(params);
641 nm.params
642 .iter_mut()
643 .enumerate()
644 .for_each(|(i, (_, c))| *c = i as f64);
645 nm.sort_param_vecs();
646 let centroid = nm.calculate_centroid();
647 let expanded = nm.expand(¢roid, &vec![1.0, 1.0]);
648 assert_relative_eq!(expanded[0], 1.5f64, epsilon = f64::EPSILON);
649 assert_relative_eq!(expanded[1], 1.5f64, epsilon = f64::EPSILON);
650 }
651
652 #[test]
653 fn test_contract() {
654 let params: Vec<Vec<f64>> = vec![vec![0.0, 1.0], vec![1.0, 0.0], vec![0.0, 0.0]];
655 let mut nm: NelderMead<_, f64> = NelderMead::new(params);
656 nm.params
657 .iter_mut()
658 .enumerate()
659 .for_each(|(i, (_, c))| *c = i as f64);
660 nm.sort_param_vecs();
661 let centroid = nm.calculate_centroid();
662 let contracted = nm.contract(¢roid, &vec![1.0, 1.0]);
663 assert_relative_eq!(contracted[0], 0.75f64, epsilon = f64::EPSILON);
664 assert_relative_eq!(contracted[1], 0.75f64, epsilon = f64::EPSILON);
665 }
666
667 #[test]
668 fn test_shrink() {
669 let params: Vec<Vec<f64>> = vec![vec![0.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]];
670 let params_shrunk: Vec<Vec<f64>> = vec![vec![0.0, 0.0], vec![0.0, 0.5], vec![0.5, 0.0]];
671 let mut nm: NelderMead<_, f64> = NelderMead::new(params);
672 nm.params
673 .iter_mut()
674 .enumerate()
675 .for_each(|(i, (_, c))| *c = i as f64);
676 nm.sort_param_vecs();
677 nm.shrink(|_| Ok(1.0f64)).unwrap();
678
679 for ((p, _), ps) in nm.params.iter().zip(params_shrunk.iter()) {
680 assert_eq!(p[0].to_ne_bytes(), ps[0].to_ne_bytes());
681 assert_eq!(p[1].to_ne_bytes(), ps[1].to_ne_bytes());
682 }
683 }
684
685 #[test]
686 fn test_init() {
687 let params: Vec<Vec<f64>> = vec![vec![-1.0, 1.0], vec![-0.5, 2.0], vec![0.7, -1.0]];
688 let params_sorted: Vec<(Vec<f64>, f64)> = vec![
689 (vec![0.7, -1.0], 0.7f64.powi(2) + 1.0f64.powi(2)),
690 (vec![-1.0, 1.0], 2.0),
691 (vec![-0.5, 2.0], 0.5f64.powi(2) + 2.0f64.powi(2)),
692 ];
693 let mut nm: NelderMead<_, f64> = NelderMead::new(params);
694 let state: IterState<Vec<f64>, (), (), (), (), f64> = IterState::new();
695 let problem = MwProblem {};
696 let (state_out, kv) = nm.init(&mut Problem::new(problem), state).unwrap();
697
698 assert!(kv.is_none());
699
700 for ((p, c), (ps, cs)) in nm.params.iter().zip(params_sorted.iter()) {
701 assert_relative_eq!(c, cs, epsilon = f64::EPSILON);
702 assert_eq!(p[0].to_ne_bytes(), ps[0].to_ne_bytes());
703 assert_eq!(p[1].to_ne_bytes(), ps[1].to_ne_bytes());
704 }
705
706 for i in 0..2 {
707 assert_relative_eq!(
708 state_out.get_param().unwrap()[i],
709 params_sorted[0].0[i],
710 epsilon = f64::EPSILON
711 );
712 }
713
714 assert_relative_eq!(
715 state_out.get_cost(),
716 0.7f64.powi(2) + 1.0f64.powi(2),
717 epsilon = f64::EPSILON
718 );
719 }
720
721 #[test]
722 fn test_next_iter_reflection() {
723 let params: Vec<Vec<f64>> = vec![vec![-1.0, 0.0], vec![-0.1, 0.65], vec![-0.1, -0.95]];
724 let mut nm: NelderMead<_, f64> = NelderMead::new(params);
725 let state: IterState<Vec<f64>, (), (), (), (), f64> = IterState::new();
726 let mut problem = Problem::new(MwProblem {});
727 let (state, _) = nm.init(&mut problem, state).unwrap();
728
729 let (state, kv) = nm.next_iter(&mut problem, state).unwrap();
730
731 assert_eq!(
732 format!("{}", kv.unwrap().get("action").unwrap()),
733 "Reflection"
734 );
735
736 let param = state.get_param().unwrap();
737
738 assert_relative_eq!(param[0], -0.1f64, epsilon = f64::EPSILON);
739 assert_relative_eq!(param[1], 0.65f64, epsilon = f64::EPSILON);
740
741 let cost = state.get_cost();
742 assert_relative_eq!(cost, 0.4325f64, epsilon = f64::EPSILON);
743
744 assert_relative_eq!(nm.params[0].0[0], -0.1f64, epsilon = f64::EPSILON);
745 assert_relative_eq!(nm.params[0].0[1], 0.65f64, epsilon = f64::EPSILON);
746 assert_relative_eq!(nm.params[0].1, 0.4325f64, epsilon = f64::EPSILON);
747
748 assert_relative_eq!(nm.params[1].0[0], 0.8f64, epsilon = f64::EPSILON);
749 assert_relative_eq!(nm.params[1].0[1], -0.3f64, epsilon = f64::EPSILON);
750 assert_relative_eq!(nm.params[1].1, 0.73f64, epsilon = f64::EPSILON);
751
752 assert_relative_eq!(nm.params[2].0[0], -0.1f64, epsilon = f64::EPSILON);
753 assert_relative_eq!(nm.params[2].0[1], -0.95f64, epsilon = f64::EPSILON);
754 assert_relative_eq!(nm.params[2].1, 0.9125f64, epsilon = f64::EPSILON);
755 }
756
757 #[test]
758 fn test_next_iter_expansion() {
759 let params: Vec<Vec<f64>> = vec![
760 vec![-2.0, 0.0],
761 vec![-1.0, 1.0],
762 vec![-1.0, -1.0 - f64::EPSILON],
766 ];
767 let mut nm: NelderMead<_, f64> = NelderMead::new(params);
768 let state: IterState<Vec<f64>, (), (), (), (), f64> = IterState::new();
769 let mut problem = Problem::new(MwProblem {});
770 let (state, _) = nm.init(&mut problem, state).unwrap();
771
772 let (state, kv) = nm.next_iter(&mut problem, state).unwrap();
773
774 assert_eq!(
775 format!("{}", kv.unwrap().get("action").unwrap()),
776 "Expansion"
777 );
778
779 let param = state.get_param().unwrap();
780
781 assert_relative_eq!(param[0], 0.0f64, epsilon = f64::EPSILON);
782 assert_relative_eq!(param[1], 0.0f64, epsilon = f64::EPSILON);
783
784 let cost = state.get_cost();
785 assert_relative_eq!(cost, 0.0f64, epsilon = f64::EPSILON);
786
787 assert_relative_eq!(nm.params[0].0[0], 0.0f64, epsilon = f64::EPSILON);
788 assert_relative_eq!(nm.params[0].0[1], 0.0f64, epsilon = f64::EPSILON);
789 assert_relative_eq!(nm.params[0].1, 0.0f64, epsilon = f64::EPSILON);
790
791 assert_relative_eq!(nm.params[1].0[0], -1.0f64, epsilon = f64::EPSILON);
792 assert_relative_eq!(nm.params[1].0[1], 1.0f64, epsilon = f64::EPSILON);
793 assert_relative_eq!(nm.params[1].1, 2.0f64, epsilon = f64::EPSILON);
794
795 assert_relative_eq!(nm.params[2].0[0], -1.0f64, epsilon = f64::EPSILON);
796 assert_relative_eq!(nm.params[2].0[1], -1.0f64, epsilon = f64::EPSILON);
797 assert_relative_eq!(nm.params[2].1, 2.0f64, epsilon = f64::EPSILON);
798 }
799
800 #[test]
801 fn test_next_iter_contraction_outside() {
802 let params: Vec<Vec<f64>> = vec![vec![-1.1, 0.0], vec![-0.1, 1.0], vec![-0.1, -0.5]];
803 let mut nm: NelderMead<_, f64> = NelderMead::new(params);
804 let state: IterState<Vec<f64>, (), (), (), (), f64> = IterState::new();
805 let mut problem = Problem::new(MwProblem {});
806 let (state, _) = nm.init(&mut problem, state).unwrap();
807
808 let (state, kv) = nm.next_iter(&mut problem, state).unwrap();
809
810 assert_eq!(
811 format!("{}", kv.unwrap().get("action").unwrap()),
812 "ContractionOutside"
813 );
814
815 let param = state.get_param().unwrap();
816
817 assert_relative_eq!(param[0], -0.1f64, epsilon = f64::EPSILON);
818 assert_relative_eq!(param[1], -0.5f64, epsilon = f64::EPSILON);
819
820 let cost = state.get_cost();
821 assert_relative_eq!(cost, 0.26f64, epsilon = f64::EPSILON);
822
823 assert_relative_eq!(nm.params[0].0[0], -0.1f64, epsilon = f64::EPSILON);
824 assert_relative_eq!(nm.params[0].0[1], -0.5f64, epsilon = f64::EPSILON);
825 assert_relative_eq!(nm.params[0].1, 0.26f64, epsilon = f64::EPSILON);
826
827 assert_relative_eq!(nm.params[1].0[0], 0.4f64, epsilon = f64::EPSILON);
828 assert_relative_eq!(nm.params[1].0[1], 0.375f64, epsilon = f64::EPSILON);
829 assert_relative_eq!(nm.params[1].1, 0.300625f64, epsilon = f64::EPSILON);
830
831 assert_relative_eq!(nm.params[2].0[0], -0.1f64, epsilon = f64::EPSILON);
832 assert_relative_eq!(nm.params[2].0[1], 1.0f64, epsilon = f64::EPSILON);
833 assert_relative_eq!(nm.params[2].1, 1.01f64, epsilon = f64::EPSILON);
834 }
835
836 #[test]
837 fn test_next_iter_contraction_inside() {
838 let params: Vec<Vec<f64>> = vec![vec![-1.0, 0.0], vec![0.0, 1.0], vec![0.0, -0.5]];
839 let mut nm: NelderMead<_, f64> = NelderMead::new(params);
840 let state: IterState<Vec<f64>, (), (), (), (), f64> = IterState::new();
841 let mut problem = Problem::new(MwProblem {});
842 let (state, _) = nm.init(&mut problem, state).unwrap();
843
844 let (state, kv) = nm.next_iter(&mut problem, state).unwrap();
845
846 assert_eq!(
847 format!("{}", kv.unwrap().get("action").unwrap()),
848 "ContractionInside"
849 );
850
851 let param = state.get_param().unwrap();
852
853 assert_relative_eq!(param[0], -0.25f64, epsilon = f64::EPSILON);
854 assert_relative_eq!(param[1], 0.375f64, epsilon = f64::EPSILON);
855
856 let cost = state.get_cost();
857 assert_relative_eq!(cost, 0.203125f64, epsilon = f64::EPSILON);
858
859 assert_relative_eq!(nm.params[0].0[0], -0.25f64, epsilon = f64::EPSILON);
860 assert_relative_eq!(nm.params[0].0[1], 0.375f64, epsilon = f64::EPSILON);
861 assert_relative_eq!(nm.params[0].1, 0.203125f64, epsilon = f64::EPSILON);
862
863 assert_relative_eq!(nm.params[1].0[0], 0.0f64, epsilon = f64::EPSILON);
864 assert_relative_eq!(nm.params[1].0[1], -0.5f64, epsilon = f64::EPSILON);
865 assert_relative_eq!(nm.params[1].1, 0.25, epsilon = f64::EPSILON);
866
867 assert_relative_eq!(nm.params[2].0[0], -1.0f64, epsilon = f64::EPSILON);
868 assert_relative_eq!(nm.params[2].0[1], 0.0f64, epsilon = f64::EPSILON);
869 assert_relative_eq!(nm.params[2].1, 1.00f64, epsilon = f64::EPSILON);
870 }
871}