argmin/solver/particleswarm/
mod.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! # Particle Swarm Optimization (PSO)
9//!
10//! Canonical implementation of the particle swarm optimization method as outlined in \[0\] in
11//! chapter II, section A.
12//!
13//! For details see [`ParticleSwarm`].
14//!
15//! ## References
16//!
17//! \[0\] Zambrano-Bigiarini, M. et.al. (2013): Standard Particle Swarm Optimisation 2011 at
18//! CEC-2013: A baseline for future PSO improvements. 2013 IEEE Congress on Evolutionary
19//! Computation. <https://doi.org/10.1109/CEC.2013.6557848>
20//!
21//! \[1\] <https://en.wikipedia.org/wiki/Particle_swarm_optimization>
22
23use crate::core::{
24    ArgminFloat, CostFunction, Error, PopulationState, Problem, Solver, SyncAlias, KV,
25};
26use argmin_math::{ArgminAdd, ArgminMinMax, ArgminMul, ArgminRandom, ArgminSub, ArgminZeroLike};
27#[cfg(feature = "rand")]
28use rand::{Rng, SeedableRng};
29#[cfg(feature = "serde1")]
30use serde::{Deserialize, Serialize};
31
32/// # Particle Swarm Optimization (PSO)
33///
34/// Canonical implementation of the particle swarm optimization method as outlined in \[0\] in
35/// chapter II, section A.
36///
37/// The `rayon` feature enables parallel computation of the cost function. This can be beneficial
38/// for expensive cost functions, but may cause a drop in performance for cheap cost functions. Be
39/// sure to benchmark both parallel and sequential computation.
40///
41/// ## Requirements on the optimization problem
42///
43/// The optimization problem is required to implement [`CostFunction`].
44///
45/// ## References
46///
47/// \[0\] Zambrano-Bigiarini, M. et.al. (2013): Standard Particle Swarm Optimisation 2011 at
48/// CEC-2013: A baseline for future PSO improvements. 2013 IEEE Congress on Evolutionary
49/// Computation. <https://doi.org/10.1109/CEC.2013.6557848>
50///
51/// \[1\] <https://en.wikipedia.org/wiki/Particle_swarm_optimization>
52#[derive(Clone)]
53#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
54pub struct ParticleSwarm<P, F, R> {
55    /// Inertia weight
56    weight_inertia: F,
57    /// Cognitive acceleration coefficient
58    weight_cognitive: F,
59    /// Social acceleration coefficient
60    weight_social: F,
61    /// Bounds on parameter space
62    bounds: (P, P),
63    /// Number of particles
64    num_particles: usize,
65    /// Random number generator
66    rng_generator: R,
67}
68
69impl<P, F> ParticleSwarm<P, F, rand::rngs::StdRng>
70where
71    P: Clone + SyncAlias + ArgminSub<P, P> + ArgminMul<F, P> + ArgminRandom + ArgminZeroLike,
72    F: ArgminFloat,
73{
74    /// Construct a new instance of `ParticleSwarm`
75    ///
76    /// Takes the number of particles and bounds on the search space as inputs. `bounds` is a tuple
77    /// `(lower_bound, upper_bound)`, where `lower_bound` and `upper_bound` are of the same type as
78    /// the position of a particle (`P`) and of the same length as the problem as dimensions.
79    ///
80    /// The inertia weight on velocity and the social and cognitive acceleration factors can be
81    /// adapted with [`with_inertia_factor`](`ParticleSwarm::with_inertia_factor`),
82    /// [`with_cognitive_factor`](`ParticleSwarm::with_cognitive_factor`) and
83    /// [`with_social_factor`](`ParticleSwarm::with_social_factor`), respectively.
84    ///
85    /// The weights and acceleration factors default to:
86    ///
87    /// * inertia: `1/(2 * ln(2))`
88    /// * cognitive: `0.5 + ln(2)`
89    /// * social: `0.5 + ln(2)`
90    ///
91    /// # Example
92    ///
93    /// ```
94    /// # use argmin::solver::particleswarm::ParticleSwarm;
95    /// # let lower_bound: Vec<f64> = vec![-1.0, -1.0];
96    /// # let upper_bound: Vec<f64> = vec![1.0, 1.0];
97    /// let pso: ParticleSwarm<_, f64, _> = ParticleSwarm::new((lower_bound, upper_bound), 40);
98    /// ```
99    pub fn new(bounds: (P, P), num_particles: usize) -> Self {
100        ParticleSwarm {
101            weight_inertia: float!(1.0f64 / (2.0 * 2.0f64.ln())),
102            weight_cognitive: float!(0.5 + 2.0f64.ln()),
103            weight_social: float!(0.5 + 2.0f64.ln()),
104            bounds,
105            num_particles,
106            rng_generator: rand::rngs::StdRng::from_entropy(),
107        }
108    }
109}
110impl<P, F, R0> ParticleSwarm<P, F, R0>
111where
112    P: Clone + SyncAlias + ArgminSub<P, P> + ArgminMul<F, P> + ArgminRandom + ArgminZeroLike,
113    F: ArgminFloat,
114    R0: Rng,
115{
116    /// Set the random number generator
117    ///
118    /// Defaults to `rand::rngs::StdRng::from_entropy()`
119    ///
120    /// # Example
121    /// ```
122    /// # use argmin::solver::particleswarm::ParticleSwarm;
123    /// # use argmin::core::Error;
124    /// # use rand::SeedableRng;
125    /// # fn main() -> Result<(), Error> {
126    /// # let lower_bound: Vec<f64> = vec![-1.0, -1.0];
127    /// # let upper_bound: Vec<f64> = vec![1.0, 1.0];
128    /// let pso: ParticleSwarm<_, f64, _> =
129    ///     ParticleSwarm::new((lower_bound, upper_bound), 40)
130    ///     .with_rng_generator(rand_xoshiro::Xoroshiro128Plus::seed_from_u64(1729));
131    /// # Ok(())
132    /// # }
133    /// ```
134    pub fn with_rng_generator<R1: Rng>(self, generator: R1) -> ParticleSwarm<P, F, R1> {
135        ParticleSwarm {
136            weight_inertia: self.weight_inertia,
137            weight_cognitive: self.weight_cognitive,
138            weight_social: self.weight_social,
139            bounds: self.bounds,
140            num_particles: self.num_particles,
141            rng_generator: generator,
142        }
143    }
144}
145
146impl<P, F, R> ParticleSwarm<P, F, R>
147where
148    P: Clone + SyncAlias + ArgminSub<P, P> + ArgminMul<F, P> + ArgminRandom + ArgminZeroLike,
149    F: ArgminFloat,
150    R: Rng,
151{
152    /// Set inertia factor on particle velocity
153    ///
154    /// Defaults to `1/(2 * ln(2))`.
155    ///
156    /// # Example
157    ///
158    /// ```
159    /// # use argmin::solver::particleswarm::ParticleSwarm;
160    /// # use argmin::core::Error;
161    /// # fn main() -> Result<(), Error> {
162    /// # let lower_bound: Vec<f64> = vec![-1.0, -1.0];
163    /// # let upper_bound: Vec<f64> = vec![1.0, 1.0];
164    /// let pso: ParticleSwarm<_, f64, _> =
165    ///     ParticleSwarm::new((lower_bound, upper_bound), 40).with_inertia_factor(0.5)?;
166    /// # Ok(())
167    /// # }
168    /// ```
169    pub fn with_inertia_factor(mut self, factor: F) -> Result<Self, Error> {
170        if factor < float!(0.0) {
171            return Err(argmin_error!(
172                InvalidParameter,
173                "`ParticleSwarm`: inertia factor must be >=0."
174            ));
175        }
176        self.weight_inertia = factor;
177        Ok(self)
178    }
179
180    /// Set cognitive acceleration factor
181    ///
182    /// Defaults to `0.5 + ln(2)`.
183    ///
184    /// # Example
185    ///
186    /// ```
187    /// # use argmin::solver::particleswarm::ParticleSwarm;
188    /// # use argmin::core::Error;
189    /// # fn main() -> Result<(), Error> {
190    /// # let lower_bound: Vec<f64> = vec![-1.0, -1.0];
191    /// # let upper_bound: Vec<f64> = vec![1.0, 1.0];
192    /// let pso: ParticleSwarm<_, f64, _> =
193    ///     ParticleSwarm::new((lower_bound, upper_bound), 40).with_cognitive_factor(1.1)?;
194    /// # Ok(())
195    /// # }
196    /// ```
197    pub fn with_cognitive_factor(mut self, factor: F) -> Result<Self, Error> {
198        if factor < float!(0.0) {
199            return Err(argmin_error!(
200                InvalidParameter,
201                "`ParticleSwarm`: cognitive factor must be >=0."
202            ));
203        }
204        self.weight_cognitive = factor;
205        Ok(self)
206    }
207
208    /// Set social acceleration factor
209    ///
210    /// Defaults to `0.5 + ln(2)`.
211    ///
212    /// # Example
213    ///
214    /// ```
215    /// # use argmin::solver::particleswarm::ParticleSwarm;
216    /// # use argmin::core::Error;
217    /// # fn main() -> Result<(), Error> {
218    /// # let lower_bound: Vec<f64> = vec![-1.0, -1.0];
219    /// # let upper_bound: Vec<f64> = vec![1.0, 1.0];
220    /// let pso: ParticleSwarm<_, f64, _> =
221    ///     ParticleSwarm::new((lower_bound, upper_bound), 40).with_social_factor(1.1)?;
222    /// # Ok(())
223    /// # }
224    /// ```
225    pub fn with_social_factor(mut self, factor: F) -> Result<Self, Error> {
226        if factor < float!(0.0) {
227            return Err(argmin_error!(
228                InvalidParameter,
229                "`ParticleSwarm`: social factor must be >=0."
230            ));
231        }
232        self.weight_social = factor;
233        Ok(self)
234    }
235
236    /// Initializes all particles randomly and sorts them by their cost function values
237    fn initialize_particles<O: CostFunction<Param = P, Output = F> + SyncAlias>(
238        &mut self,
239        problem: &mut Problem<O>,
240    ) -> Result<Vec<Particle<P, F>>, Error> {
241        let (positions, velocities) = self.initialize_positions_and_velocities();
242
243        let costs = problem.bulk_cost(&positions)?;
244
245        let mut particles = positions
246            .into_iter()
247            .zip(velocities)
248            .zip(costs)
249            .map(|((p, v), c)| Particle::new(p, c, v))
250            .collect::<Vec<_>>();
251
252        // sort them, such that the first one is the best one
253        particles.sort_by(|a, b| {
254            a.cost
255                .partial_cmp(&b.cost)
256                .unwrap_or(std::cmp::Ordering::Equal)
257        });
258
259        Ok(particles)
260    }
261
262    /// Initializes positions and velocities for all particles
263    fn initialize_positions_and_velocities(&mut self) -> (Vec<P>, Vec<P>) {
264        let (min, max) = &self.bounds;
265        let delta = max.sub(min);
266        let delta_neg = delta.mul(&float!(-1.0));
267
268        (
269            (0..self.num_particles)
270                .map(|_| P::rand_from_range(min, max, &mut self.rng_generator))
271                .collect(),
272            (0..self.num_particles)
273                .map(|_| P::rand_from_range(&delta_neg, &delta, &mut self.rng_generator))
274                .collect(),
275        )
276    }
277}
278
279impl<O, P, F, R> Solver<O, PopulationState<Particle<P, F>, F>> for ParticleSwarm<P, F, R>
280where
281    O: CostFunction<Param = P, Output = F> + SyncAlias,
282    P: Clone
283        + SyncAlias
284        + ArgminAdd<P, P>
285        + ArgminSub<P, P>
286        + ArgminMul<F, P>
287        + ArgminZeroLike
288        + ArgminRandom
289        + ArgminMinMax,
290    F: ArgminFloat,
291    R: Rng,
292{
293    fn name(&self) -> &str {
294        "Particle Swarm Optimization"
295    }
296
297    fn init(
298        &mut self,
299        problem: &mut Problem<O>,
300        mut state: PopulationState<Particle<P, F>, F>,
301    ) -> Result<(PopulationState<Particle<P, F>, F>, Option<KV>), Error> {
302        // Users can provide a population or it will be randomly created.
303        let particles = match state.take_population() {
304            Some(mut particles) if particles.len() == self.num_particles => {
305                // sort them first
306                particles.sort_by(|a, b| {
307                    a.cost
308                        .partial_cmp(&b.cost)
309                        .unwrap_or(std::cmp::Ordering::Equal)
310                });
311                particles
312            }
313            Some(particles) => {
314                return Err(argmin_error!(
315                    InvalidParameter,
316                    format!(
317                        "`ParticleSwarm`: Provided list of particles is of length {}, expected {}",
318                        particles.len(),
319                        self.num_particles
320                    )
321                ))
322            }
323            None => self.initialize_particles(problem)?,
324        };
325
326        Ok((
327            state
328                .individual(particles[0].clone())
329                .cost(particles[0].cost)
330                .population(particles),
331            None,
332        ))
333    }
334
335    /// Perform one iteration of algorithm
336    fn next_iter(
337        &mut self,
338        problem: &mut Problem<O>,
339        mut state: PopulationState<Particle<P, F>, F>,
340    ) -> Result<(PopulationState<Particle<P, F>, F>, Option<KV>), Error> {
341        let mut best_particle = state.take_individual().ok_or_else(argmin_error_closure!(
342            PotentialBug,
343            "`ParticleSwarm`: No current best individual in state."
344        ))?;
345        let mut best_cost = state.get_cost();
346        let mut particles = state.take_population().ok_or_else(argmin_error_closure!(
347            PotentialBug,
348            "`ParticleSwarm`: No population in state."
349        ))?;
350
351        let zero = P::zero_like(&best_particle.position);
352
353        let positions: Vec<_> = particles
354            .iter_mut()
355            .map(|p| {
356                // New velocity is composed of
357                // 1) previous velocity (momentum),
358                // 2) motion toward particle optimum and
359                // 3) motion toward global optimum.
360
361                // ad 1)
362                let momentum = p.velocity.mul(&self.weight_inertia);
363
364                // ad 2)
365                let to_optimum = p.best_position.sub(&p.position);
366                let pull_to_optimum =
367                    P::rand_from_range(&zero, &to_optimum, &mut self.rng_generator);
368                let pull_to_optimum = pull_to_optimum.mul(&self.weight_cognitive);
369
370                // ad 3)
371                let to_global_optimum = best_particle.position.sub(&p.position);
372                let pull_to_global_optimum =
373                    P::rand_from_range(&zero, &to_global_optimum, &mut self.rng_generator)
374                        .mul(&self.weight_social);
375
376                p.velocity = momentum.add(&pull_to_optimum).add(&pull_to_global_optimum);
377                let new_position = p.position.add(&p.velocity);
378
379                // Limit to search window
380                p.position = P::min(&P::max(&new_position, &self.bounds.0), &self.bounds.1);
381                &p.position
382            })
383            .collect();
384
385        let costs = problem.bulk_cost(&positions)?;
386
387        for (p, c) in particles.iter_mut().zip(costs.into_iter()) {
388            p.cost = c;
389
390            if p.cost < p.best_cost {
391                p.best_position = p.position.clone();
392                p.best_cost = p.cost;
393
394                if p.cost < best_cost {
395                    best_particle.position = p.position.clone();
396                    best_particle.best_position = p.position.clone();
397                    best_particle.cost = p.cost;
398                    best_particle.best_cost = p.cost;
399                    best_cost = p.cost;
400                }
401            }
402        }
403
404        Ok((
405            state
406                .individual(best_particle)
407                .cost(best_cost)
408                .population(particles),
409            None,
410        ))
411    }
412}
413
414/// A single particle
415#[derive(Clone, Debug, Eq, PartialEq)]
416#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
417pub struct Particle<T, F> {
418    /// Position of particle
419    pub position: T,
420    /// Velocity of particle
421    velocity: T,
422    /// Cost of particle
423    pub cost: F,
424    /// Best position of particle so far
425    best_position: T,
426    /// Best cost of particle so far
427    best_cost: F,
428}
429
430impl<T, F> Particle<T, F>
431where
432    T: Clone,
433    F: ArgminFloat,
434{
435    /// Create a new particle with a given position, cost and velocity.
436    ///
437    /// # Example
438    ///
439    /// ```
440    /// # use argmin::solver::particleswarm::Particle;
441    /// let particle: Particle<Vec<f64>, f64> = Particle::new(vec![0.0, 1.4], 12.0, vec![0.1, 0.5]);
442    /// ```
443    pub fn new(position: T, cost: F, velocity: T) -> Particle<T, F> {
444        Particle {
445            position: position.clone(),
446            velocity,
447            cost,
448            best_position: position,
449            best_cost: cost,
450        }
451    }
452}
453
454#[cfg(test)]
455mod tests {
456    use super::*;
457    use crate::core::{test_utils::TestProblem, ArgminError, State};
458    use approx::assert_relative_eq;
459
460    test_trait_impl!(
461        particleswarm,
462        ParticleSwarm<Vec<f64>, f64, rand::rngs::StdRng>
463    );
464
465    #[test]
466    fn test_new() {
467        let lower_bound: Vec<f64> = vec![-1.0, -1.0];
468        let upper_bound: Vec<f64> = vec![1.0, 1.0];
469        let pso: ParticleSwarm<_, f64, rand::rngs::StdRng> =
470            ParticleSwarm::new((lower_bound.clone(), upper_bound.clone()), 40);
471        let ParticleSwarm {
472            weight_inertia,
473            weight_cognitive,
474            weight_social,
475            bounds,
476            num_particles,
477            ..
478        } = pso;
479
480        assert_relative_eq!(
481            weight_inertia,
482            (1.0f64 / (2.0 * 2.0f64.ln())),
483            epsilon = f64::EPSILON
484        );
485        assert_relative_eq!(
486            weight_cognitive,
487            (0.5f64 + 2.0f64.ln()),
488            epsilon = f64::EPSILON
489        );
490        assert_relative_eq!(
491            weight_social,
492            (0.5f64 + 2.0f64.ln()),
493            epsilon = f64::EPSILON
494        );
495        assert_eq!(lower_bound[0].to_ne_bytes(), bounds.0[0].to_ne_bytes());
496        assert_eq!(lower_bound[1].to_ne_bytes(), bounds.0[1].to_ne_bytes());
497        assert_eq!(upper_bound[0].to_ne_bytes(), bounds.1[0].to_ne_bytes());
498        assert_eq!(upper_bound[1].to_ne_bytes(), bounds.1[1].to_ne_bytes());
499        assert_eq!(num_particles, 40);
500    }
501
502    #[test]
503    fn test_with_inertia_factor() {
504        let lower_bound: Vec<f64> = vec![-1.0, -1.0];
505        let upper_bound: Vec<f64> = vec![1.0, 1.0];
506
507        for inertia in [0.0, f64::EPSILON, 0.5, 1.0, 1.2, 3.0] {
508            let res = ParticleSwarm::new((lower_bound.clone(), upper_bound.clone()), 40)
509                .with_inertia_factor(inertia);
510            assert!(res.is_ok());
511            assert_eq!(
512                res.unwrap().weight_inertia.to_ne_bytes(),
513                inertia.to_ne_bytes()
514            );
515        }
516
517        for inertia in [-f64::EPSILON, -0.5, -1.0, -1.2, -3.0] {
518            let res = ParticleSwarm::new((lower_bound.clone(), upper_bound.clone()), 40)
519                .with_inertia_factor(inertia);
520            assert_error!(
521                res,
522                ArgminError,
523                concat!(
524                    "Invalid parameter: \"`ParticleSwarm`: ",
525                    "inertia factor must be >=0.\""
526                )
527            );
528        }
529    }
530
531    #[test]
532    fn test_with_cognitive_factor() {
533        let lower_bound: Vec<f64> = vec![-1.0, -1.0];
534        let upper_bound: Vec<f64> = vec![1.0, 1.0];
535
536        for cognitive in [0.0, f64::EPSILON, 0.5, 1.0, 1.2, 3.0] {
537            let res = ParticleSwarm::new((lower_bound.clone(), upper_bound.clone()), 40)
538                .with_cognitive_factor(cognitive);
539            assert!(res.is_ok());
540            assert_eq!(
541                res.unwrap().weight_cognitive.to_ne_bytes(),
542                cognitive.to_ne_bytes()
543            );
544        }
545
546        for cognitive in [-f64::EPSILON, -0.5, -1.0, -1.2, -3.0] {
547            let res = ParticleSwarm::new((lower_bound.clone(), upper_bound.clone()), 40)
548                .with_cognitive_factor(cognitive);
549            assert_error!(
550                res,
551                ArgminError,
552                concat!(
553                    "Invalid parameter: \"`ParticleSwarm`: ",
554                    "cognitive factor must be >=0.\""
555                )
556            );
557        }
558    }
559
560    #[test]
561    fn test_with_social_factor() {
562        let lower_bound: Vec<f64> = vec![-1.0, -1.0];
563        let upper_bound: Vec<f64> = vec![1.0, 1.0];
564
565        for social in [0.0, f64::EPSILON, 0.5, 1.0, 1.2, 3.0] {
566            let res = ParticleSwarm::new((lower_bound.clone(), upper_bound.clone()), 40)
567                .with_social_factor(social);
568            assert!(res.is_ok());
569            assert_eq!(
570                res.unwrap().weight_social.to_ne_bytes(),
571                social.to_ne_bytes()
572            );
573        }
574
575        for social in [-f64::EPSILON, -0.5, -1.0, -1.2, -3.0] {
576            let res = ParticleSwarm::new((lower_bound.clone(), upper_bound.clone()), 40)
577                .with_social_factor(social);
578            assert_error!(
579                res,
580                ArgminError,
581                concat!(
582                    "Invalid parameter: \"`ParticleSwarm`: ",
583                    "social factor must be >=0.\""
584                )
585            );
586        }
587    }
588
589    #[test]
590    fn test_initialize_positions_and_velocities() {
591        let lower_bound: Vec<f64> = vec![-1.0, -1.0];
592        let upper_bound: Vec<f64> = vec![1.0, 1.0];
593        let num_particles = 100;
594        let mut pso: ParticleSwarm<_, f64, _> =
595            ParticleSwarm::new((lower_bound, upper_bound), num_particles);
596
597        let (positions, velocities) = pso.initialize_positions_and_velocities();
598        assert_eq!(positions.len(), num_particles);
599        assert_eq!(velocities.len(), num_particles);
600
601        for pos in positions {
602            for elem in pos {
603                assert!(elem <= 1.0f64);
604                assert!(elem >= -1.0f64);
605            }
606        }
607
608        for velo in velocities {
609            for elem in velo {
610                assert!(elem <= 2.0f64);
611                assert!(elem >= -2.0f64);
612            }
613        }
614    }
615
616    #[test]
617    fn test_initialize_particles() {
618        let lower_bound: Vec<f64> = vec![-1.0, -1.0];
619        let upper_bound: Vec<f64> = vec![1.0, 1.0];
620        let num_particles = 10;
621        let mut pso: ParticleSwarm<_, f64, _> =
622            ParticleSwarm::new((lower_bound, upper_bound), num_particles);
623
624        struct PsoProblem {
625            counter: std::sync::Arc<std::sync::Mutex<usize>>,
626            values: [f64; 10],
627        }
628
629        impl CostFunction for PsoProblem {
630            type Param = Vec<f64>;
631            type Output = f64;
632
633            fn cost(&self, _param: &Self::Param) -> Result<Self::Output, Error> {
634                let mut counter = self.counter.lock().unwrap();
635                let cost = self.values[*counter];
636                *counter += 1;
637                Ok(cost)
638            }
639        }
640
641        let mut values = [1.0, 4.0, 10.0, 2.0, -3.0, 8.0, 4.4, 8.1, 6.4, 4.5];
642
643        let mut problem = Problem::new(PsoProblem {
644            counter: std::sync::Arc::new(std::sync::Mutex::new(0)),
645            values,
646        });
647
648        values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
649
650        let particles = pso.initialize_particles(&mut problem).unwrap();
651        assert_eq!(particles.len(), num_particles);
652
653        // at least assure that they are ordered correctly and have the correct cost.
654        for (particle, cost) in particles.iter().zip(values.iter()) {
655            assert_eq!(particle.cost.to_ne_bytes(), cost.to_ne_bytes());
656        }
657    }
658
659    #[test]
660    fn test_particle_new() {
661        let init_position = vec![0.2, 3.0];
662        let init_cost = 12.0;
663        let init_velocity = vec![1.2, -1.3];
664
665        let particle: Particle<Vec<f64>, f64> =
666            Particle::new(init_position.clone(), init_cost, init_velocity.clone());
667        let Particle {
668            position,
669            velocity,
670            cost,
671            best_position,
672            best_cost,
673        } = particle;
674
675        assert_eq!(init_position, position);
676        assert_eq!(init_position, best_position);
677        assert_eq!(init_cost.to_ne_bytes(), cost.to_ne_bytes());
678        assert_eq!(init_cost.to_ne_bytes(), best_cost.to_ne_bytes());
679        assert_eq!(init_velocity, velocity);
680    }
681
682    #[test]
683    fn test_init_provided_population_wrong_size() {
684        let lower_bound: Vec<f64> = vec![-1.0, -1.0];
685        let upper_bound: Vec<f64> = vec![1.0, 1.0];
686        let mut pso: ParticleSwarm<_, f64, _> = ParticleSwarm::new((lower_bound, upper_bound), 40);
687        let state: PopulationState<Particle<Vec<f64>, f64>, f64> = PopulationState::new()
688            .population(vec![Particle::new(vec![1.0, 2.0], 12.0, vec![0.1, 0.3])]);
689        let res = pso.init(&mut Problem::new(TestProblem::new()), state);
690        assert_error!(
691            res,
692            ArgminError,
693            concat!(
694                "Invalid parameter: \"`ParticleSwarm`: ",
695                "Provided list of particles is of length 1, expected 40\"",
696            )
697        );
698    }
699
700    #[test]
701    fn test_init_provided_population_correct_size() {
702        let lower_bound: Vec<f64> = vec![-1.0, -1.0];
703        let upper_bound: Vec<f64> = vec![1.0, 1.0];
704        let particle_a = Particle::new(vec![1.0, 2.0], 12.0, vec![0.1, 0.3]);
705        let particle_b = Particle::new(vec![2.0, 3.0], 10.0, vec![0.2, 0.4]);
706        let mut pso: ParticleSwarm<_, f64, _> = ParticleSwarm::new((lower_bound, upper_bound), 2);
707        let state: PopulationState<Particle<Vec<f64>, f64>, f64> =
708            PopulationState::new().population(vec![particle_a.clone(), particle_b.clone()]);
709        let res = pso.init(&mut Problem::new(TestProblem::new()), state);
710        assert!(res.is_ok());
711        let (mut state, kv) = res.unwrap();
712        assert!(kv.is_none());
713        assert_eq!(*state.get_param().unwrap(), particle_b);
714        let population = state.take_population().unwrap();
715        // assert that it was sorted!
716        assert_eq!(population[0], particle_b);
717        assert_eq!(population[1], particle_a);
718    }
719
720    #[test]
721    fn test_init_random_population() {
722        let lower_bound: Vec<f64> = vec![-1.0, -1.0];
723        let upper_bound: Vec<f64> = vec![1.0, 1.0];
724        let mut pso: ParticleSwarm<_, f64, _> = ParticleSwarm::new((lower_bound, upper_bound), 40);
725        let state: PopulationState<Particle<Vec<f64>, f64>, f64> = PopulationState::new();
726        let res = pso.init(&mut Problem::new(TestProblem::new()), state);
727        assert!(res.is_ok());
728        let (mut state, kv) = res.unwrap();
729        assert!(kv.is_none());
730        assert!(state.get_param().is_some());
731        let population = state.take_population().unwrap();
732        assert_eq!(population.len(), 40);
733    }
734
735    #[test]
736    fn test_next_iter() {
737        struct PsoProblem {
738            counter: std::sync::Mutex<usize>,
739            values: [f64; 10],
740        }
741
742        impl CostFunction for PsoProblem {
743            type Param = Vec<f64>;
744            type Output = f64;
745
746            fn cost(&self, _param: &Self::Param) -> Result<Self::Output, Error> {
747                let cost = self.values[*self.counter.lock().unwrap() % 10];
748                *self.counter.lock().unwrap() += 1;
749                Ok(cost)
750            }
751        }
752
753        let values = [1.0, 4.0, 10.0, 2.0, -3.0, 8.0, 4.4, 8.1, 6.4, 4.4];
754
755        let mut problem = Problem::new(PsoProblem {
756            counter: std::sync::Mutex::new(0),
757            values,
758        });
759
760        // setup
761        let lower_bound: Vec<f64> = vec![-1.0, -1.0];
762        let upper_bound: Vec<f64> = vec![1.0, 1.0];
763        let mut pso: ParticleSwarm<_, f64, _> = ParticleSwarm::new((lower_bound, upper_bound), 100);
764        let state: PopulationState<Particle<Vec<f64>, f64>, f64> = PopulationState::new();
765
766        // init
767        let (mut state, _) = pso.init(&mut problem, state).unwrap();
768
769        // next_iter
770        for _ in 0..200 {
771            (state, _) = pso.next_iter(&mut problem, state).unwrap();
772            let population = state.get_population().unwrap();
773            assert_eq!(population.len(), 100);
774            for particle in population {
775                for x in particle.position.iter() {
776                    assert!(*x <= 1.0);
777                    assert!(*x >= -1.0);
778                }
779            }
780            assert_eq!(state.get_cost().to_ne_bytes(), (-3.0f64).to_ne_bytes());
781        }
782    }
783}