1use crate::core::{
9 ArgminFloat, CostFunction, Error, IterState, Problem, Solver, State, TerminationReason, KV,
10};
11#[cfg(feature = "serde1")]
12use serde::{Deserialize, Serialize};
13use thiserror::Error;
14
15#[derive(Debug, Error)]
17pub enum BrentRootError {
18 #[error("BrentRoot error: f(min) and f(max) must have different signs.")]
20 WrongSign,
21 #[error("BrentRoot error: tol must be positive.")]
23 NegativeTol,
24}
25
26#[derive(Clone)]
40#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
41pub struct BrentRoot<F> {
42 tol: F,
44 a: F,
46 b: F,
48 c: F,
50 d: F,
52 e: F,
54 fa: F,
56 fb: F,
58 fc: F,
60}
61
62impl<F: ArgminFloat> BrentRoot<F> {
63 pub fn new(min: F, max: F, tol: F) -> Self {
67 BrentRoot {
68 tol,
69 a: min,
70 b: max,
71 c: max,
72 d: F::nan(),
73 e: F::nan(),
74 fa: F::nan(),
75 fb: F::nan(),
76 fc: F::nan(),
77 }
78 }
79}
80
81impl<O, F> Solver<O, IterState<F, (), (), (), (), F>> for BrentRoot<F>
82where
83 O: CostFunction<Param = F, Output = F>,
84 F: ArgminFloat,
85{
86 fn name(&self) -> &str {
87 "BrentRoot"
88 }
89
90 fn init(
91 &mut self,
92 problem: &mut Problem<O>,
93 state: IterState<F, (), (), (), (), F>,
95 ) -> Result<(IterState<F, (), (), (), (), F>, Option<KV>), Error> {
96 self.fa = problem.cost(&self.a)?;
97 self.fb = problem.cost(&self.b)?;
98 if self.fa * self.fb > float!(0.0) {
99 return Err(BrentRootError::WrongSign.into());
100 }
101 if self.tol < F::zero() {
102 return Err(BrentRootError::NegativeTol.into());
103 }
104 self.fc = self.fb;
105 Ok((state.param(self.b).cost(self.fb.abs()), None))
106 }
107
108 fn next_iter(
109 &mut self,
110 problem: &mut Problem<O>,
111 state: IterState<F, (), (), (), (), F>,
113 ) -> Result<(IterState<F, (), (), (), (), F>, Option<KV>), Error> {
114 if (self.fb > float!(0.0) && self.fc > float!(0.0))
115 || self.fb < float!(0.0) && self.fc < float!(0.0)
116 {
117 self.c = self.a;
118 self.fc = self.fa;
119 self.d = self.b - self.a;
120 self.e = self.d;
121 }
122 if self.fc.abs() < self.fb.abs() {
123 self.a = self.b;
124 self.b = self.c;
125 self.c = self.a;
126 self.fa = self.fb;
127 self.fb = self.fc;
128 self.fc = self.fa;
129 }
130 let eff_tol = float!(2.0) * F::epsilon() * self.b.abs() + float!(0.5) * self.tol;
132 let mid = float!(0.5) * (self.c - self.b);
133 if mid.abs() <= eff_tol || self.fb == float!(0.0) {
134 return Ok((
135 state
136 .terminate_with(TerminationReason::SolverConverged)
137 .param(self.b)
138 .cost(self.fb.abs()),
139 None,
140 ));
141 }
142 if self.e.abs() >= eff_tol && self.fa.abs() > self.fb.abs() {
143 let s = self.fb / self.fa;
144 let (mut p, mut q) = if self.a == self.c {
145 (float!(2.0) * mid * s, float!(1.0) - s)
146 } else {
147 let q = self.fa / self.fc;
148 let r = self.fb / self.fc;
149 (
150 s * (float!(2.0) * mid * q * (q - r) - (self.b - self.a) * (r - float!(1.0))),
151 (q - float!(1.0)) * (r - float!(1.0)) * (s - float!(1.0)),
152 )
153 };
154 if p > float!(0.0) {
155 q = -q;
156 }
157 p = p.abs();
158 let min1 = float!(3.0) * mid * q - (eff_tol * q).abs();
159 let min2 = (self.e * q).abs();
160 if float!(2.0) * p < if min1 < min2 { min1 } else { min2 } {
161 self.e = self.d;
162 self.d = p / q;
163 } else {
164 self.d = mid;
165 self.e = self.d;
166 };
167 } else {
168 self.d = mid;
169 self.e = self.d;
170 };
171 self.a = self.b;
172 self.fa = self.fb;
173 if self.d.abs() > eff_tol {
174 self.b = self.b + self.d;
175 } else {
176 self.b = self.b
177 + if mid >= float!(0.0) {
178 eff_tol.abs()
179 } else {
180 -eff_tol.abs()
181 };
182 }
183
184 self.fb = problem.cost(&self.b)?;
185 Ok((state.param(self.b).cost(self.fb.abs()), None))
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192 use crate::core::Executor;
193 use approx::assert_relative_eq;
194
195 #[derive(Clone)]
196 struct Quadratic {}
197
198 impl CostFunction for Quadratic {
199 type Param = f64;
200 type Output = f64;
201
202 fn cost(&self, param: &Self::Param) -> Result<Self::Output, Error> {
203 Ok(param.powi(2) - 1.0) }
205 }
206
207 #[test]
208 fn test_brent_negative_tol() {
209 let min: f64 = 0.0;
210 let max: f64 = 2.0;
211 let tol: f64 = -1e-6;
212
213 let mut solver: BrentRoot<f64> = BrentRoot::new(min, max, tol);
214 let mut problem: Problem<Quadratic> = Problem::new(Quadratic {});
215
216 let result: Result<(IterState<f64, (), (), (), (), f64>, Option<KV>), Error> =
217 solver.init(&mut problem, IterState::new());
218
219 assert!(result.is_err());
221 assert_eq!(
222 result.err().unwrap().to_string(),
223 "BrentRoot error: tol must be positive."
224 );
225 }
226
227 #[test]
228 fn test_brent_invalid_range() {
229 let min: f64 = 2.0;
230 let max: f64 = 3.0;
231 let tol: f64 = 1e-6;
232
233 let mut solver: BrentRoot<f64> = BrentRoot::new(min, max, tol);
234 let mut problem: Problem<Quadratic> = Problem::new(Quadratic {});
235
236 let result: Result<(IterState<f64, (), (), (), (), f64>, Option<KV>), Error> =
237 solver.init(&mut problem, IterState::new());
238
239 assert!(result.is_err());
241 assert_eq!(
242 result.err().unwrap().to_string(),
243 "BrentRoot error: f(min) and f(max) must have different signs."
244 );
245 }
246
247 #[test]
248 fn test_brent_valid_range() {
249 let min: f64 = 0.0;
250 let max: f64 = 2.0;
251 let tol: f64 = 1e-6;
252
253 let mut solver: BrentRoot<f64> = BrentRoot::new(min, max, tol);
254 let mut problem: Problem<Quadratic> = Problem::new(Quadratic {});
255
256 let result: Result<(IterState<f64, (), (), (), (), f64>, Option<KV>), Error> =
257 solver.init(&mut problem, IterState::new());
258
259 assert!(result.is_ok());
261 }
262
263 #[test]
264 fn test_brent_find_root() {
265 let min: f64 = 0.0;
266 let max: f64 = 2.0;
267 let tol: f64 = 1e-6;
268 let init_param: f64 = 1.5;
269
270 let solver: BrentRoot<f64> = BrentRoot::new(min, max, tol);
271 let problem: Quadratic = Quadratic {};
272
273 let res = Executor::new(problem, solver)
274 .configure(|state| state.param(init_param).max_iters(100))
275 .run()
276 .unwrap();
277
278 assert_relative_eq!(res.state.best_param.unwrap(), 1.0, epsilon = tol);
280 }
281
282 #[test]
283 fn test_brent_symmetry() {
284 let min: f64 = 0.0;
285 let max: f64 = 2.0;
286 let tol: f64 = 1e-6;
287 let init_param: f64 = 1.5;
288
289 let problem: Quadratic = Quadratic {};
290
291 let solver1: BrentRoot<f64> = BrentRoot::new(min, max, tol);
293 let res1 = Executor::new(problem.clone(), solver1)
294 .configure(|state| state.param(init_param).max_iters(100))
295 .run()
296 .unwrap();
297
298 let solver2: BrentRoot<f64> = BrentRoot::new(max, min, tol);
300 let res2 = Executor::new(problem, solver2)
301 .configure(|state| state.param(init_param).max_iters(100))
302 .run()
303 .unwrap();
304
305 assert_relative_eq!(
307 res1.state.param.unwrap(),
308 res2.state.param.unwrap(),
309 epsilon = tol,
310 );
311
312 assert_eq!(res1.state.get_iter(), res2.state.get_iter());
314 }
315}