argmin/solver/simulatedannealing/
mod.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! # Simulated Annealing
9//!
10//! Simulated Annealing (SA) is a stochastic optimization method which imitates annealing in
11//! metallurgy. For details see [`SimulatedAnnealing`].
12//!
13//! ## References
14//!
15//! [Wikipedia](https://en.wikipedia.org/wiki/Simulated_annealing)
16//!
17//! S Kirkpatrick, CD Gelatt Jr, MP Vecchi. (1983). "Optimization by Simulated Annealing".
18//! Science 13 May 1983, Vol. 220, Issue 4598, pp. 671-680
19//! DOI: 10.1126/science.220.4598.671
20
21use crate::core::{
22    ArgminFloat, CostFunction, Error, IterState, Problem, Solver, TerminationReason,
23    TerminationStatus, KV,
24};
25use rand::prelude::*;
26use rand_xoshiro::Xoshiro256PlusPlus;
27#[cfg(feature = "serde1")]
28use serde::{Deserialize, Serialize};
29
30/// This trait handles the annealing of a parameter vector. Problems which are to be solved using
31/// [`SimulatedAnnealing`] must implement this trait.
32pub trait Anneal {
33    /// Type of the parameter vector
34    type Param;
35    /// Return type of the anneal function
36    type Output;
37    /// Precision of floats
38    type Float;
39
40    /// Anneal a parameter vector
41    fn anneal(&self, param: &Self::Param, extent: Self::Float) -> Result<Self::Output, Error>;
42}
43
44/// Wraps a call to `anneal` defined in the `Anneal` trait and as such allows to call `anneal` on
45/// an instance of `Problem`. Internally, the number of evaluations of `anneal` is counted.
46impl<O: Anneal> Problem<O> {
47    /// Calls `anneal` defined in the `Anneal` trait and keeps track of the number of evaluations.
48    ///
49    /// # Example
50    ///
51    /// ```
52    /// # use argmin::core::{Problem, Error};
53    /// # use argmin::solver::simulatedannealing::Anneal;
54    /// #
55    /// # #[derive(Eq, PartialEq, Debug, Clone)]
56    /// # struct UserDefinedProblem {};
57    /// #
58    /// # impl Anneal for UserDefinedProblem {
59    /// #     type Param = Vec<f64>;
60    /// #     type Output = Vec<f64>;
61    /// #     type Float = f64;
62    /// #
63    /// #     fn anneal(&self, param: &Self::Param, extent: Self::Float) -> Result<Self::Output, Error> {
64    /// #         Ok(vec![1.0f64, 1.0f64])
65    /// #     }
66    /// # }
67    /// // `UserDefinedProblem` implements `Anneal`.
68    /// let mut problem1 = Problem::new(UserDefinedProblem {});
69    ///
70    /// let param = vec![2.0f64, 1.0f64];
71    ///
72    /// let res = problem1.anneal(&param, 1.0);
73    ///
74    /// assert_eq!(problem1.counts["anneal_count"], 1);
75    /// # assert_eq!(res.unwrap(), vec![1.0f64, 1.0f64]);
76    /// ```
77    pub fn anneal(&mut self, param: &O::Param, extent: O::Float) -> Result<O::Output, Error> {
78        self.problem("anneal_count", |problem| problem.anneal(param, extent))
79    }
80}
81
82/// Temperature functions for Simulated Annealing.
83///
84/// Given the initial temperature `t_init` and the iteration number `i`, the current temperature
85/// `t_i` is given as follows:
86///
87/// * `SATempFunc::TemperatureFast`: `t_i = t_init / i`
88/// * `SATempFunc::Boltzmann`: `t_i = t_init / ln(i)`
89/// * `SATempFunc::Exponential`: `t_i = t_init * 0.95^i`
90#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
91#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
92pub enum SATempFunc<F> {
93    /// `t_i = t_init / i`
94    TemperatureFast,
95    /// `t_i = t_init / ln(i)`
96    #[default]
97    Boltzmann,
98    /// `t_i = t_init * x^i`
99    Exponential(F),
100    // /// User-provided temperature function. The first parameter must be the current temperature and
101    // /// the second parameter must be the iteration number.
102    // Custom(Box<dyn Fn(f64, u64) -> f64 + 'static>),
103}
104
105/// # Simulated Annealing
106///
107/// Simulated Annealing (SA) is a stochastic optimization method which imitates annealing in
108/// metallurgy. Parameter vectors are randomly modified in each iteration, where the degree of
109/// modification depends on the current temperature. The algorithm starts with a high temperature
110/// (a lot of modification and hence movement in parameter space) and continuously cools down as
111/// the iterations progress, hence narrowing down in the search. Under certain conditions,
112/// reannealing (increasing the temperature) can be performed. Solutions which are better than the
113/// previous one are always accepted and solutions which are worse are accepted with a probability
114/// proportional to the cost function value difference of previous to current parameter vector.
115/// These measures allow the algorithm to explore the parameter space in a large and a small scale
116/// and hence it is able to overcome local minima.
117///
118/// The initial temperature has to be provided by the user as well as the a initial parameter
119/// vector (via [`configure`](`crate::core::Executor::configure`) of
120/// [`Executor`](`crate::core::Executor`).
121///
122/// The cooling schedule can be set with [`SimulatedAnnealing::with_temp_func`]. For the available
123/// choices please see [`SATempFunc`].
124///
125/// Reannealing can be performed if no new best solution was found for `N` iterations
126/// ([`SimulatedAnnealing::with_reannealing_best`]), or if no new accepted solution was found for
127/// `N` iterations ([`SimulatedAnnealing::with_reannealing_accepted`]) or every `N` iterations
128/// without any other conditions ([`SimulatedAnnealing::with_reannealing_fixed`]).
129///
130/// The user-provided problem must implement [`Anneal`] which defines how parameter vectors are
131/// modified. Please see the Simulated Annealing example for one approach to do so for floating
132/// point parameters.
133///
134/// ## Requirements on the optimization problem
135///
136/// The optimization problem is required to implement [`CostFunction`].
137///
138/// ## References
139///
140/// [Wikipedia](https://en.wikipedia.org/wiki/Simulated_annealing)
141///
142/// S Kirkpatrick, CD Gelatt Jr, MP Vecchi. (1983). "Optimization by Simulated Annealing".
143/// Science 13 May 1983, Vol. 220, Issue 4598, pp. 671-680
144/// DOI: 10.1126/science.220.4598.671
145#[derive(Clone)]
146#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
147pub struct SimulatedAnnealing<F, R> {
148    /// Initial temperature
149    init_temp: F,
150    /// Temperature function used for decreasing the temperature
151    temp_func: SATempFunc<F>,
152    /// Number of iterations used for the calculation of temperature. Needed for reannealing
153    temp_iter: u64,
154    /// Number of iterations since the last accepted solution
155    stall_iter_accepted: u64,
156    /// Stop if `stall_iter_accepted` exceeds this number
157    stall_iter_accepted_limit: u64,
158    /// Number of iterations since the last best solution was found
159    stall_iter_best: u64,
160    /// Stop if `stall_iter_best` exceeds this number
161    stall_iter_best_limit: u64,
162    /// Reanneal after this number of iterations is reached
163    reanneal_fixed: u64,
164    /// Number of iterations since beginning or last reannealing
165    reanneal_iter_fixed: u64,
166    /// Reanneal after no accepted solution has been found for `reanneal_accepted` iterations
167    reanneal_accepted: u64,
168    /// Similar to `stall_iter_accepted`, but will be reset to 0 when reannealing  is performed
169    reanneal_iter_accepted: u64,
170    /// Reanneal after no new best solution has been found for `reanneal_best` iterations
171    reanneal_best: u64,
172    /// Similar to `stall_iter_best`, but will be reset to 0 when reannealing is performed
173    reanneal_iter_best: u64,
174    /// current temperature
175    cur_temp: F,
176    /// random number generator
177    rng: R,
178}
179
180impl<F> SimulatedAnnealing<F, Xoshiro256PlusPlus>
181where
182    F: ArgminFloat,
183{
184    /// Construct a new instance of [`SimulatedAnnealing`]
185    ///
186    /// Takes the initial temperature as input, which must be >0.
187    ///
188    /// Uses the `Xoshiro256PlusPlus` RNG internally. For use of another RNG, consider using
189    /// [`SimulatedAnnealing::new_with_rng`].
190    ///
191    /// # Example
192    ///
193    /// ```
194    /// # use argmin::solver::simulatedannealing::SimulatedAnnealing;
195    /// # use argmin::core::Error;
196    /// # fn main() -> Result<(), Error> {
197    /// let sa = SimulatedAnnealing::new(100.0f64)?;
198    /// # Ok(())
199    /// # }
200    /// ```
201    pub fn new(initial_temperature: F) -> Result<Self, Error> {
202        SimulatedAnnealing::new_with_rng(
203            initial_temperature,
204            Xoshiro256PlusPlus::try_from_os_rng()?,
205        )
206    }
207}
208
209impl<F, R> SimulatedAnnealing<F, R>
210where
211    F: ArgminFloat,
212{
213    /// Construct a new instance of [`SimulatedAnnealing`]
214    ///
215    /// Takes the initial temperature as input, which must be >0.
216    /// Requires a RNG which must implement `rand::Rng` (and `serde::Serialize` if the `serde1`
217    /// feature is enabled).
218    ///
219    /// # Example
220    ///
221    /// ```
222    /// # use argmin::solver::simulatedannealing::SimulatedAnnealing;
223    /// # use argmin::core::Error;
224    /// # fn main() -> Result<(), Error> {
225    /// # let my_rng = ();
226    /// let sa = SimulatedAnnealing::new_with_rng(100.0f64, my_rng)?;
227    /// # Ok(())
228    /// # }
229    /// ```
230    pub fn new_with_rng(init_temp: F, rng: R) -> Result<Self, Error> {
231        if init_temp <= float!(0.0) {
232            Err(argmin_error!(
233                InvalidParameter,
234                "`SimulatedAnnealing`: Initial temperature must be > 0."
235            ))
236        } else {
237            Ok(SimulatedAnnealing {
238                init_temp,
239                temp_func: SATempFunc::TemperatureFast,
240                temp_iter: 0,
241                stall_iter_accepted: 0,
242                stall_iter_accepted_limit: u64::MAX,
243                stall_iter_best: 0,
244                stall_iter_best_limit: u64::MAX,
245                reanneal_fixed: u64::MAX,
246                reanneal_iter_fixed: 0,
247                reanneal_accepted: u64::MAX,
248                reanneal_iter_accepted: 0,
249                reanneal_best: u64::MAX,
250                reanneal_iter_best: 0,
251                cur_temp: init_temp,
252                rng,
253            })
254        }
255    }
256
257    /// Set temperature function
258    ///
259    /// The temperature function defines how the temperature is decreased over the course of the
260    /// iterations.
261    /// See [`SATempFunc`] for the available options. Defaults to [`SATempFunc::TemperatureFast`].
262    ///
263    /// # Example
264    ///
265    /// ```
266    /// # use argmin::solver::simulatedannealing::{SimulatedAnnealing, SATempFunc};
267    /// # use argmin::core::Error;
268    /// # fn main() -> Result<(), Error> {
269    /// let sa = SimulatedAnnealing::new(100.0f64)?.with_temp_func(SATempFunc::Boltzmann);
270    /// # Ok(())
271    /// # }
272    /// ```
273    #[must_use]
274    pub fn with_temp_func(mut self, temperature_func: SATempFunc<F>) -> Self {
275        self.temp_func = temperature_func;
276        self
277    }
278
279    /// If there are no accepted solutions for `iter` iterations, the algorithm stops.
280    ///
281    /// Defaults to `u64::MAX`.
282    ///
283    /// # Example
284    ///
285    /// ```
286    /// # use argmin::solver::simulatedannealing::{SimulatedAnnealing, SATempFunc};
287    /// # use argmin::core::Error;
288    /// # fn main() -> Result<(), Error> {
289    /// let sa = SimulatedAnnealing::new(100.0f64)?.with_stall_accepted(1000);
290    /// # Ok(())
291    /// # }
292    /// ```
293    #[must_use]
294    pub fn with_stall_accepted(mut self, iter: u64) -> Self {
295        self.stall_iter_accepted_limit = iter;
296        self
297    }
298
299    /// If there are no new best solutions for `iter` iterations, the algorithm stops.
300    ///
301    /// Defaults to `u64::MAX`.
302    ///
303    /// # Example
304    ///
305    /// ```
306    /// # use argmin::solver::simulatedannealing::{SimulatedAnnealing, SATempFunc};
307    /// # use argmin::core::Error;
308    /// # fn main() -> Result<(), Error> {
309    /// let sa = SimulatedAnnealing::new(100.0f64)?.with_stall_best(2000);
310    /// # Ok(())
311    /// # }
312    /// ```
313    #[must_use]
314    pub fn with_stall_best(mut self, iter: u64) -> Self {
315        self.stall_iter_best_limit = iter;
316        self
317    }
318
319    /// Set number of iterations after which reannealing is performed
320    ///
321    /// Every `iter` iterations, reannealing (resetting temperature to its initial value) will be
322    /// performed. This may help in overcoming local minima.
323    ///
324    /// Defaults to `u64::MAX`.
325    ///
326    /// # Example
327    ///
328    /// ```
329    /// # use argmin::solver::simulatedannealing::{SimulatedAnnealing, SATempFunc};
330    /// # use argmin::core::Error;
331    /// # fn main() -> Result<(), Error> {
332    /// let sa = SimulatedAnnealing::new(100.0f64)?.with_reannealing_fixed(5000);
333    /// # Ok(())
334    /// # }
335    /// ```
336    #[must_use]
337    pub fn with_reannealing_fixed(mut self, iter: u64) -> Self {
338        self.reanneal_fixed = iter;
339        self
340    }
341
342    /// Set the number of iterations that need to pass after the last accepted solution was found
343    /// for reannealing to be performed.
344    ///
345    /// If no new accepted solution is found for `iter` iterations, reannealing (resetting
346    /// temperature to its initial value) is performed. This may help in overcoming local minima.
347    ///
348    /// Defaults to `u64::MAX`.
349    ///
350    /// # Example
351    ///
352    /// ```
353    /// # use argmin::solver::simulatedannealing::{SimulatedAnnealing, SATempFunc};
354    /// # use argmin::core::Error;
355    /// # fn main() -> Result<(), Error> {
356    /// let sa = SimulatedAnnealing::new(100.0f64)?.with_reannealing_accepted(5000);
357    /// # Ok(())
358    /// # }
359    /// ```
360    #[must_use]
361    pub fn with_reannealing_accepted(mut self, iter: u64) -> Self {
362        self.reanneal_accepted = iter;
363        self
364    }
365
366    /// Set the number of iterations that need to pass after the last best solution was found
367    /// for reannealing to be performed.
368    ///
369    /// If no new best solution is found for `iter` iterations, reannealing (resetting temperature
370    /// to its initial value) is performed. This may help in overcoming local minima.
371    ///
372    /// Defaults to `u64::MAX`.
373    ///
374    /// # Example
375    ///
376    /// ```
377    /// # use argmin::solver::simulatedannealing::{SimulatedAnnealing, SATempFunc};
378    /// # use argmin::core::Error;
379    /// # fn main() -> Result<(), Error> {
380    /// let sa = SimulatedAnnealing::new(100.0f64)?.with_reannealing_best(5000);
381    /// # Ok(())
382    /// # }
383    /// ```
384    #[must_use]
385    pub fn with_reannealing_best(mut self, iter: u64) -> Self {
386        self.reanneal_best = iter;
387        self
388    }
389
390    /// Update the temperature based on the current iteration number.
391    ///
392    /// Updates are performed based on specific update functions. See `SATempFunc` for details.
393    fn update_temperature(&mut self) {
394        self.cur_temp = match self.temp_func {
395            SATempFunc::TemperatureFast => {
396                self.init_temp / F::from_u64(self.temp_iter + 1).unwrap()
397            }
398            SATempFunc::Boltzmann => self.init_temp / F::from_u64(self.temp_iter + 1).unwrap().ln(),
399            SATempFunc::Exponential(x) => {
400                self.init_temp * x.powf(F::from_u64(self.temp_iter + 1).unwrap())
401            }
402        };
403    }
404
405    /// Perform reannealing
406    fn reanneal(&mut self) -> (bool, bool, bool) {
407        let out = (
408            self.reanneal_iter_fixed >= self.reanneal_fixed,
409            self.reanneal_iter_accepted >= self.reanneal_accepted,
410            self.reanneal_iter_best >= self.reanneal_best,
411        );
412        if out.0 || out.1 || out.2 {
413            self.reanneal_iter_fixed = 0;
414            self.reanneal_iter_accepted = 0;
415            self.reanneal_iter_best = 0;
416            self.cur_temp = self.init_temp;
417            self.temp_iter = 0;
418        }
419        out
420    }
421
422    /// Update the stall iter variables
423    fn update_stall_and_reanneal_iter(&mut self, accepted: bool, new_best: bool) {
424        (self.stall_iter_accepted, self.reanneal_iter_accepted) = if accepted {
425            (0, 0)
426        } else {
427            (
428                self.stall_iter_accepted + 1,
429                self.reanneal_iter_accepted + 1,
430            )
431        };
432
433        (self.stall_iter_best, self.reanneal_iter_best) = if new_best {
434            (0, 0)
435        } else {
436            (self.stall_iter_best + 1, self.reanneal_iter_best + 1)
437        };
438    }
439}
440
441impl<O, P, F, R> Solver<O, IterState<P, (), (), (), (), F>> for SimulatedAnnealing<F, R>
442where
443    O: CostFunction<Param = P, Output = F> + Anneal<Param = P, Output = P, Float = F>,
444    P: Clone,
445    F: ArgminFloat,
446    R: Rng,
447{
448    fn name(&self) -> &str {
449        "Simulated Annealing"
450    }
451    fn init(
452        &mut self,
453        problem: &mut Problem<O>,
454        mut state: IterState<P, (), (), (), (), F>,
455    ) -> Result<(IterState<P, (), (), (), (), F>, Option<KV>), Error> {
456        let param = state.take_param().ok_or_else(argmin_error_closure!(
457            NotInitialized,
458            concat!(
459                "`SimulatedAnnealing` requires an initial parameter vector. ",
460                "Please provide an initial guess via `Executor`s `configure` method."
461            )
462        ))?;
463
464        let cost = state.get_cost();
465        let cost = if cost.is_infinite() {
466            problem.cost(&param)?
467        } else {
468            cost
469        };
470
471        Ok((
472            state.param(param).cost(cost),
473            Some(kv!(
474                "initial_temperature" => self.init_temp;
475                "stall_iter_accepted_limit" => self.stall_iter_accepted_limit;
476                "stall_iter_best_limit" => self.stall_iter_best_limit;
477                "reanneal_fixed" => self.reanneal_fixed;
478                "reanneal_accepted" => self.reanneal_accepted;
479                "reanneal_best" => self.reanneal_best;
480            )),
481        ))
482    }
483
484    /// Perform one iteration of SA algorithm
485    fn next_iter(
486        &mut self,
487        problem: &mut Problem<O>,
488        mut state: IterState<P, (), (), (), (), F>,
489    ) -> Result<(IterState<P, (), (), (), (), F>, Option<KV>), Error> {
490        // Careful: The order in here is *very* important, even if it may not seem so. Everything
491        // is linked to the iteration number, and getting things mixed up may lead to unexpected
492        // behavior.
493
494        let prev_param = state.take_param().ok_or_else(argmin_error_closure!(
495            PotentialBug,
496            "`SimulatedAnnealing`: Parameter vector in state not set."
497        ))?;
498        let prev_cost = state.get_cost();
499
500        // Make a move
501        let new_param = problem.anneal(&prev_param, self.cur_temp)?;
502
503        // Evaluate cost function with new parameter vector
504        let new_cost = problem.cost(&new_param)?;
505
506        // Acceptance function
507        //
508        // Decide whether new parameter vector should be accepted.
509        // If no, move on with old parameter vector.
510        //
511        // Any solution which satisfies `next_cost < prev_cost` will be accepted. Solutions worse
512        // than the previous one are accepted with a probability given as:
513        //
514        // `1 / (1 + exp((next_cost - prev_cost) / current_temperature))`,
515        //
516        // which will always be between 0 and 0.5.
517        let prob: f64 = self.rng.random();
518        let prob = float!(prob);
519        let accepted = (new_cost < prev_cost)
520            || (float!(1.0) / (float!(1.0) + ((new_cost - prev_cost) / self.cur_temp).exp())
521                > prob);
522
523        let new_best_found = new_cost < state.best_cost;
524
525        // Update stall iter variables
526        self.update_stall_and_reanneal_iter(accepted, new_best_found);
527
528        let (r_fixed, r_accepted, r_best) = self.reanneal();
529
530        // Update temperature for next iteration.
531        self.temp_iter += 1;
532        // Actually not necessary as it does the same as `temp_iter`, but I'll leave it here for
533        // better readability.
534        self.reanneal_iter_fixed += 1;
535
536        self.update_temperature();
537
538        Ok((
539            if accepted {
540                state.param(new_param).cost(new_cost)
541            } else {
542                state.param(prev_param).cost(prev_cost)
543            },
544            Some(kv!(
545                "t" => self.cur_temp;
546                "new_be" => new_best_found;
547                "acc" => accepted;
548                "st_i_be" => self.stall_iter_best;
549                "st_i_ac" => self.stall_iter_accepted;
550                "ra_i_fi" => self.reanneal_iter_fixed;
551                "ra_i_be" => self.reanneal_iter_best;
552                "ra_i_ac" => self.reanneal_iter_accepted;
553                "ra_fi" => r_fixed;
554                "ra_be" => r_best;
555                "ra_ac" => r_accepted;
556            )),
557        ))
558    }
559
560    fn terminate(&mut self, _state: &IterState<P, (), (), (), (), F>) -> TerminationStatus {
561        if self.stall_iter_accepted > self.stall_iter_accepted_limit {
562            return TerminationStatus::Terminated(TerminationReason::SolverExit(
563                "AcceptedStallIterExceeded".to_string(),
564            ));
565        }
566        if self.stall_iter_best > self.stall_iter_best_limit {
567            return TerminationStatus::Terminated(TerminationReason::SolverExit(
568                "BestStallIterExceeded".to_string(),
569            ));
570        }
571        TerminationStatus::NotTerminated
572    }
573}
574
575#[cfg(test)]
576mod tests {
577    use super::*;
578    use crate::core::{test_utils::TestProblem, ArgminError, State};
579    use approx::assert_relative_eq;
580
581    test_trait_impl!(sa, SimulatedAnnealing<f64, StdRng>);
582
583    #[test]
584    fn test_new() {
585        let sa: SimulatedAnnealing<f64, Xoshiro256PlusPlus> =
586            SimulatedAnnealing::new(100.0).unwrap();
587        let SimulatedAnnealing {
588            init_temp,
589            temp_func,
590            temp_iter,
591            stall_iter_accepted,
592            stall_iter_accepted_limit,
593            stall_iter_best,
594            stall_iter_best_limit,
595            reanneal_fixed,
596            reanneal_iter_fixed,
597            reanneal_accepted,
598            reanneal_iter_accepted,
599            reanneal_best,
600            reanneal_iter_best,
601            cur_temp,
602            rng: _rng,
603        } = sa;
604
605        assert_eq!(init_temp.to_ne_bytes(), 100.0f64.to_ne_bytes());
606        assert_eq!(temp_func, SATempFunc::TemperatureFast);
607        assert_eq!(temp_iter, 0);
608        assert_eq!(stall_iter_accepted, 0);
609        assert_eq!(stall_iter_accepted_limit, u64::MAX);
610        assert_eq!(stall_iter_best, 0);
611        assert_eq!(stall_iter_best_limit, u64::MAX);
612        assert_eq!(reanneal_fixed, u64::MAX);
613        assert_eq!(reanneal_iter_fixed, 0);
614        assert_eq!(reanneal_accepted, u64::MAX);
615        assert_eq!(reanneal_iter_accepted, 0);
616        assert_eq!(reanneal_best, u64::MAX);
617        assert_eq!(reanneal_iter_best, 0);
618        assert_eq!(cur_temp.to_ne_bytes(), 100.0f64.to_ne_bytes());
619
620        for temp in [0.0, -1.0, -f64::EPSILON, -100.0] {
621            let res = SimulatedAnnealing::new(temp);
622            assert_error!(
623                res,
624                ArgminError,
625                "Invalid parameter: \"`SimulatedAnnealing`: Initial temperature must be > 0.\""
626            );
627        }
628    }
629
630    #[test]
631    fn test_new_with_rng() {
632        #[derive(Eq, PartialEq, Debug)]
633        struct MyRng {}
634
635        let sa: SimulatedAnnealing<f64, MyRng> =
636            SimulatedAnnealing::new_with_rng(100.0, MyRng {}).unwrap();
637        let SimulatedAnnealing {
638            init_temp,
639            temp_func,
640            temp_iter,
641            stall_iter_accepted,
642            stall_iter_accepted_limit,
643            stall_iter_best,
644            stall_iter_best_limit,
645            reanneal_fixed,
646            reanneal_iter_fixed,
647            reanneal_accepted,
648            reanneal_iter_accepted,
649            reanneal_best,
650            reanneal_iter_best,
651            cur_temp,
652            rng,
653        } = sa;
654
655        assert_eq!(init_temp.to_ne_bytes(), 100.0f64.to_ne_bytes());
656        assert_eq!(temp_func, SATempFunc::TemperatureFast);
657        assert_eq!(temp_iter, 0);
658        assert_eq!(stall_iter_accepted, 0);
659        assert_eq!(stall_iter_accepted_limit, u64::MAX);
660        assert_eq!(stall_iter_best, 0);
661        assert_eq!(stall_iter_best_limit, u64::MAX);
662        assert_eq!(reanneal_fixed, u64::MAX);
663        assert_eq!(reanneal_iter_fixed, 0);
664        assert_eq!(reanneal_accepted, u64::MAX);
665        assert_eq!(reanneal_iter_accepted, 0);
666        assert_eq!(reanneal_best, u64::MAX);
667        assert_eq!(reanneal_iter_best, 0);
668        assert_eq!(cur_temp.to_ne_bytes(), 100.0f64.to_ne_bytes());
669        // important part
670        assert_eq!(rng, MyRng {});
671
672        for temp in [0.0, -1.0, -f64::EPSILON, -100.0] {
673            let res = SimulatedAnnealing::new_with_rng(temp, MyRng {});
674            assert_error!(
675                res,
676                ArgminError,
677                "Invalid parameter: \"`SimulatedAnnealing`: Initial temperature must be > 0.\""
678            );
679        }
680    }
681
682    #[test]
683    fn test_with_temp_func() {
684        for func in [
685            SATempFunc::TemperatureFast,
686            SATempFunc::Boltzmann,
687            SATempFunc::Exponential(2.0),
688        ] {
689            let sa = SimulatedAnnealing::new(100.0f64).unwrap();
690            let sa = sa.with_temp_func(func);
691
692            assert_eq!(sa.temp_func, func);
693        }
694    }
695
696    #[test]
697    fn test_with_stall_accepted() {
698        for iter in [0, 1, 5, 10, 100, 100000] {
699            let sa = SimulatedAnnealing::new(100.0f64).unwrap();
700            let sa = sa.with_stall_accepted(iter);
701
702            assert_eq!(sa.stall_iter_accepted_limit, iter);
703        }
704    }
705
706    #[test]
707    fn test_with_stall_best() {
708        for iter in [0, 1, 5, 10, 100, 100000] {
709            let sa = SimulatedAnnealing::new(100.0f64).unwrap();
710            let sa = sa.with_stall_best(iter);
711
712            assert_eq!(sa.stall_iter_best_limit, iter);
713        }
714    }
715
716    #[test]
717    fn test_with_reannealing_fixed() {
718        for iter in [0, 1, 5, 10, 100, 100000] {
719            let sa = SimulatedAnnealing::new(100.0f64).unwrap();
720            let sa = sa.with_reannealing_fixed(iter);
721
722            assert_eq!(sa.reanneal_fixed, iter);
723        }
724    }
725
726    #[test]
727    fn test_with_reannealing_accepted() {
728        for iter in [0, 1, 5, 10, 100, 100000] {
729            let sa = SimulatedAnnealing::new(100.0f64).unwrap();
730            let sa = sa.with_reannealing_accepted(iter);
731
732            assert_eq!(sa.reanneal_accepted, iter);
733        }
734    }
735
736    #[test]
737    fn test_with_reannealing_best() {
738        for iter in [0, 1, 5, 10, 100, 100000] {
739            let sa = SimulatedAnnealing::new(100.0f64).unwrap();
740            let sa = sa.with_reannealing_best(iter);
741
742            assert_eq!(sa.reanneal_best, iter);
743        }
744    }
745
746    #[test]
747    fn test_update_temperature() {
748        for (func, val) in [
749            (SATempFunc::TemperatureFast, 100.0f64 / 2.0),
750            (SATempFunc::Boltzmann, 100.0f64 / 2.0f64.ln()),
751            (SATempFunc::Exponential(3.0), 100.0 * 3.0f64.powi(2)),
752        ] {
753            let mut sa = SimulatedAnnealing::new(100.0f64)
754                .unwrap()
755                .with_temp_func(func);
756            sa.temp_iter = 1;
757
758            sa.update_temperature();
759
760            assert_relative_eq!(sa.cur_temp, val, epsilon = f64::EPSILON);
761        }
762    }
763
764    #[test]
765    fn test_reanneal() {
766        let mut sa_t = SimulatedAnnealing::new(100.0f64).unwrap();
767
768        sa_t.reanneal_fixed = 10;
769        sa_t.reanneal_accepted = 20;
770        sa_t.reanneal_best = 30;
771        sa_t.temp_iter = 40;
772        sa_t.cur_temp = 50.0;
773
774        for ((f, a, b), expected) in [
775            ((0, 0, 0), (false, false, false)),
776            ((10, 0, 0), (true, false, false)),
777            ((11, 0, 0), (true, false, false)),
778            ((0, 20, 0), (false, true, false)),
779            ((0, 21, 0), (false, true, false)),
780            ((0, 0, 30), (false, false, true)),
781            ((0, 0, 31), (false, false, true)),
782            ((10, 20, 0), (true, true, false)),
783            ((10, 0, 30), (true, false, true)),
784            ((0, 20, 30), (false, true, true)),
785            ((10, 20, 30), (true, true, true)),
786        ] {
787            let mut sa = sa_t.clone();
788
789            sa.reanneal_iter_fixed = f;
790            sa.reanneal_iter_accepted = a;
791            sa.reanneal_iter_best = b;
792
793            assert_eq!(sa.reanneal(), expected);
794
795            if expected.0 || expected.1 || expected.2 {
796                assert_eq!(sa.reanneal_iter_fixed, 0);
797                assert_eq!(sa.reanneal_iter_accepted, 0);
798                assert_eq!(sa.reanneal_iter_best, 0);
799                assert_eq!(sa.temp_iter, 0);
800                assert_eq!(sa.cur_temp.to_ne_bytes(), sa.init_temp.to_ne_bytes());
801            }
802        }
803    }
804
805    #[test]
806    fn test_update_stall_and_reanneal_iter() {
807        let mut sa_t = SimulatedAnnealing::new(100.0f64).unwrap();
808
809        sa_t.stall_iter_accepted = 10;
810        sa_t.reanneal_iter_accepted = 20;
811        sa_t.stall_iter_best = 30;
812        sa_t.reanneal_iter_best = 40;
813
814        for ((a, b), (sia, ria, sib, rib)) in [
815            ((false, false), (11, 21, 31, 41)),
816            ((false, true), (11, 21, 0, 0)),
817            ((true, false), (0, 0, 31, 41)),
818            ((true, true), (0, 0, 0, 0)),
819        ] {
820            let mut sa = sa_t.clone();
821
822            sa.update_stall_and_reanneal_iter(a, b);
823
824            assert_eq!(sa.stall_iter_accepted, sia);
825            assert_eq!(sa.reanneal_iter_accepted, ria);
826            assert_eq!(sa.stall_iter_best, sib);
827            assert_eq!(sa.reanneal_iter_best, rib);
828        }
829    }
830
831    #[test]
832    fn test_init() {
833        let param: Vec<f64> = vec![-1.0, 1.0];
834
835        let stall_iter_accepted_limit = 10;
836        let stall_iter_best_limit = 20;
837        let reanneal_fixed = 30;
838        let reanneal_accepted = 40;
839        let reanneal_best = 50;
840
841        let mut sa = SimulatedAnnealing::new(100.0f64)
842            .unwrap()
843            .with_stall_accepted(stall_iter_accepted_limit)
844            .with_stall_best(stall_iter_best_limit)
845            .with_reannealing_fixed(reanneal_fixed)
846            .with_reannealing_accepted(reanneal_accepted)
847            .with_reannealing_best(reanneal_best);
848
849        // Forgot to initialize the parameter vector
850        let state: IterState<Vec<f64>, (), (), (), (), f64> = IterState::new();
851        let problem = TestProblem::new();
852        let res = sa.init(&mut Problem::new(problem), state);
853        assert_error!(
854            res,
855            ArgminError,
856            concat!(
857                "Not initialized: \"`SimulatedAnnealing` requires an initial parameter vector. ",
858                "Please provide an initial guess via `Executor`s `configure` method.\""
859            )
860        );
861
862        // All good.
863        let state: IterState<Vec<f64>, (), (), (), (), f64> = IterState::new().param(param.clone());
864        let problem = TestProblem::new();
865        let (mut state_out, kv) = sa.init(&mut Problem::new(problem), state).unwrap();
866
867        let kv_expected = kv!(
868            "initial_temperature" => 100.0f64;
869            "stall_iter_accepted_limit" => stall_iter_accepted_limit;
870            "stall_iter_best_limit" => stall_iter_best_limit;
871            "reanneal_fixed" => reanneal_fixed;
872            "reanneal_accepted" => reanneal_accepted;
873            "reanneal_best" => reanneal_best;
874        );
875
876        assert_eq!(kv.unwrap(), kv_expected);
877
878        let s_param = state_out.take_param().unwrap();
879
880        for (s, p) in s_param.iter().zip(param.iter()) {
881            assert_eq!(s.to_ne_bytes(), p.to_ne_bytes());
882        }
883
884        assert_eq!(state_out.get_cost().to_ne_bytes(), 1.0f64.to_ne_bytes())
885    }
886}