1use crate::core::{
20 ArgminFloat, CostFunction, Error, IterState, Problem, Solver, TerminationReason,
21 TerminationStatus, KV,
22};
23#[cfg(feature = "serde1")]
24use serde::{Deserialize, Serialize};
25
26const GOLDEN_RATIO: f64 = 1.618_033_988_749_895;
28const G1: f64 = -1.0 + GOLDEN_RATIO;
29const G2: f64 = 1.0 - G1;
30
31#[derive(Clone)]
55#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
56pub struct GoldenSectionSearch<F> {
57 g1: F,
58 g2: F,
59 min_bound: F,
60 max_bound: F,
61 tolerance: F,
62
63 x0: F,
64 x1: F,
65 x2: F,
66 x3: F,
67 f1: F,
68 f2: F,
69}
70
71impl<F> GoldenSectionSearch<F>
72where
73 F: ArgminFloat,
74{
75 pub fn new(min_bound: F, max_bound: F) -> Result<Self, Error> {
90 if max_bound <= min_bound {
91 return Err(argmin_error!(
92 InvalidParameter,
93 "`GoldenSectionSearch`: `min_bound` must be smaller than `max_bound`."
94 ));
95 }
96 Ok(GoldenSectionSearch {
97 g1: F::from(G1).unwrap(),
98 g2: F::from(G2).unwrap(),
99 min_bound,
100 max_bound,
101 tolerance: F::from(0.01).unwrap(),
102 x0: min_bound,
103 x1: F::zero(),
104 x2: F::zero(),
105 x3: max_bound,
106 f1: F::zero(),
107 f2: F::zero(),
108 })
109 }
110
111 pub fn with_tolerance(mut self, tolerance: F) -> Result<Self, Error> {
126 if tolerance <= float!(0.0) {
127 return Err(argmin_error!(
128 InvalidParameter,
129 "`GoldenSectionSearch`: Tolerance must be larger than 0."
130 ));
131 }
132 self.tolerance = tolerance;
133 Ok(self)
134 }
135}
136
137impl<O, F> Solver<O, IterState<F, (), (), (), (), F>> for GoldenSectionSearch<F>
138where
139 O: CostFunction<Param = F, Output = F>,
140 F: ArgminFloat,
141{
142 fn name(&self) -> &str {
143 "Golden-section search"
144 }
145
146 fn init(
147 &mut self,
148 problem: &mut Problem<O>,
149 mut state: IterState<F, (), (), (), (), F>,
150 ) -> Result<(IterState<F, (), (), (), (), F>, Option<KV>), Error> {
151 let init_estimate = state.take_param().ok_or_else(argmin_error_closure!(
152 NotInitialized,
153 concat!(
154 "`GoldenSectionSearch` requires an initial estimate. ",
155 "Please provide an initial guess via `Executor`s `configure` method."
156 )
157 ))?;
158 if init_estimate < self.min_bound || init_estimate > self.max_bound {
159 Err(argmin_error!(
160 InvalidParameter,
161 "`GoldenSectionSearch`: Initial estimate must be ∈ [min_bound, max_bound]."
162 ))
163 } else {
164 let ie_min = init_estimate - self.min_bound;
165 let max_ie = self.max_bound - init_estimate;
166 let (x1, x2) = if max_ie.abs() > ie_min.abs() {
167 (init_estimate, init_estimate + self.g2 * max_ie)
168 } else {
169 (init_estimate - self.g2 * ie_min, init_estimate)
170 };
171 self.x1 = x1;
172 self.x2 = x2;
173 self.f1 = problem.cost(&self.x1)?;
174 self.f2 = problem.cost(&self.x2)?;
175 if self.f1 < self.f2 {
176 Ok((state.param(self.x1).cost(self.f1), None))
177 } else {
178 Ok((state.param(self.x2).cost(self.f2), None))
179 }
180 }
181 }
182
183 fn next_iter(
184 &mut self,
185 problem: &mut Problem<O>,
186 state: IterState<F, (), (), (), (), F>,
187 ) -> Result<(IterState<F, (), (), (), (), F>, Option<KV>), Error> {
188 if self.f2 < self.f1 {
189 self.x0 = self.x1;
190 self.x1 = self.x2;
191 self.x2 = self.g1 * self.x1 + self.g2 * self.x3;
192 self.f1 = self.f2;
193 self.f2 = problem.cost(&self.x2)?;
194 } else {
195 self.x3 = self.x2;
196 self.x2 = self.x1;
197 self.x1 = self.g1 * self.x2 + self.g2 * self.x0;
198 self.f2 = self.f1;
199 self.f1 = problem.cost(&self.x1)?;
200 }
201 if self.f1 < self.f2 {
202 Ok((state.param(self.x1).cost(self.f1), None))
203 } else {
204 Ok((state.param(self.x2).cost(self.f2), None))
205 }
206 }
207
208 fn terminate(&mut self, _state: &IterState<F, (), (), (), (), F>) -> TerminationStatus {
209 if self.tolerance * (self.x1.abs() + self.x2.abs()) >= (self.x3 - self.x0).abs() {
210 return TerminationStatus::Terminated(TerminationReason::SolverConverged);
211 }
212 TerminationStatus::NotTerminated
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use crate::core::{ArgminError, State};
220 use approx::assert_relative_eq;
221
222 #[derive(Clone)]
223 struct GssTestProblem {}
224
225 impl CostFunction for GssTestProblem {
226 type Param = f64;
227 type Output = f64;
228
229 fn cost(&self, x: &Self::Param) -> Result<Self::Output, Error> {
230 Ok((x + 3.0) * (x - 1.0).powi(2))
231 }
232 }
233
234 test_trait_impl!(golden_section_search, GoldenSectionSearch<f64>);
235
236 #[test]
237 fn test_new() {
238 let GoldenSectionSearch {
239 g1,
240 g2,
241 min_bound,
242 max_bound,
243 tolerance,
244 x0,
245 x1,
246 x2,
247 x3,
248 f1,
249 f2,
250 } = GoldenSectionSearch::new(-2.5f64, 3.0f64).unwrap();
251
252 assert_eq!(g1.to_ne_bytes(), G1.to_ne_bytes());
253 assert_eq!(g2.to_ne_bytes(), G2.to_ne_bytes());
254 assert_eq!(min_bound.to_ne_bytes(), (-2.5f64).to_ne_bytes());
255 assert_eq!(max_bound.to_ne_bytes(), 3.0f64.to_ne_bytes());
256 assert_eq!(tolerance.to_ne_bytes(), 0.01f64.to_ne_bytes());
257 assert_eq!(x0.to_ne_bytes(), min_bound.to_ne_bytes());
258 assert_eq!(x1.to_ne_bytes(), 0f64.to_ne_bytes());
259 assert_eq!(x2.to_ne_bytes(), 0f64.to_ne_bytes());
260 assert_eq!(x3.to_ne_bytes(), max_bound.to_ne_bytes());
261 assert_eq!(f1.to_ne_bytes(), 0f64.to_ne_bytes());
262 assert_eq!(f2.to_ne_bytes(), 0f64.to_ne_bytes());
263 }
264
265 #[test]
266 fn test_new_errors() {
267 let res = GoldenSectionSearch::new(2.5f64, -3.0f64);
268
269 assert_error!(
270 res,
271 ArgminError,
272 concat!(
273 "Invalid parameter: \"`GoldenSectionSearch`: ",
274 "`min_bound` must be smaller than `max_bound`.\""
275 )
276 );
277
278 let res = GoldenSectionSearch::new(2.5f64, 2.5f64);
279
280 assert_error!(
281 res,
282 ArgminError,
283 concat!(
284 "Invalid parameter: \"`GoldenSectionSearch`: ",
285 "`min_bound` must be smaller than `max_bound`.\""
286 )
287 );
288 }
289
290 #[test]
291 fn test_tolerance() {
292 let GoldenSectionSearch { tolerance, .. } = GoldenSectionSearch::new(-2.5f64, 3.0f64)
293 .unwrap()
294 .with_tolerance(0.001)
295 .unwrap();
296
297 assert_eq!(tolerance.to_ne_bytes(), 0.001f64.to_ne_bytes());
298 }
299
300 #[test]
301 fn test_tolerance_errors() {
302 let res = GoldenSectionSearch::new(-2.5f64, 3.0f64)
303 .unwrap()
304 .with_tolerance(0.0);
305 assert_error!(
306 res,
307 ArgminError,
308 "Invalid parameter: \"`GoldenSectionSearch`: Tolerance must be larger than 0.\""
309 );
310
311 let res = GoldenSectionSearch::new(-2.5f64, 3.0f64)
312 .unwrap()
313 .with_tolerance(-1.0);
314 assert_error!(
315 res,
316 ArgminError,
317 "Invalid parameter: \"`GoldenSectionSearch`: Tolerance must be larger than 0.\""
318 );
319 }
320
321 #[test]
322 fn test_init_param_not_initialized() {
323 let mut gss = GoldenSectionSearch::new(-2.5f64, 3.0f64).unwrap();
324 let res = gss.init(&mut Problem::new(GssTestProblem {}), IterState::new());
325 assert_error!(
326 res,
327 ArgminError,
328 concat!(
329 "Not initialized: \"`GoldenSectionSearch` requires an initial estimate. ",
330 "Please provide an initial guess via `Executor`s `configure` method.\""
331 )
332 );
333 }
334
335 #[test]
336 fn test_init_param_outside_bounds() {
337 let mut gss = GoldenSectionSearch::new(-2.5f64, 3.0f64).unwrap();
338 let res = gss.init(
339 &mut Problem::new(GssTestProblem {}),
340 IterState::new().param(5.0f64),
341 );
342 assert_error!(
343 res,
344 ArgminError,
345 "Invalid parameter: \"`GoldenSectionSearch`: Initial estimate must be ∈ [min_bound, max_bound].\""
346 );
347 }
348
349 #[test]
350 fn test_init() {
351 let mut gss = GoldenSectionSearch::new(-2.5f64, 3.0f64).unwrap();
352 let problem = GssTestProblem {};
353 let (state, kv) = gss
354 .init(
355 &mut Problem::new(problem.clone()),
356 IterState::new().param(-0.5f64),
357 )
358 .unwrap();
359
360 assert!(kv.is_none());
361
362 let GoldenSectionSearch {
363 g1,
364 g2,
365 min_bound,
366 max_bound,
367 tolerance,
368 x0,
369 x1,
370 x2,
371 x3,
372 f1,
373 f2,
374 } = gss.clone();
375
376 assert_relative_eq!(x1, -0.5f64, epsilon = f64::EPSILON);
377 assert_relative_eq!(x2, -0.5f64 + g2 * 3.5f64, epsilon = f64::EPSILON);
378 assert_relative_eq!(f1, problem.cost(&x1).unwrap(), epsilon = f64::EPSILON);
379 assert_relative_eq!(f2, problem.cost(&x2).unwrap(), epsilon = f64::EPSILON);
380 if f1 < f2 {
381 assert_relative_eq!(*state.param.as_ref().unwrap(), x1, epsilon = f64::EPSILON);
382 assert_relative_eq!(state.cost, f1, epsilon = f64::EPSILON);
383 } else {
384 assert_relative_eq!(*state.param.as_ref().unwrap(), x2, epsilon = f64::EPSILON);
385 assert_relative_eq!(state.cost, f2, epsilon = f64::EPSILON);
386 }
387
388 assert_eq!(g1.to_ne_bytes(), G1.to_ne_bytes());
389 assert_eq!(g2.to_ne_bytes(), G2.to_ne_bytes());
390 assert_eq!(min_bound.to_ne_bytes(), (-2.5f64).to_ne_bytes());
391 assert_eq!(max_bound.to_ne_bytes(), 3.0f64.to_ne_bytes());
392 assert_eq!(tolerance.to_ne_bytes(), 0.01f64.to_ne_bytes());
393 assert_eq!(x0.to_ne_bytes(), min_bound.to_ne_bytes());
394 assert_eq!(x3.to_ne_bytes(), max_bound.to_ne_bytes());
395 }
396
397 #[test]
398 fn test_next_iter_1() {
399 let mut gss = GoldenSectionSearch::new(-2.5f64, 3.0f64).unwrap();
400 let mut problem = Problem::new(GssTestProblem {});
401
402 gss.f1 = 10.0f64;
403 gss.f2 = 5.0f64;
404 gss.x0 = 0.0f64;
405 gss.x1 = 1.0f64;
406 gss.x2 = 2.0f64;
407 gss.x3 = 3.0f64;
408
409 let (state, kv) = gss
410 .next_iter(&mut problem, IterState::new().param(-0.5f64))
411 .unwrap();
412
413 assert!(kv.is_none());
414
415 let GoldenSectionSearch {
416 g1,
417 g2,
418 min_bound,
419 max_bound,
420 tolerance,
421 x0,
422 x1,
423 x2,
424 x3,
425 f1,
426 f2,
427 } = gss.clone();
428
429 assert_relative_eq!(x0, 1.0f64, epsilon = f64::EPSILON);
430 assert_relative_eq!(x1, 2.0f64, epsilon = f64::EPSILON);
431 assert_relative_eq!(x2, g1 * 2.0f64 + g2 * x3, epsilon = f64::EPSILON);
432 assert_relative_eq!(f1, 5.0f64, epsilon = f64::EPSILON);
433 assert_relative_eq!(f2, problem.cost(&x2).unwrap(), epsilon = f64::EPSILON);
434 assert_eq!(g1.to_ne_bytes(), G1.to_ne_bytes());
435 assert_eq!(g2.to_ne_bytes(), G2.to_ne_bytes());
436 assert_eq!(min_bound.to_ne_bytes(), (-2.5f64).to_ne_bytes());
437 assert_eq!(max_bound.to_ne_bytes(), 3.0f64.to_ne_bytes());
438 assert_eq!(tolerance.to_ne_bytes(), 0.01f64.to_ne_bytes());
439 if f1 < f2 {
440 assert_relative_eq!(*state.param.as_ref().unwrap(), x1, epsilon = f64::EPSILON);
441 assert_relative_eq!(state.cost, f1, epsilon = f64::EPSILON);
442 } else {
443 assert_relative_eq!(*state.param.as_ref().unwrap(), x2, epsilon = f64::EPSILON);
444 assert_relative_eq!(state.cost, f2, epsilon = f64::EPSILON);
445 }
446 }
447
448 #[test]
449 fn test_next_iter_2() {
450 let mut gss = GoldenSectionSearch::new(-2.5f64, 3.0f64).unwrap();
451 let mut problem = Problem::new(GssTestProblem {});
452
453 gss.f1 = 5.0f64;
454 gss.f2 = 10.0f64;
455 gss.x0 = 0.0f64;
456 gss.x1 = 1.0f64;
457 gss.x2 = 2.0f64;
458 gss.x3 = 3.0f64;
459
460 let (state, kv) = gss
461 .next_iter(&mut problem, IterState::new().param(-0.5f64))
462 .unwrap();
463
464 assert!(kv.is_none());
465
466 let GoldenSectionSearch {
467 g1,
468 g2,
469 min_bound,
470 max_bound,
471 tolerance,
472 x0,
473 x1,
474 x2,
475 x3,
476 f1,
477 f2,
478 } = gss.clone();
479
480 assert_relative_eq!(x0, 0.0f64, epsilon = f64::EPSILON);
481 assert_relative_eq!(x1, g1 * x2 + g2 * x0, epsilon = f64::EPSILON);
482 assert_relative_eq!(x2, 1.0f64, epsilon = f64::EPSILON);
483 assert_relative_eq!(x3, 2.0f64, epsilon = f64::EPSILON);
484 assert_relative_eq!(f1, problem.cost(&x1).unwrap(), epsilon = f64::EPSILON);
485 assert_relative_eq!(f2, 5.0f64, epsilon = f64::EPSILON);
486 assert_eq!(g1.to_ne_bytes(), G1.to_ne_bytes());
487 assert_eq!(g2.to_ne_bytes(), G2.to_ne_bytes());
488 assert_eq!(min_bound.to_ne_bytes(), (-2.5f64).to_ne_bytes());
489 assert_eq!(max_bound.to_ne_bytes(), 3.0f64.to_ne_bytes());
490 assert_eq!(tolerance.to_ne_bytes(), 0.01f64.to_ne_bytes());
491 if f1 < f2 {
492 assert_relative_eq!(*state.param.as_ref().unwrap(), x1, epsilon = f64::EPSILON);
493 assert_relative_eq!(state.cost, f1, epsilon = f64::EPSILON);
494 } else {
495 assert_relative_eq!(*state.param.as_ref().unwrap(), x2, epsilon = f64::EPSILON);
496 assert_relative_eq!(state.cost, f2, epsilon = f64::EPSILON);
497 }
498 }
499}