argmin/solver/brent/
brentopt.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use crate::core::{
9    ArgminFloat, CostFunction, Error, IterState, Problem, Solver, State, TerminationReason, KV,
10};
11#[cfg(feature = "serde1")]
12use serde::{Deserialize, Serialize};
13
14/// # Brent's method
15///
16/// A minimization algorithm combining parabolic interpolation and the
17/// golden-section method.  It has the reliability of the golden-section
18/// method, but can be faster thanks to the parabolic interpolation steps.
19///
20/// ## Requirements on the optimization problem
21///
22/// The optimization problem is required to implement [`CostFunction`].
23///
24/// ## Reference
25///
26/// "An algorithm with guaranteed convergence for finding a minimum of
27/// a function of one variable", _Algorithms for minimization without
28/// derivatives_, Richard P. Brent, 1973, Prentice-Hall.
29#[derive(Clone)]
30#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
31pub struct BrentOpt<F> {
32    /// relative tolerance
33    eps: F,
34    /// absolute tolerance
35    t: F,
36    /// left boundary of current interval
37    a: F,
38    /// right boundary of current interval
39    b: F,
40    /// last point where f was evaluated
41    u: F,
42    /// previous value of w
43    v: F,
44    /// point with the current second lowest value of f
45    w: F,
46    /// point with the current lowest value of f
47    x: F,
48    /// value of f in v
49    fv: F,
50    /// value of f in w
51    fw: F,
52    /// value of f in x
53    fx: F,
54    /// value of p/q in the second last step
55    e: F,
56    /// value of p/q in the last step
57    d: F,
58    /// (3-sqrt(5)) / 2
59    c: F,
60}
61
62impl<F: ArgminFloat> BrentOpt<F> {
63    /// Constructor
64    ///
65    /// The values `min` and `max` must bracket the minimum of the function.
66    pub fn new(min: F, max: F) -> Self {
67        BrentOpt {
68            eps: F::epsilon().sqrt(),
69            t: float!(1e-5),
70            a: min,
71            b: max,
72            u: F::nan(),
73            v: F::nan(),
74            w: F::nan(),
75            x: F::nan(),
76            fv: F::nan(),
77            fw: F::nan(),
78            fx: F::nan(),
79            e: F::zero(),
80            d: F::zero(),
81            c: float!((3f64 - 5f64.sqrt()) / 2f64),
82        }
83    }
84
85    /// Set the tolerance to the value required.
86    ///
87    /// The algorithm will return an approximation `x` of a local
88    /// minimum of the function, with an accuracy smaller than `3 tol`,
89    /// where `tol = eps*abs(x) + t`.
90    /// It is useless to set `eps` to less than the square root of the
91    /// machine precision (`F::epsilon().sqrt()`), which is its default
92    /// value.  The default value of `t` is `1e-5`.
93    pub fn set_tolerance(mut self, eps: F, t: F) -> Self {
94        self.eps = eps;
95        self.t = t;
96        self
97    }
98}
99
100impl<O, F> Solver<O, IterState<F, (), (), (), (), F>> for BrentOpt<F>
101where
102    O: CostFunction<Param = F, Output = F>,
103    F: ArgminFloat,
104{
105    fn name(&self) -> &str {
106        "BrentOpt"
107    }
108
109    fn init(
110        &mut self,
111        problem: &mut Problem<O>,
112        // BrentOpt maintains its own state
113        state: IterState<F, (), (), (), (), F>,
114    ) -> Result<(IterState<F, (), (), (), (), F>, Option<KV>), Error> {
115        let u = self.a + self.c * (self.b - self.a);
116        self.v = u;
117        self.w = u;
118        self.x = u;
119        let f = problem.cost(&u)?;
120        self.fv = f;
121        self.fw = f;
122        self.fx = f;
123        Ok((state.param(self.x).cost(self.fx), None))
124    }
125
126    fn next_iter(
127        &mut self,
128        problem: &mut Problem<O>,
129        // BrentOpt maintains its own state
130        state: IterState<F, (), (), (), (), F>,
131    ) -> Result<(IterState<F, (), (), (), (), F>, Option<KV>), Error> {
132        let two = float!(2f64);
133        let tol = self.eps * self.x.abs() + self.t;
134        let m = (self.a + self.b) / two;
135        if (self.x - m).abs() <= two * tol - (self.b - self.a) / two {
136            return Ok((
137                state
138                    .terminate_with(TerminationReason::SolverConverged)
139                    .param(self.x)
140                    .cost(self.fx),
141                None,
142            ));
143        }
144        let p = (self.x - self.v) * (self.x - self.v) * (self.fx - self.fw)
145            - (self.x - self.w) * (self.x - self.w) * (self.fx - self.fv);
146        let q = two
147            * ((self.x - self.w) * (self.fx - self.fv) - (self.x - self.v) * (self.fx - self.fw));
148        let (p, q) = if q >= F::zero() { (p, q) } else { (-p, -q) };
149        self.d = if self.e.abs() <= tol
150            || p < q * (self.a - self.x)
151            || p > q * (self.b - self.x)
152            || two * p.abs() >= q * self.e.abs()
153        {
154            // golden section step
155            self.e = if self.x < m { self.b } else { self.a } - self.x;
156            self.c * self.e
157        } else {
158            // parabolic interpolation step
159            self.e = self.d;
160            let d = p / q;
161            // f must not be evaluated too close from a and b
162            if self.x + d - self.a < two * tol || self.b - self.x - d < two * tol {
163                (m - self.x).signum() * tol
164            } else {
165                d
166            }
167        };
168        // f must not be evaluated too close from x
169        self.u = self.x
170            + if self.d.abs() >= tol {
171                self.d
172            } else {
173                self.d.signum() * tol
174            };
175        let fu = problem.cost(&self.u)?;
176        if fu <= self.fx {
177            if self.u < self.x {
178                self.b = self.x;
179            } else {
180                self.a = self.x;
181            }
182            // v is the previous w
183            self.v = self.w;
184            self.fv = self.fw;
185            // w is the second lowest value (former x)
186            self.w = self.x;
187            self.fw = self.fx;
188            // x is the lowest value (u)
189            self.x = self.u;
190            self.fx = fu;
191        } else {
192            if self.u < self.x {
193                self.a = self.u;
194            } else {
195                self.b = self.u;
196            }
197            if fu <= self.fw || self.w == self.x {
198                self.v = self.w;
199                self.fv = self.fw;
200                self.w = self.u;
201                self.fw = fu;
202            } else if fu <= self.fv || self.v == self.x || self.v == self.w {
203                self.v = self.u;
204                self.fv = fu;
205            }
206        }
207        Ok((state.param(self.x).cost(self.fx), None))
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use crate::core::{Executor, TerminationStatus};
215    use approx::assert_relative_eq;
216
217    test_trait_impl!(brent, BrentOpt<f64>);
218
219    struct TestFunc {}
220    impl CostFunction for TestFunc {
221        type Param = f64;
222        type Output = f64;
223
224        fn cost(&self, x: &Self::Param) -> Result<Self::Output, Error> {
225            Ok((-x).exp() - (5. - x / 2.).exp())
226        }
227    }
228
229    #[test]
230    fn test_brent() {
231        let cost = TestFunc {};
232        let solver = BrentOpt::new(-10., 10.);
233        let res = Executor::new(cost, solver)
234            .configure(|state| state.counting(true).max_iters(13))
235            .run()
236            .unwrap();
237        assert_eq!(
238            res.state().termination_status,
239            TerminationStatus::Terminated(TerminationReason::SolverConverged)
240        );
241        assert_relative_eq!(
242            res.state().param.unwrap(),
243            -8.613701289624956,
244            epsilon = f64::EPSILON.sqrt()
245        );
246        assert_relative_eq!(
247            res.state().prev_param.unwrap(),
248            -8.613701289624956,
249            epsilon = f64::EPSILON.sqrt()
250        );
251        assert_relative_eq!(
252            res.state().best_param.unwrap(),
253            -8.613701289624956,
254            epsilon = f64::EPSILON.sqrt()
255        );
256        assert_relative_eq!(
257            res.state().prev_best_param.unwrap(),
258            -8.613570813317839,
259            epsilon = f64::EPSILON.sqrt()
260        );
261        assert_relative_eq!(
262            res.state().cost,
263            -5506.616448675639,
264            epsilon = f64::EPSILON.sqrt()
265        );
266        assert_relative_eq!(
267            res.state().best_cost,
268            -5506.616448675639,
269            epsilon = f64::EPSILON.sqrt()
270        );
271        assert_relative_eq!(
272            res.state().prev_cost,
273            -5506.616448675639,
274            epsilon = f64::EPSILON.sqrt()
275        );
276        assert_relative_eq!(
277            res.state().prev_best_cost,
278            -5506.616423678641,
279            epsilon = f64::EPSILON.sqrt()
280        );
281        assert_eq!(res.state().iter, 13);
282        assert_eq!(res.state().get_func_counts()["cost_count"], 13);
283    }
284}