argmin/solver/goldensectionsearch/
mod.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! # Golden-section search
9//!
10//! The golden-section search is a technique for finding an extremum (minimum or maximum) of a
11//! function inside a specified interval.
12//!
13//! See [`GoldenSectionSearch`] for details.
14//!
15//! ## Reference
16//!
17//! <https://en.wikipedia.org/wiki/Golden-section_search>
18
19use crate::core::{
20    ArgminFloat, CostFunction, Error, IterState, Problem, Solver, TerminationReason,
21    TerminationStatus, KV,
22};
23#[cfg(feature = "serde1")]
24use serde::{Deserialize, Serialize};
25
26// Golden ratio is actually 1.61803398874989484820, but that is too much precision for f64.
27const GOLDEN_RATIO: f64 = 1.618_033_988_749_895;
28const G1: f64 = -1.0 + GOLDEN_RATIO;
29const G2: f64 = 1.0 - G1;
30
31/// # Golden-section search
32///
33/// The golden-section search is a technique for finding an extremum (minimum or maximum) of a
34/// function inside a specified interval.
35///
36/// The method operates by successively narrowing the range of values on the specified interval,
37/// which makes it relatively slow, but very robust. The technique derives its name from the fact
38/// that the algorithm maintains the function values for four points whose three interval widths
39/// are in the ratio 2-φ:2φ-3:2-φ where φ is the golden ratio. These ratios are maintained for each
40/// iteration and are maximally efficient.
41///
42/// The `min_bound` and `max_bound` arguments define values that bracket the expected minimum.
43///
44/// Requires an initial guess which is to be provided via [`Executor`](`crate::core::Executor`)s
45/// `configure` method.
46///
47/// ## Requirements on the optimization problem
48///
49/// The optimization problem is required to implement [`CostFunction`].
50///
51/// ## Reference
52///
53/// <https://en.wikipedia.org/wiki/Golden-section_search>
54#[derive(Clone)]
55#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
56pub struct GoldenSectionSearch<F> {
57    g1: F,
58    g2: F,
59    min_bound: F,
60    max_bound: F,
61    tolerance: F,
62
63    x0: F,
64    x1: F,
65    x2: F,
66    x3: F,
67    f1: F,
68    f2: F,
69}
70
71impl<F> GoldenSectionSearch<F>
72where
73    F: ArgminFloat,
74{
75    /// Construct a new instance of [`GoldenSectionSearch`].
76    ///
77    /// The `min_bound` and `max_bound` arguments define values that bracket the expected minimum.
78    ///
79    /// # Example
80    ///
81    /// ```
82    /// # use argmin::solver::goldensectionsearch::GoldenSectionSearch;
83    /// # use argmin::core::Error;
84    /// # fn main() -> Result<(), Error> {
85    /// let gss = GoldenSectionSearch::new(-2.5f64, 3.0f64)?;
86    /// # Ok(())
87    /// # }
88    /// ```
89    pub fn new(min_bound: F, max_bound: F) -> Result<Self, Error> {
90        if max_bound <= min_bound {
91            return Err(argmin_error!(
92                InvalidParameter,
93                "`GoldenSectionSearch`: `min_bound` must be smaller than `max_bound`."
94            ));
95        }
96        Ok(GoldenSectionSearch {
97            g1: F::from(G1).unwrap(),
98            g2: F::from(G2).unwrap(),
99            min_bound,
100            max_bound,
101            tolerance: F::from(0.01).unwrap(),
102            x0: min_bound,
103            x1: F::zero(),
104            x2: F::zero(),
105            x3: max_bound,
106            f1: F::zero(),
107            f2: F::zero(),
108        })
109    }
110
111    /// Set tolerance.
112    ///
113    /// Must be larger than `0` and defaults to `0.01`.
114    ///
115    /// # Example
116    ///
117    /// ```
118    /// # use argmin::solver::goldensectionsearch::GoldenSectionSearch;
119    /// # use argmin::core::Error;
120    /// # fn main() -> Result<(), Error> {
121    /// let gss = GoldenSectionSearch::new(-2.5f64, 3.0f64)?.with_tolerance(0.0001)?;
122    /// # Ok(())
123    /// # }
124    /// ```
125    pub fn with_tolerance(mut self, tolerance: F) -> Result<Self, Error> {
126        if tolerance <= float!(0.0) {
127            return Err(argmin_error!(
128                InvalidParameter,
129                "`GoldenSectionSearch`: Tolerance must be larger than 0."
130            ));
131        }
132        self.tolerance = tolerance;
133        Ok(self)
134    }
135}
136
137impl<O, F> Solver<O, IterState<F, (), (), (), (), F>> for GoldenSectionSearch<F>
138where
139    O: CostFunction<Param = F, Output = F>,
140    F: ArgminFloat,
141{
142    fn name(&self) -> &str {
143        "Golden-section search"
144    }
145
146    fn init(
147        &mut self,
148        problem: &mut Problem<O>,
149        mut state: IterState<F, (), (), (), (), F>,
150    ) -> Result<(IterState<F, (), (), (), (), F>, Option<KV>), Error> {
151        let init_estimate = state.take_param().ok_or_else(argmin_error_closure!(
152            NotInitialized,
153            concat!(
154                "`GoldenSectionSearch` requires an initial estimate. ",
155                "Please provide an initial guess via `Executor`s `configure` method."
156            )
157        ))?;
158        if init_estimate < self.min_bound || init_estimate > self.max_bound {
159            Err(argmin_error!(
160                InvalidParameter,
161                "`GoldenSectionSearch`: Initial estimate must be ∈ [min_bound, max_bound]."
162            ))
163        } else {
164            let ie_min = init_estimate - self.min_bound;
165            let max_ie = self.max_bound - init_estimate;
166            let (x1, x2) = if max_ie.abs() > ie_min.abs() {
167                (init_estimate, init_estimate + self.g2 * max_ie)
168            } else {
169                (init_estimate - self.g2 * ie_min, init_estimate)
170            };
171            self.x1 = x1;
172            self.x2 = x2;
173            self.f1 = problem.cost(&self.x1)?;
174            self.f2 = problem.cost(&self.x2)?;
175            if self.f1 < self.f2 {
176                Ok((state.param(self.x1).cost(self.f1), None))
177            } else {
178                Ok((state.param(self.x2).cost(self.f2), None))
179            }
180        }
181    }
182
183    fn next_iter(
184        &mut self,
185        problem: &mut Problem<O>,
186        state: IterState<F, (), (), (), (), F>,
187    ) -> Result<(IterState<F, (), (), (), (), F>, Option<KV>), Error> {
188        if self.f2 < self.f1 {
189            self.x0 = self.x1;
190            self.x1 = self.x2;
191            self.x2 = self.g1 * self.x1 + self.g2 * self.x3;
192            self.f1 = self.f2;
193            self.f2 = problem.cost(&self.x2)?;
194        } else {
195            self.x3 = self.x2;
196            self.x2 = self.x1;
197            self.x1 = self.g1 * self.x2 + self.g2 * self.x0;
198            self.f2 = self.f1;
199            self.f1 = problem.cost(&self.x1)?;
200        }
201        if self.f1 < self.f2 {
202            Ok((state.param(self.x1).cost(self.f1), None))
203        } else {
204            Ok((state.param(self.x2).cost(self.f2), None))
205        }
206    }
207
208    fn terminate(&mut self, _state: &IterState<F, (), (), (), (), F>) -> TerminationStatus {
209        if self.tolerance * (self.x1.abs() + self.x2.abs()) >= (self.x3 - self.x0).abs() {
210            return TerminationStatus::Terminated(TerminationReason::SolverConverged);
211        }
212        TerminationStatus::NotTerminated
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219    use crate::core::{ArgminError, State};
220    use approx::assert_relative_eq;
221
222    #[derive(Clone)]
223    struct GssTestProblem {}
224
225    impl CostFunction for GssTestProblem {
226        type Param = f64;
227        type Output = f64;
228
229        fn cost(&self, x: &Self::Param) -> Result<Self::Output, Error> {
230            Ok((x + 3.0) * (x - 1.0).powi(2))
231        }
232    }
233
234    test_trait_impl!(golden_section_search, GoldenSectionSearch<f64>);
235
236    #[test]
237    fn test_new() {
238        let GoldenSectionSearch {
239            g1,
240            g2,
241            min_bound,
242            max_bound,
243            tolerance,
244            x0,
245            x1,
246            x2,
247            x3,
248            f1,
249            f2,
250        } = GoldenSectionSearch::new(-2.5f64, 3.0f64).unwrap();
251
252        assert_eq!(g1.to_ne_bytes(), G1.to_ne_bytes());
253        assert_eq!(g2.to_ne_bytes(), G2.to_ne_bytes());
254        assert_eq!(min_bound.to_ne_bytes(), (-2.5f64).to_ne_bytes());
255        assert_eq!(max_bound.to_ne_bytes(), 3.0f64.to_ne_bytes());
256        assert_eq!(tolerance.to_ne_bytes(), 0.01f64.to_ne_bytes());
257        assert_eq!(x0.to_ne_bytes(), min_bound.to_ne_bytes());
258        assert_eq!(x1.to_ne_bytes(), 0f64.to_ne_bytes());
259        assert_eq!(x2.to_ne_bytes(), 0f64.to_ne_bytes());
260        assert_eq!(x3.to_ne_bytes(), max_bound.to_ne_bytes());
261        assert_eq!(f1.to_ne_bytes(), 0f64.to_ne_bytes());
262        assert_eq!(f2.to_ne_bytes(), 0f64.to_ne_bytes());
263    }
264
265    #[test]
266    fn test_new_errors() {
267        let res = GoldenSectionSearch::new(2.5f64, -3.0f64);
268
269        assert_error!(
270            res,
271            ArgminError,
272            concat!(
273                "Invalid parameter: \"`GoldenSectionSearch`: ",
274                "`min_bound` must be smaller than `max_bound`.\""
275            )
276        );
277
278        let res = GoldenSectionSearch::new(2.5f64, 2.5f64);
279
280        assert_error!(
281            res,
282            ArgminError,
283            concat!(
284                "Invalid parameter: \"`GoldenSectionSearch`: ",
285                "`min_bound` must be smaller than `max_bound`.\""
286            )
287        );
288    }
289
290    #[test]
291    fn test_tolerance() {
292        let GoldenSectionSearch { tolerance, .. } = GoldenSectionSearch::new(-2.5f64, 3.0f64)
293            .unwrap()
294            .with_tolerance(0.001)
295            .unwrap();
296
297        assert_eq!(tolerance.to_ne_bytes(), 0.001f64.to_ne_bytes());
298    }
299
300    #[test]
301    fn test_tolerance_errors() {
302        let res = GoldenSectionSearch::new(-2.5f64, 3.0f64)
303            .unwrap()
304            .with_tolerance(0.0);
305        assert_error!(
306            res,
307            ArgminError,
308            "Invalid parameter: \"`GoldenSectionSearch`: Tolerance must be larger than 0.\""
309        );
310
311        let res = GoldenSectionSearch::new(-2.5f64, 3.0f64)
312            .unwrap()
313            .with_tolerance(-1.0);
314        assert_error!(
315            res,
316            ArgminError,
317            "Invalid parameter: \"`GoldenSectionSearch`: Tolerance must be larger than 0.\""
318        );
319    }
320
321    #[test]
322    fn test_init_param_not_initialized() {
323        let mut gss = GoldenSectionSearch::new(-2.5f64, 3.0f64).unwrap();
324        let res = gss.init(&mut Problem::new(GssTestProblem {}), IterState::new());
325        assert_error!(
326            res,
327            ArgminError,
328            concat!(
329                "Not initialized: \"`GoldenSectionSearch` requires an initial estimate. ",
330                "Please provide an initial guess via `Executor`s `configure` method.\""
331            )
332        );
333    }
334
335    #[test]
336    fn test_init_param_outside_bounds() {
337        let mut gss = GoldenSectionSearch::new(-2.5f64, 3.0f64).unwrap();
338        let res = gss.init(
339            &mut Problem::new(GssTestProblem {}),
340            IterState::new().param(5.0f64),
341        );
342        assert_error!(
343            res,
344            ArgminError,
345            "Invalid parameter: \"`GoldenSectionSearch`: Initial estimate must be ∈ [min_bound, max_bound].\""
346        );
347    }
348
349    #[test]
350    fn test_init() {
351        let mut gss = GoldenSectionSearch::new(-2.5f64, 3.0f64).unwrap();
352        let problem = GssTestProblem {};
353        let (state, kv) = gss
354            .init(
355                &mut Problem::new(problem.clone()),
356                IterState::new().param(-0.5f64),
357            )
358            .unwrap();
359
360        assert!(kv.is_none());
361
362        let GoldenSectionSearch {
363            g1,
364            g2,
365            min_bound,
366            max_bound,
367            tolerance,
368            x0,
369            x1,
370            x2,
371            x3,
372            f1,
373            f2,
374        } = gss.clone();
375
376        assert_relative_eq!(x1, -0.5f64, epsilon = f64::EPSILON);
377        assert_relative_eq!(x2, -0.5f64 + g2 * 3.5f64, epsilon = f64::EPSILON);
378        assert_relative_eq!(f1, problem.cost(&x1).unwrap(), epsilon = f64::EPSILON);
379        assert_relative_eq!(f2, problem.cost(&x2).unwrap(), epsilon = f64::EPSILON);
380        if f1 < f2 {
381            assert_relative_eq!(*state.param.as_ref().unwrap(), x1, epsilon = f64::EPSILON);
382            assert_relative_eq!(state.cost, f1, epsilon = f64::EPSILON);
383        } else {
384            assert_relative_eq!(*state.param.as_ref().unwrap(), x2, epsilon = f64::EPSILON);
385            assert_relative_eq!(state.cost, f2, epsilon = f64::EPSILON);
386        }
387
388        assert_eq!(g1.to_ne_bytes(), G1.to_ne_bytes());
389        assert_eq!(g2.to_ne_bytes(), G2.to_ne_bytes());
390        assert_eq!(min_bound.to_ne_bytes(), (-2.5f64).to_ne_bytes());
391        assert_eq!(max_bound.to_ne_bytes(), 3.0f64.to_ne_bytes());
392        assert_eq!(tolerance.to_ne_bytes(), 0.01f64.to_ne_bytes());
393        assert_eq!(x0.to_ne_bytes(), min_bound.to_ne_bytes());
394        assert_eq!(x3.to_ne_bytes(), max_bound.to_ne_bytes());
395    }
396
397    #[test]
398    fn test_next_iter_1() {
399        let mut gss = GoldenSectionSearch::new(-2.5f64, 3.0f64).unwrap();
400        let mut problem = Problem::new(GssTestProblem {});
401
402        gss.f1 = 10.0f64;
403        gss.f2 = 5.0f64;
404        gss.x0 = 0.0f64;
405        gss.x1 = 1.0f64;
406        gss.x2 = 2.0f64;
407        gss.x3 = 3.0f64;
408
409        let (state, kv) = gss
410            .next_iter(&mut problem, IterState::new().param(-0.5f64))
411            .unwrap();
412
413        assert!(kv.is_none());
414
415        let GoldenSectionSearch {
416            g1,
417            g2,
418            min_bound,
419            max_bound,
420            tolerance,
421            x0,
422            x1,
423            x2,
424            x3,
425            f1,
426            f2,
427        } = gss.clone();
428
429        assert_relative_eq!(x0, 1.0f64, epsilon = f64::EPSILON);
430        assert_relative_eq!(x1, 2.0f64, epsilon = f64::EPSILON);
431        assert_relative_eq!(x2, g1 * 2.0f64 + g2 * x3, epsilon = f64::EPSILON);
432        assert_relative_eq!(f1, 5.0f64, epsilon = f64::EPSILON);
433        assert_relative_eq!(f2, problem.cost(&x2).unwrap(), epsilon = f64::EPSILON);
434        assert_eq!(g1.to_ne_bytes(), G1.to_ne_bytes());
435        assert_eq!(g2.to_ne_bytes(), G2.to_ne_bytes());
436        assert_eq!(min_bound.to_ne_bytes(), (-2.5f64).to_ne_bytes());
437        assert_eq!(max_bound.to_ne_bytes(), 3.0f64.to_ne_bytes());
438        assert_eq!(tolerance.to_ne_bytes(), 0.01f64.to_ne_bytes());
439        if f1 < f2 {
440            assert_relative_eq!(*state.param.as_ref().unwrap(), x1, epsilon = f64::EPSILON);
441            assert_relative_eq!(state.cost, f1, epsilon = f64::EPSILON);
442        } else {
443            assert_relative_eq!(*state.param.as_ref().unwrap(), x2, epsilon = f64::EPSILON);
444            assert_relative_eq!(state.cost, f2, epsilon = f64::EPSILON);
445        }
446    }
447
448    #[test]
449    fn test_next_iter_2() {
450        let mut gss = GoldenSectionSearch::new(-2.5f64, 3.0f64).unwrap();
451        let mut problem = Problem::new(GssTestProblem {});
452
453        gss.f1 = 5.0f64;
454        gss.f2 = 10.0f64;
455        gss.x0 = 0.0f64;
456        gss.x1 = 1.0f64;
457        gss.x2 = 2.0f64;
458        gss.x3 = 3.0f64;
459
460        let (state, kv) = gss
461            .next_iter(&mut problem, IterState::new().param(-0.5f64))
462            .unwrap();
463
464        assert!(kv.is_none());
465
466        let GoldenSectionSearch {
467            g1,
468            g2,
469            min_bound,
470            max_bound,
471            tolerance,
472            x0,
473            x1,
474            x2,
475            x3,
476            f1,
477            f2,
478        } = gss.clone();
479
480        assert_relative_eq!(x0, 0.0f64, epsilon = f64::EPSILON);
481        assert_relative_eq!(x1, g1 * x2 + g2 * x0, epsilon = f64::EPSILON);
482        assert_relative_eq!(x2, 1.0f64, epsilon = f64::EPSILON);
483        assert_relative_eq!(x3, 2.0f64, epsilon = f64::EPSILON);
484        assert_relative_eq!(f1, problem.cost(&x1).unwrap(), epsilon = f64::EPSILON);
485        assert_relative_eq!(f2, 5.0f64, epsilon = f64::EPSILON);
486        assert_eq!(g1.to_ne_bytes(), G1.to_ne_bytes());
487        assert_eq!(g2.to_ne_bytes(), G2.to_ne_bytes());
488        assert_eq!(min_bound.to_ne_bytes(), (-2.5f64).to_ne_bytes());
489        assert_eq!(max_bound.to_ne_bytes(), 3.0f64.to_ne_bytes());
490        assert_eq!(tolerance.to_ne_bytes(), 0.01f64.to_ne_bytes());
491        if f1 < f2 {
492            assert_relative_eq!(*state.param.as_ref().unwrap(), x1, epsilon = f64::EPSILON);
493            assert_relative_eq!(state.cost, f1, epsilon = f64::EPSILON);
494        } else {
495            assert_relative_eq!(*state.param.as_ref().unwrap(), x2, epsilon = f64::EPSILON);
496            assert_relative_eq!(state.cost, f2, epsilon = f64::EPSILON);
497        }
498    }
499}