argmin/solver/brent/
brentroot.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use crate::core::{
9    ArgminFloat, CostFunction, Error, IterState, Problem, Solver, State, TerminationReason, KV,
10};
11#[cfg(feature = "serde1")]
12use serde::{Deserialize, Serialize};
13use thiserror::Error;
14
15/// Error to be thrown if Brent is initialized with improper parameters.
16#[derive(Debug, Error)]
17pub enum BrentRootError {
18    /// f(min) and f(max) must have different signs
19    #[error("BrentRoot error: f(min) and f(max) must have different signs.")]
20    WrongSign,
21    // tol must be positive
22    #[error("BrentRoot error: tol must be positive.")]
23    NegativeTol,
24}
25
26/// # Brent's method
27///
28/// A root-finding algorithm combining the bisection method, the secant method
29/// and inverse quadratic interpolation. It has the reliability of bisection
30/// but it can be as quick as some of the less-reliable methods.
31///
32/// ## Requirements on the optimization problem
33///
34/// The optimization problem is required to implement [`CostFunction`].
35///
36/// ##  Reference
37///
38/// <https://en.wikipedia.org/wiki/Brent%27s_method>
39#[derive(Clone)]
40#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
41pub struct BrentRoot<F> {
42    /// required relative accuracy
43    tol: F,
44    /// left or right boundary of current interval
45    a: F,
46    /// currently proposed best guess
47    b: F,
48    /// left or right boundary of current interval
49    c: F,
50    /// helper variable
51    d: F,
52    /// another helper variable
53    e: F,
54    /// function value at `a`
55    fa: F,
56    /// function value at `b`
57    fb: F,
58    /// function value at `c`
59    fc: F,
60}
61
62impl<F: ArgminFloat> BrentRoot<F> {
63    /// Constructor
64    /// The values `min` and `max` must bracketing the root of the function.
65    /// The parameter `tol` specifies the relative error to be targeted.
66    pub fn new(min: F, max: F, tol: F) -> Self {
67        BrentRoot {
68            tol,
69            a: min,
70            b: max,
71            c: max,
72            d: F::nan(),
73            e: F::nan(),
74            fa: F::nan(),
75            fb: F::nan(),
76            fc: F::nan(),
77        }
78    }
79}
80
81impl<O, F> Solver<O, IterState<F, (), (), (), (), F>> for BrentRoot<F>
82where
83    O: CostFunction<Param = F, Output = F>,
84    F: ArgminFloat,
85{
86    fn name(&self) -> &str {
87        "BrentRoot"
88    }
89
90    fn init(
91        &mut self,
92        problem: &mut Problem<O>,
93        // BrentRoot maintains its own state
94        state: IterState<F, (), (), (), (), F>,
95    ) -> Result<(IterState<F, (), (), (), (), F>, Option<KV>), Error> {
96        self.fa = problem.cost(&self.a)?;
97        self.fb = problem.cost(&self.b)?;
98        if self.fa * self.fb > float!(0.0) {
99            return Err(BrentRootError::WrongSign.into());
100        }
101        if self.tol < F::zero() {
102            return Err(BrentRootError::NegativeTol.into());
103        }
104        self.fc = self.fb;
105        Ok((state.param(self.b).cost(self.fb.abs()), None))
106    }
107
108    fn next_iter(
109        &mut self,
110        problem: &mut Problem<O>,
111        // BrentRoot maintains its own state
112        state: IterState<F, (), (), (), (), F>,
113    ) -> Result<(IterState<F, (), (), (), (), F>, Option<KV>), Error> {
114        if (self.fb > float!(0.0) && self.fc > float!(0.0))
115            || self.fb < float!(0.0) && self.fc < float!(0.0)
116        {
117            self.c = self.a;
118            self.fc = self.fa;
119            self.d = self.b - self.a;
120            self.e = self.d;
121        }
122        if self.fc.abs() < self.fb.abs() {
123            self.a = self.b;
124            self.b = self.c;
125            self.c = self.a;
126            self.fa = self.fb;
127            self.fb = self.fc;
128            self.fc = self.fa;
129        }
130        // effective tolerance is double machine precision plus half tolerance as given.
131        let eff_tol = float!(2.0) * F::epsilon() * self.b.abs() + float!(0.5) * self.tol;
132        let mid = float!(0.5) * (self.c - self.b);
133        if mid.abs() <= eff_tol || self.fb == float!(0.0) {
134            return Ok((
135                state
136                    .terminate_with(TerminationReason::SolverConverged)
137                    .param(self.b)
138                    .cost(self.fb.abs()),
139                None,
140            ));
141        }
142        if self.e.abs() >= eff_tol && self.fa.abs() > self.fb.abs() {
143            let s = self.fb / self.fa;
144            let (mut p, mut q) = if self.a == self.c {
145                (float!(2.0) * mid * s, float!(1.0) - s)
146            } else {
147                let q = self.fa / self.fc;
148                let r = self.fb / self.fc;
149                (
150                    s * (float!(2.0) * mid * q * (q - r) - (self.b - self.a) * (r - float!(1.0))),
151                    (q - float!(1.0)) * (r - float!(1.0)) * (s - float!(1.0)),
152                )
153            };
154            if p > float!(0.0) {
155                q = -q;
156            }
157            p = p.abs();
158            let min1 = float!(3.0) * mid * q - (eff_tol * q).abs();
159            let min2 = (self.e * q).abs();
160            if float!(2.0) * p < if min1 < min2 { min1 } else { min2 } {
161                self.e = self.d;
162                self.d = p / q;
163            } else {
164                self.d = mid;
165                self.e = self.d;
166            };
167        } else {
168            self.d = mid;
169            self.e = self.d;
170        };
171        self.a = self.b;
172        self.fa = self.fb;
173        if self.d.abs() > eff_tol {
174            self.b = self.b + self.d;
175        } else {
176            self.b = self.b
177                + if mid >= float!(0.0) {
178                    eff_tol.abs()
179                } else {
180                    -eff_tol.abs()
181                };
182        }
183
184        self.fb = problem.cost(&self.b)?;
185        Ok((state.param(self.b).cost(self.fb.abs()), None))
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192    use crate::core::Executor;
193    use approx::assert_relative_eq;
194
195    #[derive(Clone)]
196    struct Quadratic {}
197
198    impl CostFunction for Quadratic {
199        type Param = f64;
200        type Output = f64;
201
202        fn cost(&self, param: &Self::Param) -> Result<Self::Output, Error> {
203            Ok(param.powi(2) - 1.0) // x^2 - 1
204        }
205    }
206
207    #[test]
208    fn test_brent_negative_tol() {
209        let min: f64 = 0.0;
210        let max: f64 = 2.0;
211        let tol: f64 = -1e-6;
212
213        let mut solver: BrentRoot<f64> = BrentRoot::new(min, max, tol);
214        let mut problem: Problem<Quadratic> = Problem::new(Quadratic {});
215
216        let result: Result<(IterState<f64, (), (), (), (), f64>, Option<KV>), Error> =
217            solver.init(&mut problem, IterState::new());
218
219        // Check if the initialization fails and we get the correct error message
220        assert!(result.is_err());
221        assert_eq!(
222            result.err().unwrap().to_string(),
223            "BrentRoot error: tol must be positive."
224        );
225    }
226
227    #[test]
228    fn test_brent_invalid_range() {
229        let min: f64 = 2.0;
230        let max: f64 = 3.0;
231        let tol: f64 = 1e-6;
232
233        let mut solver: BrentRoot<f64> = BrentRoot::new(min, max, tol);
234        let mut problem: Problem<Quadratic> = Problem::new(Quadratic {});
235
236        let result: Result<(IterState<f64, (), (), (), (), f64>, Option<KV>), Error> =
237            solver.init(&mut problem, IterState::new());
238
239        // Check if the initialization fails and we get the correct error message
240        assert!(result.is_err());
241        assert_eq!(
242            result.err().unwrap().to_string(),
243            "BrentRoot error: f(min) and f(max) must have different signs."
244        );
245    }
246
247    #[test]
248    fn test_brent_valid_range() {
249        let min: f64 = 0.0;
250        let max: f64 = 2.0;
251        let tol: f64 = 1e-6;
252
253        let mut solver: BrentRoot<f64> = BrentRoot::new(min, max, tol);
254        let mut problem: Problem<Quadratic> = Problem::new(Quadratic {});
255
256        let result: Result<(IterState<f64, (), (), (), (), f64>, Option<KV>), Error> =
257            solver.init(&mut problem, IterState::new());
258
259        // Check if the initialization is successful
260        assert!(result.is_ok());
261    }
262
263    #[test]
264    fn test_brent_find_root() {
265        let min: f64 = 0.0;
266        let max: f64 = 2.0;
267        let tol: f64 = 1e-6;
268        let init_param: f64 = 1.5;
269
270        let solver: BrentRoot<f64> = BrentRoot::new(min, max, tol);
271        let problem: Quadratic = Quadratic {};
272
273        let res = Executor::new(problem, solver)
274            .configure(|state| state.param(init_param).max_iters(100))
275            .run()
276            .unwrap();
277
278        // Check if the result is close to the real root
279        assert_relative_eq!(res.state.best_param.unwrap(), 1.0, epsilon = tol);
280    }
281
282    #[test]
283    fn test_brent_symmetry() {
284        let min: f64 = 0.0;
285        let max: f64 = 2.0;
286        let tol: f64 = 1e-6;
287        let init_param: f64 = 1.5;
288
289        let problem: Quadratic = Quadratic {};
290
291        // First run with [min, max] interval
292        let solver1: BrentRoot<f64> = BrentRoot::new(min, max, tol);
293        let res1 = Executor::new(problem.clone(), solver1)
294            .configure(|state| state.param(init_param).max_iters(100))
295            .run()
296            .unwrap();
297
298        // Second run with [max, min] interval (swapped inputs)
299        let solver2: BrentRoot<f64> = BrentRoot::new(max, min, tol);
300        let res2 = Executor::new(problem, solver2)
301            .configure(|state| state.param(init_param).max_iters(100))
302            .run()
303            .unwrap();
304
305        // Check if the results are the same
306        assert_relative_eq!(
307            res1.state.param.unwrap(),
308            res2.state.param.unwrap(),
309            epsilon = tol,
310        );
311
312        // Check if the number of iterations is the same
313        assert_eq!(res1.state.get_iter(), res2.state.get_iter());
314    }
315}