argmin/solver/neldermead/
mod.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! # Nelder-Mead method
9//!
10//! The Nelder-Mead method a heuristic search method for nonlinear optimization problems which does
11//! not require derivatives.
12//!
13//! See [`NelderMead`] for details.
14//!
15//! ## References
16//!
17//! <https://en.wikipedia.org/wiki/Nelder%E2%80%93Mead_method>
18//!
19//! <http://www.scholarpedia.org/article/Nelder-Mead_algorithm#Simplex_transformation_algorithm>
20
21use crate::core::{
22    ArgminFloat, CostFunction, Error, IterState, Problem, Solver, TerminationReason,
23    TerminationStatus, KV,
24};
25use argmin_math::{ArgminAdd, ArgminMul, ArgminSub};
26#[cfg(feature = "serde1")]
27use serde::{Deserialize, Serialize};
28use std::fmt;
29
30/// # Nelder-Mead method
31///
32/// The Nelder-Mead method a heuristic search method for nonlinear optimization problems which does
33/// not require derivatives.
34///
35/// The method is based on simplices which consist of n+1 vertices for an optimization problem with
36/// n dimensions.
37/// The function to be optimized is evaluated at all vertices. Based on these cost function values
38/// the behavior of the cost function is extrapolated in order to find the next point to be
39/// evaluated.
40///
41/// The following actions are possible:
42///
43/// 1) Reflection (Parameter `alpha`, defaults to `1`, configurable via
44///    [`with_alpha`](`NelderMead::with_alpha`))
45/// 2) Expansion (Parameter `gamma`, defaults to `2`, configurable via
46///    [`with_gamma`](`NelderMead::with_gamma`))
47/// 3) Contraction inside or outside (Parameter `rho`, defaults to `0.5`, configurable via
48///    [`with_rho`](`NelderMead::with_rho`))
49/// 4) Shrink (Parameter `sigma`, defaults to `0.5`, configurable via
50///    [`with_sigma`](`NelderMead::with_sigma`))
51///
52/// ## Requirements on the optimization problem
53///
54/// The optimization problem is required to implement [`CostFunction`].
55///
56/// ## References
57///
58/// <https://en.wikipedia.org/wiki/Nelder%E2%80%93Mead_method>
59///
60/// <http://www.scholarpedia.org/article/Nelder-Mead_algorithm#Simplex_transformation_algorithm>
61#[derive(Clone)]
62#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
63pub struct NelderMead<P, F> {
64    /// alpha
65    alpha: F,
66    /// gamma
67    gamma: F,
68    /// rho
69    rho: F,
70    /// sigma
71    sigma: F,
72    /// parameters
73    params: Vec<(P, F)>,
74    /// Sample standard deviation tolerance
75    sd_tolerance: F,
76}
77
78impl<P, F> NelderMead<P, F>
79where
80    P: Clone + ArgminAdd<P, P> + ArgminSub<P, P> + ArgminMul<F, P>,
81    F: ArgminFloat,
82{
83    /// Construct a new instance of `NelderMead`
84    ///
85    /// Takes a vector of parameter vectors. The number of parameter vectors must be `n + 1` where
86    /// `n` is the number of optimization parameters.
87    ///
88    /// # Example
89    ///
90    /// ```
91    /// # use argmin::solver::neldermead::NelderMead;
92    /// # let vec_of_parameters = vec![vec![1.0], vec![2.0], vec![3.0]];
93    /// let nm: NelderMead<Vec<f64>, f64> = NelderMead::new(vec_of_parameters);
94    /// ```
95    pub fn new(params: Vec<P>) -> Self {
96        NelderMead {
97            alpha: float!(1.0),
98            gamma: float!(2.0),
99            rho: float!(0.5),
100            sigma: float!(0.5),
101            params: params.into_iter().map(|p| (p, F::nan())).collect(),
102            sd_tolerance: F::epsilon(),
103        }
104    }
105
106    /// Set sample standard deviation tolerance
107    ///
108    /// Must be non-negative and defaults to `EPSILON`.
109    ///
110    /// # Example
111    ///
112    /// ```
113    /// # use argmin::solver::neldermead::NelderMead;
114    /// # use argmin::core::Error;
115    /// # fn main() -> Result<(), Error> {
116    /// # let vec_of_parameters = vec![vec![1.0], vec![2.0], vec![3.0]];
117    /// let nm: NelderMead<Vec<f64>, f64> =
118    ///     NelderMead::new(vec_of_parameters).with_sd_tolerance(1e-6)?;
119    /// # Ok(())
120    /// # }
121    /// ```
122    pub fn with_sd_tolerance(mut self, tol: F) -> Result<Self, Error> {
123        if tol < float!(0.0) {
124            return Err(argmin_error!(
125                InvalidParameter,
126                "`Nelder-Mead`: sd_tolerance must be >= 0."
127            ));
128        }
129        self.sd_tolerance = tol;
130        Ok(self)
131    }
132
133    /// Set alpha parameter for reflection
134    ///
135    /// Must be larger than 0 and defaults to 1.
136    ///
137    /// # Example
138    ///
139    /// ```
140    /// # use argmin::solver::neldermead::NelderMead;
141    /// # use argmin::core::Error;
142    /// # fn main() -> Result<(), Error> {
143    /// # let vec_of_parameters = vec![vec![1.0], vec![2.0], vec![3.0]];
144    /// let nm: NelderMead<Vec<f64>, f64> =
145    ///     NelderMead::new(vec_of_parameters).with_alpha(0.9)?;
146    /// # Ok(())
147    /// # }
148    /// ```
149    pub fn with_alpha(mut self, alpha: F) -> Result<Self, Error> {
150        if alpha <= float!(0.0) {
151            return Err(argmin_error!(
152                InvalidParameter,
153                "`Nelder-Mead`: alpha must be > 0."
154            ));
155        }
156        self.alpha = alpha;
157        Ok(self)
158    }
159
160    /// Set gamma for expansion
161    ///
162    /// Must be larger than 1 and defaults to 2.
163    ///
164    /// # Example
165    ///
166    /// ```
167    /// # use argmin::solver::neldermead::NelderMead;
168    /// # use argmin::core::Error;
169    /// # fn main() -> Result<(), Error> {
170    /// # let vec_of_parameters = vec![vec![1.0], vec![2.0], vec![3.0]];
171    /// let nm: NelderMead<Vec<f64>, f64> =
172    ///     NelderMead::new(vec_of_parameters).with_gamma(1.9)?;
173    /// # Ok(())
174    /// # }
175    /// ```
176    pub fn with_gamma(mut self, gamma: F) -> Result<Self, Error> {
177        if gamma <= float!(1.0) {
178            return Err(argmin_error!(
179                InvalidParameter,
180                "`Nelder-Mead`: gamma must be > 1."
181            ));
182        }
183        self.gamma = gamma;
184        Ok(self)
185    }
186
187    /// Set rho for contraction
188    ///
189    /// Must be in (0, 0.5] and defaults to 0.5.
190    ///
191    /// # Example
192    ///
193    /// ```
194    /// # use argmin::solver::neldermead::NelderMead;
195    /// # use argmin::core::Error;
196    /// # fn main() -> Result<(), Error> {
197    /// # let vec_of_parameters = vec![vec![1.0], vec![2.0], vec![3.0]];
198    /// let nm: NelderMead<Vec<f64>, f64> =
199    ///     NelderMead::new(vec_of_parameters).with_rho(0.4)?;
200    /// # Ok(())
201    /// # }
202    /// ```
203    pub fn with_rho(mut self, rho: F) -> Result<Self, Error> {
204        if rho <= float!(0.0) || rho > float!(0.5) {
205            return Err(argmin_error!(
206                InvalidParameter,
207                "`Nelder-Mead`: rho must be in (0, 0.5]."
208            ));
209        }
210        self.rho = rho;
211        Ok(self)
212    }
213
214    /// Set sigma for shrinking
215    ///
216    /// Must be in (0, 1] and defaults to 0.5.
217    ///
218    /// # Example
219    ///
220    /// ```
221    /// # use argmin::solver::neldermead::NelderMead;
222    /// # use argmin::core::Error;
223    /// # fn main() -> Result<(), Error> {
224    /// # let vec_of_parameters = vec![vec![1.0], vec![2.0], vec![3.0]];
225    /// let nm: NelderMead<Vec<f64>, f64> =
226    ///     NelderMead::new(vec_of_parameters).with_sigma(0.4)?;
227    /// # Ok(())
228    /// # }
229    /// ```
230    pub fn with_sigma(mut self, sigma: F) -> Result<Self, Error> {
231        if sigma <= float!(0.0) || sigma > float!(1.0) {
232            return Err(argmin_error!(
233                InvalidParameter,
234                "`Nelder-Mead`: sigma must be in (0, 1]."
235            ));
236        }
237        self.sigma = sigma;
238        Ok(self)
239    }
240
241    /// Sort parameters vectors based on their cost function values
242    fn sort_param_vecs(&mut self) {
243        self.params
244            .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
245    }
246
247    /// Calculate centroid of all vectors but the worst
248    fn calculate_centroid(&self) -> P {
249        // Number of parameters is number of parameter vectors minus 1
250        let num_param = self.params.len() - 1;
251        self.params
252            .iter()
253            // Avoid the worst vector
254            .take(num_param)
255            // First one is used as the accumulator, therefore exclude it from the iterator
256            .skip(1)
257            // Add all vectors to the first
258            .fold(self.params[0].0.clone(), |acc, p| acc.add(&p.0))
259            // Scale
260            .mul(&(float!(1.0) / (float!(num_param as f64))))
261    }
262
263    /// Reflect
264    fn reflect(&self, x0: &P, x: &P) -> P {
265        x0.add(&x0.sub(x).mul(&self.alpha))
266    }
267
268    /// Expand
269    fn expand(&self, x0: &P, x: &P) -> P {
270        x0.add(&x.sub(x0).mul(&self.gamma))
271    }
272
273    /// Contract
274    fn contract(&self, x0: &P, x: &P) -> P {
275        x0.add(&x.sub(x0).mul(&self.rho))
276    }
277
278    /// Shrink
279    fn shrink<S>(&mut self, mut cost: S) -> Result<(), Error>
280    where
281        S: FnMut(&P) -> Result<F, Error>,
282    {
283        // The best parameter vector unfortunately has to be cloned once.
284        let x0 = self.params[0].0.clone();
285        self.params
286            .iter_mut()
287            // Best one is not modified
288            .skip(1)
289            .try_for_each(|(p, c)| -> Result<(), Error> {
290                *p = x0.add(&p.sub(&x0).mul(&self.sigma));
291                *c = (cost)(p)?;
292                Ok(())
293            })?;
294        Ok(())
295    }
296}
297
298#[derive(Debug)]
299enum Action {
300    Reflection,
301    Expansion,
302    ContractionOutside,
303    ContractionInside,
304    Shrink,
305}
306
307impl fmt::Display for Action {
308    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
309        match *self {
310            Action::Reflection => write!(f, "Reflection"),
311            Action::Expansion => write!(f, "Expansion"),
312            Action::ContractionOutside => write!(f, "ContractionOutside"),
313            Action::ContractionInside => write!(f, "ContractionInside"),
314            Action::Shrink => write!(f, "Shrink"),
315        }
316    }
317}
318
319impl<O, P, F> Solver<O, IterState<P, (), (), (), (), F>> for NelderMead<P, F>
320where
321    O: CostFunction<Param = P, Output = F>,
322    P: Clone + ArgminSub<P, P> + ArgminAdd<P, P> + ArgminMul<F, P>,
323    F: ArgminFloat + std::iter::Sum<F>,
324{
325    fn name(&self) -> &str {
326        "Nelder-Mead method"
327    }
328
329    fn init(
330        &mut self,
331        problem: &mut Problem<O>,
332        state: IterState<P, (), (), (), (), F>,
333    ) -> Result<(IterState<P, (), (), (), (), F>, Option<KV>), Error> {
334        self.params
335            .iter_mut()
336            .for_each(|(p, c)| *c = problem.cost(p).unwrap());
337
338        self.sort_param_vecs();
339
340        Ok((
341            state.param(self.params[0].0.clone()).cost(self.params[0].1),
342            None,
343        ))
344    }
345
346    fn next_iter(
347        &mut self,
348        problem: &mut Problem<O>,
349        state: IterState<P, (), (), (), (), F>,
350    ) -> Result<(IterState<P, (), (), (), (), F>, Option<KV>), Error> {
351        let num_param_vecs = self.params.len();
352
353        let x0 = self.calculate_centroid();
354
355        let p_best = &self.params[0];
356        let p_worst = &self.params[num_param_vecs - 1];
357        let p_second_worst = &self.params[num_param_vecs - 2];
358
359        let xr = self.reflect(&x0, &p_worst.0);
360        let xr_cost = problem.cost(&xr)?;
361
362        let action = if xr_cost < p_second_worst.1 && xr_cost >= p_best.1 {
363            // reflection
364            *self.params.last_mut().unwrap() = (xr, xr_cost);
365            Action::Reflection
366        } else if xr_cost < p_best.1 {
367            // expansion
368            let xe = self.expand(&x0, &xr);
369            let xe_cost = problem.cost(&xe)?;
370            *self.params.last_mut().unwrap() = if xe_cost < xr_cost {
371                (xe, xe_cost)
372            } else {
373                (xr, xr_cost)
374            };
375            Action::Expansion
376        } else if xr_cost >= p_second_worst.1 {
377            // contraction
378            if xr_cost < p_worst.1 {
379                // Outside
380                let xc = self.contract(&x0, &xr);
381                let xc_cost = problem.cost(&xc)?;
382                if xc_cost <= xr_cost {
383                    *self.params.last_mut().unwrap() = (xc, xc_cost);
384                    Action::ContractionOutside
385                } else {
386                    // shrink
387                    self.shrink(|x| problem.cost(x))?;
388                    Action::Shrink
389                }
390            } else {
391                // Inside
392                let xc = self.contract(&x0, &p_worst.0);
393                let xc_cost = problem.cost(&xc)?;
394                if xc_cost < p_worst.1 {
395                    *self.params.last_mut().unwrap() = (xc, xc_cost);
396                    Action::ContractionInside
397                } else {
398                    // shrink
399                    self.shrink(|x| problem.cost(x))?;
400                    Action::Shrink
401                }
402            }
403        } else {
404            return Err(argmin_error!(
405                PotentialBug,
406                "`NelderMead`: Reached unreachable point."
407            ));
408        };
409
410        self.sort_param_vecs();
411
412        Ok((
413            state.param(self.params[0].0.clone()).cost(self.params[0].1),
414            Some(kv!("action" => format!("{action}");)),
415        ))
416    }
417
418    fn terminate(&mut self, _state: &IterState<P, (), (), (), (), F>) -> TerminationStatus {
419        let n = float!(self.params.len() as f64);
420        let c0: F = self.params.iter().map(|(_, c)| *c).sum::<F>() / n;
421        let s: F = (float!(1.0) / (n - float!(1.0))
422            * self
423                .params
424                .iter()
425                .map(|(_, c)| (*c - c0).powi(2))
426                .sum::<F>())
427        .sqrt();
428        if s < self.sd_tolerance {
429            return TerminationStatus::Terminated(TerminationReason::SolverConverged);
430        }
431        TerminationStatus::NotTerminated
432    }
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438    use crate::core::{test_utils::TestProblem, ArgminError, State};
439    use approx::assert_relative_eq;
440
441    test_trait_impl!(nelder_mead, NelderMead<TestProblem, f64>);
442
443    struct MwProblem {}
444
445    impl CostFunction for MwProblem {
446        type Param = Vec<f64>;
447        type Output = f64;
448
449        fn cost(&self, p: &Self::Param) -> Result<Self::Output, Error> {
450            Ok(p.iter().fold(0.0, |acc, x| acc + x.powi(2)))
451        }
452    }
453
454    #[test]
455    fn test_new() {
456        let params = vec![vec![1.0], vec![2.0]];
457        let nm: NelderMead<Vec<f64>, f64> = NelderMead::new(params);
458
459        let NelderMead {
460            alpha,
461            gamma,
462            rho,
463            sigma,
464            params,
465            sd_tolerance,
466        } = nm;
467
468        assert_eq!(alpha.to_ne_bytes(), 1.0f64.to_ne_bytes());
469        assert_eq!(gamma.to_ne_bytes(), 2.0f64.to_ne_bytes());
470        assert_eq!(rho.to_ne_bytes(), 0.5f64.to_ne_bytes());
471        assert_eq!(sigma.to_ne_bytes(), 0.5f64.to_ne_bytes());
472        assert_eq!(params[0].0[0].to_ne_bytes(), 1.0f64.to_ne_bytes());
473        assert_eq!(params[1].0[0].to_ne_bytes(), 2.0f64.to_ne_bytes());
474        assert_eq!(params[0].1.to_ne_bytes(), f64::NAN.to_ne_bytes());
475        assert_eq!(params[1].1.to_ne_bytes(), f64::NAN.to_ne_bytes());
476        assert_eq!(sd_tolerance.to_ne_bytes(), f64::EPSILON.to_ne_bytes());
477    }
478
479    #[test]
480    fn test_with_sd_tolerance() {
481        // correct parameters
482        for tol in [1e-6, 0.0, 1e-2, 1.0, 2.0] {
483            let params = vec![vec![1.0], vec![2.0]];
484            let nm: NelderMead<Vec<f64>, f64> = NelderMead::new(params);
485            let res = nm.with_sd_tolerance(tol);
486            assert!(res.is_ok());
487
488            let nm = res.unwrap();
489            assert_eq!(nm.sd_tolerance.to_ne_bytes(), tol.to_ne_bytes());
490        }
491
492        // incorrect parameters
493        for tol in [-f64::EPSILON, -1.0, -100.0, -42.0] {
494            let params = vec![vec![1.0], vec![2.0]];
495            let nm: NelderMead<Vec<f64>, f64> = NelderMead::new(params);
496            let res = nm.with_sd_tolerance(tol);
497            assert_error!(
498                res,
499                ArgminError,
500                concat!(
501                    "Invalid parameter: \"`Nelder-Mead`: ",
502                    "sd_tolerance must be >= 0.\""
503                )
504            );
505        }
506    }
507
508    #[test]
509    fn test_with_alpha() {
510        // correct parameters
511        for alpha in [f64::EPSILON, 1e-6, 1e-2, 1.0, 2.0] {
512            let params = vec![vec![1.0], vec![2.0]];
513            let nm: NelderMead<Vec<f64>, f64> = NelderMead::new(params);
514            let res = nm.with_alpha(alpha);
515            assert!(res.is_ok());
516
517            let nm = res.unwrap();
518            assert_eq!(nm.alpha.to_ne_bytes(), alpha.to_ne_bytes());
519        }
520
521        // incorrect parameters
522        for alpha in [-f64::EPSILON, -1.0, -100.0, -42.0] {
523            let params = vec![vec![1.0], vec![2.0]];
524            let nm: NelderMead<Vec<f64>, f64> = NelderMead::new(params);
525            let res = nm.with_alpha(alpha);
526            assert_error!(
527                res,
528                ArgminError,
529                concat!(
530                    "Invalid parameter: \"`Nelder-Mead`: ",
531                    "alpha must be > 0.\""
532                )
533            );
534        }
535    }
536
537    #[test]
538    fn test_with_rho() {
539        // correct parameters
540        for rho in [f64::EPSILON, 0.1, 0.3, 0.5] {
541            let params = vec![vec![1.0], vec![2.0]];
542            let nm: NelderMead<Vec<f64>, f64> = NelderMead::new(params);
543            let res = nm.with_rho(rho);
544            assert!(res.is_ok());
545
546            let nm = res.unwrap();
547            assert_eq!(nm.rho.to_ne_bytes(), rho.to_ne_bytes());
548        }
549
550        // incorrect parameters
551        for rho in [-1.0, 0.0, 0.5 + f64::EPSILON, 1.0] {
552            let params = vec![vec![1.0], vec![2.0]];
553            let nm: NelderMead<Vec<f64>, f64> = NelderMead::new(params);
554            let res = nm.with_rho(rho);
555            assert_error!(
556                res,
557                ArgminError,
558                concat!(
559                    "Invalid parameter: \"`Nelder-Mead`: ",
560                    "rho must be in (0, 0.5].\""
561                )
562            );
563        }
564    }
565
566    #[test]
567    fn test_with_sigma() {
568        // correct parameters
569        for sigma in [f64::EPSILON, 0.3, 0.5, 0.9, 1.0 - f64::EPSILON] {
570            let params = vec![vec![1.0], vec![2.0]];
571            let nm: NelderMead<Vec<f64>, f64> = NelderMead::new(params);
572            let res = nm.with_sigma(sigma);
573            assert!(res.is_ok());
574
575            let nm = res.unwrap();
576            assert_eq!(nm.sigma.to_ne_bytes(), sigma.to_ne_bytes());
577        }
578
579        // incorrect parameters
580        for sigma in [-1.0, 0.0, 1.0 + f64::EPSILON, 10.0] {
581            let params = vec![vec![1.0], vec![2.0]];
582            let nm: NelderMead<Vec<f64>, f64> = NelderMead::new(params);
583            let res = nm.with_sigma(sigma);
584            assert_error!(
585                res,
586                ArgminError,
587                concat!(
588                    "Invalid parameter: \"`Nelder-Mead`: ",
589                    "sigma must be in (0, 1].\""
590                )
591            );
592        }
593    }
594
595    #[test]
596    fn test_sort_param_vecs() {
597        let params: Vec<Vec<f64>> = vec![vec![2.0], vec![1.0], vec![3.0]];
598        let params_sorted: Vec<Vec<f64>> = vec![vec![1.0], vec![2.0], vec![3.0]];
599        let mut nm: NelderMead<_, f64> = NelderMead::new(params);
600        nm.params.iter_mut().for_each(|(p, c)| *c = p[0]);
601        nm.sort_param_vecs();
602        for ((p, c), ps) in nm.params.iter().zip(params_sorted.iter()) {
603            assert_eq!(p[0].to_ne_bytes(), ps[0].to_ne_bytes());
604            assert_eq!(c.to_ne_bytes(), ps[0].to_ne_bytes());
605        }
606    }
607
608    #[test]
609    fn test_calculate_centroid() {
610        let params: Vec<Vec<f64>> = vec![vec![0.2, 0.0], vec![0.4, 1.0], vec![1.0, 0.0]];
611        let mut nm: NelderMead<_, f64> = NelderMead::new(params);
612        nm.params
613            .iter_mut()
614            .enumerate()
615            .for_each(|(i, (_, c))| *c = i as f64);
616        nm.sort_param_vecs();
617        let centroid = nm.calculate_centroid();
618        assert_relative_eq!(centroid[0], 0.3f64, epsilon = f64::EPSILON);
619        assert_relative_eq!(centroid[1], 0.5f64, epsilon = f64::EPSILON);
620    }
621
622    #[test]
623    fn test_reflect() {
624        let params: Vec<Vec<f64>> = vec![vec![0.0, 1.0], vec![1.0, 0.0], vec![0.0, 0.0]];
625        let mut nm: NelderMead<_, f64> = NelderMead::new(params);
626        nm.params
627            .iter_mut()
628            .enumerate()
629            .for_each(|(i, (_, c))| *c = i as f64);
630        nm.sort_param_vecs();
631        let centroid = nm.calculate_centroid();
632        let reflected = nm.reflect(&centroid, &vec![0.0, 0.0]);
633        assert_relative_eq!(reflected[0], 1.0f64, epsilon = f64::EPSILON);
634        assert_relative_eq!(reflected[1], 1.0f64, epsilon = f64::EPSILON);
635    }
636
637    #[test]
638    fn test_expand() {
639        let params: Vec<Vec<f64>> = vec![vec![0.0, 1.0], vec![1.0, 0.0], vec![0.0, 0.0]];
640        let mut nm: NelderMead<_, f64> = NelderMead::new(params);
641        nm.params
642            .iter_mut()
643            .enumerate()
644            .for_each(|(i, (_, c))| *c = i as f64);
645        nm.sort_param_vecs();
646        let centroid = nm.calculate_centroid();
647        let expanded = nm.expand(&centroid, &vec![1.0, 1.0]);
648        assert_relative_eq!(expanded[0], 1.5f64, epsilon = f64::EPSILON);
649        assert_relative_eq!(expanded[1], 1.5f64, epsilon = f64::EPSILON);
650    }
651
652    #[test]
653    fn test_contract() {
654        let params: Vec<Vec<f64>> = vec![vec![0.0, 1.0], vec![1.0, 0.0], vec![0.0, 0.0]];
655        let mut nm: NelderMead<_, f64> = NelderMead::new(params);
656        nm.params
657            .iter_mut()
658            .enumerate()
659            .for_each(|(i, (_, c))| *c = i as f64);
660        nm.sort_param_vecs();
661        let centroid = nm.calculate_centroid();
662        let contracted = nm.contract(&centroid, &vec![1.0, 1.0]);
663        assert_relative_eq!(contracted[0], 0.75f64, epsilon = f64::EPSILON);
664        assert_relative_eq!(contracted[1], 0.75f64, epsilon = f64::EPSILON);
665    }
666
667    #[test]
668    fn test_shrink() {
669        let params: Vec<Vec<f64>> = vec![vec![0.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0]];
670        let params_shrunk: Vec<Vec<f64>> = vec![vec![0.0, 0.0], vec![0.0, 0.5], vec![0.5, 0.0]];
671        let mut nm: NelderMead<_, f64> = NelderMead::new(params);
672        nm.params
673            .iter_mut()
674            .enumerate()
675            .for_each(|(i, (_, c))| *c = i as f64);
676        nm.sort_param_vecs();
677        nm.shrink(|_| Ok(1.0f64)).unwrap();
678
679        for ((p, _), ps) in nm.params.iter().zip(params_shrunk.iter()) {
680            assert_eq!(p[0].to_ne_bytes(), ps[0].to_ne_bytes());
681            assert_eq!(p[1].to_ne_bytes(), ps[1].to_ne_bytes());
682        }
683    }
684
685    #[test]
686    fn test_init() {
687        let params: Vec<Vec<f64>> = vec![vec![-1.0, 1.0], vec![-0.5, 2.0], vec![0.7, -1.0]];
688        let params_sorted: Vec<(Vec<f64>, f64)> = vec![
689            (vec![0.7, -1.0], 0.7f64.powi(2) + 1.0f64.powi(2)),
690            (vec![-1.0, 1.0], 2.0),
691            (vec![-0.5, 2.0], 0.5f64.powi(2) + 2.0f64.powi(2)),
692        ];
693        let mut nm: NelderMead<_, f64> = NelderMead::new(params);
694        let state: IterState<Vec<f64>, (), (), (), (), f64> = IterState::new();
695        let problem = MwProblem {};
696        let (state_out, kv) = nm.init(&mut Problem::new(problem), state).unwrap();
697
698        assert!(kv.is_none());
699
700        for ((p, c), (ps, cs)) in nm.params.iter().zip(params_sorted.iter()) {
701            assert_relative_eq!(c, cs, epsilon = f64::EPSILON);
702            assert_eq!(p[0].to_ne_bytes(), ps[0].to_ne_bytes());
703            assert_eq!(p[1].to_ne_bytes(), ps[1].to_ne_bytes());
704        }
705
706        for i in 0..2 {
707            assert_relative_eq!(
708                state_out.get_param().unwrap()[i],
709                params_sorted[0].0[i],
710                epsilon = f64::EPSILON
711            );
712        }
713
714        assert_relative_eq!(
715            state_out.get_cost(),
716            0.7f64.powi(2) + 1.0f64.powi(2),
717            epsilon = f64::EPSILON
718        );
719    }
720
721    #[test]
722    fn test_next_iter_reflection() {
723        let params: Vec<Vec<f64>> = vec![vec![-1.0, 0.0], vec![-0.1, 0.65], vec![-0.1, -0.95]];
724        let mut nm: NelderMead<_, f64> = NelderMead::new(params);
725        let state: IterState<Vec<f64>, (), (), (), (), f64> = IterState::new();
726        let mut problem = Problem::new(MwProblem {});
727        let (state, _) = nm.init(&mut problem, state).unwrap();
728
729        let (state, kv) = nm.next_iter(&mut problem, state).unwrap();
730
731        assert_eq!(
732            format!("{}", kv.unwrap().get("action").unwrap()),
733            "Reflection"
734        );
735
736        let param = state.get_param().unwrap();
737
738        assert_relative_eq!(param[0], -0.1f64, epsilon = f64::EPSILON);
739        assert_relative_eq!(param[1], 0.65f64, epsilon = f64::EPSILON);
740
741        let cost = state.get_cost();
742        assert_relative_eq!(cost, 0.4325f64, epsilon = f64::EPSILON);
743
744        assert_relative_eq!(nm.params[0].0[0], -0.1f64, epsilon = f64::EPSILON);
745        assert_relative_eq!(nm.params[0].0[1], 0.65f64, epsilon = f64::EPSILON);
746        assert_relative_eq!(nm.params[0].1, 0.4325f64, epsilon = f64::EPSILON);
747
748        assert_relative_eq!(nm.params[1].0[0], 0.8f64, epsilon = f64::EPSILON);
749        assert_relative_eq!(nm.params[1].0[1], -0.3f64, epsilon = f64::EPSILON);
750        assert_relative_eq!(nm.params[1].1, 0.73f64, epsilon = f64::EPSILON);
751
752        assert_relative_eq!(nm.params[2].0[0], -0.1f64, epsilon = f64::EPSILON);
753        assert_relative_eq!(nm.params[2].0[1], -0.95f64, epsilon = f64::EPSILON);
754        assert_relative_eq!(nm.params[2].1, 0.9125f64, epsilon = f64::EPSILON);
755    }
756
757    #[test]
758    fn test_next_iter_expansion() {
759        let params: Vec<Vec<f64>> = vec![
760            vec![-2.0, 0.0],
761            vec![-1.0, 1.0],
762            // make sure that the last to vectors don't evaluate to the same cost function value
763            // which may cause strangeness in the sorting.
764            // Check this again if this test starts failing randomly...
765            vec![-1.0, -1.0 - f64::EPSILON],
766        ];
767        let mut nm: NelderMead<_, f64> = NelderMead::new(params);
768        let state: IterState<Vec<f64>, (), (), (), (), f64> = IterState::new();
769        let mut problem = Problem::new(MwProblem {});
770        let (state, _) = nm.init(&mut problem, state).unwrap();
771
772        let (state, kv) = nm.next_iter(&mut problem, state).unwrap();
773
774        assert_eq!(
775            format!("{}", kv.unwrap().get("action").unwrap()),
776            "Expansion"
777        );
778
779        let param = state.get_param().unwrap();
780
781        assert_relative_eq!(param[0], 0.0f64, epsilon = f64::EPSILON);
782        assert_relative_eq!(param[1], 0.0f64, epsilon = f64::EPSILON);
783
784        let cost = state.get_cost();
785        assert_relative_eq!(cost, 0.0f64, epsilon = f64::EPSILON);
786
787        assert_relative_eq!(nm.params[0].0[0], 0.0f64, epsilon = f64::EPSILON);
788        assert_relative_eq!(nm.params[0].0[1], 0.0f64, epsilon = f64::EPSILON);
789        assert_relative_eq!(nm.params[0].1, 0.0f64, epsilon = f64::EPSILON);
790
791        assert_relative_eq!(nm.params[1].0[0], -1.0f64, epsilon = f64::EPSILON);
792        assert_relative_eq!(nm.params[1].0[1], 1.0f64, epsilon = f64::EPSILON);
793        assert_relative_eq!(nm.params[1].1, 2.0f64, epsilon = f64::EPSILON);
794
795        assert_relative_eq!(nm.params[2].0[0], -1.0f64, epsilon = f64::EPSILON);
796        assert_relative_eq!(nm.params[2].0[1], -1.0f64, epsilon = f64::EPSILON);
797        assert_relative_eq!(nm.params[2].1, 2.0f64, epsilon = f64::EPSILON);
798    }
799
800    #[test]
801    fn test_next_iter_contraction_outside() {
802        let params: Vec<Vec<f64>> = vec![vec![-1.1, 0.0], vec![-0.1, 1.0], vec![-0.1, -0.5]];
803        let mut nm: NelderMead<_, f64> = NelderMead::new(params);
804        let state: IterState<Vec<f64>, (), (), (), (), f64> = IterState::new();
805        let mut problem = Problem::new(MwProblem {});
806        let (state, _) = nm.init(&mut problem, state).unwrap();
807
808        let (state, kv) = nm.next_iter(&mut problem, state).unwrap();
809
810        assert_eq!(
811            format!("{}", kv.unwrap().get("action").unwrap()),
812            "ContractionOutside"
813        );
814
815        let param = state.get_param().unwrap();
816
817        assert_relative_eq!(param[0], -0.1f64, epsilon = f64::EPSILON);
818        assert_relative_eq!(param[1], -0.5f64, epsilon = f64::EPSILON);
819
820        let cost = state.get_cost();
821        assert_relative_eq!(cost, 0.26f64, epsilon = f64::EPSILON);
822
823        assert_relative_eq!(nm.params[0].0[0], -0.1f64, epsilon = f64::EPSILON);
824        assert_relative_eq!(nm.params[0].0[1], -0.5f64, epsilon = f64::EPSILON);
825        assert_relative_eq!(nm.params[0].1, 0.26f64, epsilon = f64::EPSILON);
826
827        assert_relative_eq!(nm.params[1].0[0], 0.4f64, epsilon = f64::EPSILON);
828        assert_relative_eq!(nm.params[1].0[1], 0.375f64, epsilon = f64::EPSILON);
829        assert_relative_eq!(nm.params[1].1, 0.300625f64, epsilon = f64::EPSILON);
830
831        assert_relative_eq!(nm.params[2].0[0], -0.1f64, epsilon = f64::EPSILON);
832        assert_relative_eq!(nm.params[2].0[1], 1.0f64, epsilon = f64::EPSILON);
833        assert_relative_eq!(nm.params[2].1, 1.01f64, epsilon = f64::EPSILON);
834    }
835
836    #[test]
837    fn test_next_iter_contraction_inside() {
838        let params: Vec<Vec<f64>> = vec![vec![-1.0, 0.0], vec![0.0, 1.0], vec![0.0, -0.5]];
839        let mut nm: NelderMead<_, f64> = NelderMead::new(params);
840        let state: IterState<Vec<f64>, (), (), (), (), f64> = IterState::new();
841        let mut problem = Problem::new(MwProblem {});
842        let (state, _) = nm.init(&mut problem, state).unwrap();
843
844        let (state, kv) = nm.next_iter(&mut problem, state).unwrap();
845
846        assert_eq!(
847            format!("{}", kv.unwrap().get("action").unwrap()),
848            "ContractionInside"
849        );
850
851        let param = state.get_param().unwrap();
852
853        assert_relative_eq!(param[0], -0.25f64, epsilon = f64::EPSILON);
854        assert_relative_eq!(param[1], 0.375f64, epsilon = f64::EPSILON);
855
856        let cost = state.get_cost();
857        assert_relative_eq!(cost, 0.203125f64, epsilon = f64::EPSILON);
858
859        assert_relative_eq!(nm.params[0].0[0], -0.25f64, epsilon = f64::EPSILON);
860        assert_relative_eq!(nm.params[0].0[1], 0.375f64, epsilon = f64::EPSILON);
861        assert_relative_eq!(nm.params[0].1, 0.203125f64, epsilon = f64::EPSILON);
862
863        assert_relative_eq!(nm.params[1].0[0], 0.0f64, epsilon = f64::EPSILON);
864        assert_relative_eq!(nm.params[1].0[1], -0.5f64, epsilon = f64::EPSILON);
865        assert_relative_eq!(nm.params[1].1, 0.25, epsilon = f64::EPSILON);
866
867        assert_relative_eq!(nm.params[2].0[0], -1.0f64, epsilon = f64::EPSILON);
868        assert_relative_eq!(nm.params[2].0[1], 0.0f64, epsilon = f64::EPSILON);
869        assert_relative_eq!(nm.params[2].1, 1.00f64, epsilon = f64::EPSILON);
870    }
871}