1use crate::core::{
9 ArgminFloat, CostFunction, Error, IterState, Problem, Solver, State, TerminationReason, KV,
10};
11#[cfg(feature = "serde1")]
12use serde::{Deserialize, Serialize};
13
14#[derive(Clone)]
30#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
31pub struct BrentOpt<F> {
32 eps: F,
34 t: F,
36 a: F,
38 b: F,
40 u: F,
42 v: F,
44 w: F,
46 x: F,
48 fv: F,
50 fw: F,
52 fx: F,
54 e: F,
56 d: F,
58 c: F,
60}
61
62impl<F: ArgminFloat> BrentOpt<F> {
63 pub fn new(min: F, max: F) -> Self {
67 BrentOpt {
68 eps: F::epsilon().sqrt(),
69 t: float!(1e-5),
70 a: min,
71 b: max,
72 u: F::nan(),
73 v: F::nan(),
74 w: F::nan(),
75 x: F::nan(),
76 fv: F::nan(),
77 fw: F::nan(),
78 fx: F::nan(),
79 e: F::zero(),
80 d: F::zero(),
81 c: float!((3f64 - 5f64.sqrt()) / 2f64),
82 }
83 }
84
85 pub fn set_tolerance(mut self, eps: F, t: F) -> Self {
94 self.eps = eps;
95 self.t = t;
96 self
97 }
98}
99
100impl<O, F> Solver<O, IterState<F, (), (), (), (), F>> for BrentOpt<F>
101where
102 O: CostFunction<Param = F, Output = F>,
103 F: ArgminFloat,
104{
105 fn name(&self) -> &str {
106 "BrentOpt"
107 }
108
109 fn init(
110 &mut self,
111 problem: &mut Problem<O>,
112 state: IterState<F, (), (), (), (), F>,
114 ) -> Result<(IterState<F, (), (), (), (), F>, Option<KV>), Error> {
115 let u = self.a + self.c * (self.b - self.a);
116 self.v = u;
117 self.w = u;
118 self.x = u;
119 let f = problem.cost(&u)?;
120 self.fv = f;
121 self.fw = f;
122 self.fx = f;
123 Ok((state.param(self.x).cost(self.fx), None))
124 }
125
126 fn next_iter(
127 &mut self,
128 problem: &mut Problem<O>,
129 state: IterState<F, (), (), (), (), F>,
131 ) -> Result<(IterState<F, (), (), (), (), F>, Option<KV>), Error> {
132 let two = float!(2f64);
133 let tol = self.eps * self.x.abs() + self.t;
134 let m = (self.a + self.b) / two;
135 if (self.x - m).abs() <= two * tol - (self.b - self.a) / two {
136 return Ok((
137 state
138 .terminate_with(TerminationReason::SolverConverged)
139 .param(self.x)
140 .cost(self.fx),
141 None,
142 ));
143 }
144 let p = (self.x - self.v) * (self.x - self.v) * (self.fx - self.fw)
145 - (self.x - self.w) * (self.x - self.w) * (self.fx - self.fv);
146 let q = two
147 * ((self.x - self.w) * (self.fx - self.fv) - (self.x - self.v) * (self.fx - self.fw));
148 let (p, q) = if q >= F::zero() { (p, q) } else { (-p, -q) };
149 self.d = if self.e.abs() <= tol
150 || p < q * (self.a - self.x)
151 || p > q * (self.b - self.x)
152 || two * p.abs() >= q * self.e.abs()
153 {
154 self.e = if self.x < m { self.b } else { self.a } - self.x;
156 self.c * self.e
157 } else {
158 self.e = self.d;
160 let d = p / q;
161 if self.x + d - self.a < two * tol || self.b - self.x - d < two * tol {
163 (m - self.x).signum() * tol
164 } else {
165 d
166 }
167 };
168 self.u = self.x
170 + if self.d.abs() >= tol {
171 self.d
172 } else {
173 self.d.signum() * tol
174 };
175 let fu = problem.cost(&self.u)?;
176 if fu <= self.fx {
177 if self.u < self.x {
178 self.b = self.x;
179 } else {
180 self.a = self.x;
181 }
182 self.v = self.w;
184 self.fv = self.fw;
185 self.w = self.x;
187 self.fw = self.fx;
188 self.x = self.u;
190 self.fx = fu;
191 } else {
192 if self.u < self.x {
193 self.a = self.u;
194 } else {
195 self.b = self.u;
196 }
197 if fu <= self.fw || self.w == self.x {
198 self.v = self.w;
199 self.fv = self.fw;
200 self.w = self.u;
201 self.fw = fu;
202 } else if fu <= self.fv || self.v == self.x || self.v == self.w {
203 self.v = self.u;
204 self.fv = fu;
205 }
206 }
207 Ok((state.param(self.x).cost(self.fx), None))
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use crate::core::{Executor, TerminationStatus};
215 use approx::assert_relative_eq;
216
217 test_trait_impl!(brent, BrentOpt<f64>);
218
219 struct TestFunc {}
220 impl CostFunction for TestFunc {
221 type Param = f64;
222 type Output = f64;
223
224 fn cost(&self, x: &Self::Param) -> Result<Self::Output, Error> {
225 Ok((-x).exp() - (5. - x / 2.).exp())
226 }
227 }
228
229 #[test]
230 fn test_brent() {
231 let cost = TestFunc {};
232 let solver = BrentOpt::new(-10., 10.);
233 let res = Executor::new(cost, solver)
234 .configure(|state| state.counting(true).max_iters(13))
235 .run()
236 .unwrap();
237 assert_eq!(
238 res.state().termination_status,
239 TerminationStatus::Terminated(TerminationReason::SolverConverged)
240 );
241 assert_relative_eq!(
242 res.state().param.unwrap(),
243 -8.613701289624956,
244 epsilon = f64::EPSILON.sqrt()
245 );
246 assert_relative_eq!(
247 res.state().prev_param.unwrap(),
248 -8.613701289624956,
249 epsilon = f64::EPSILON.sqrt()
250 );
251 assert_relative_eq!(
252 res.state().best_param.unwrap(),
253 -8.613701289624956,
254 epsilon = f64::EPSILON.sqrt()
255 );
256 assert_relative_eq!(
257 res.state().prev_best_param.unwrap(),
258 -8.613570813317839,
259 epsilon = f64::EPSILON.sqrt()
260 );
261 assert_relative_eq!(
262 res.state().cost,
263 -5506.616448675639,
264 epsilon = f64::EPSILON.sqrt()
265 );
266 assert_relative_eq!(
267 res.state().best_cost,
268 -5506.616448675639,
269 epsilon = f64::EPSILON.sqrt()
270 );
271 assert_relative_eq!(
272 res.state().prev_cost,
273 -5506.616448675639,
274 epsilon = f64::EPSILON.sqrt()
275 );
276 assert_relative_eq!(
277 res.state().prev_best_cost,
278 -5506.616423678641,
279 epsilon = f64::EPSILON.sqrt()
280 );
281 assert_eq!(res.state().iter, 13);
282 assert_eq!(res.state().get_func_counts()["cost_count"], 13);
283 }
284}