use super::LineSearchCondition;
use crate::core::{ArgminFloat, Error};
use argmin_math::ArgminDot;
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct GoldsteinCondition<F> {
c: F,
}
impl<F> GoldsteinCondition<F>
where
F: ArgminFloat,
{
pub fn new(c: F) -> Result<Self, Error> {
if c <= float!(0.0) || c >= float!(0.5) {
return Err(argmin_error!(
InvalidParameter,
"GoldsteinCondition: Parameter c must be in (0, 0.5)"
));
}
Ok(GoldsteinCondition { c })
}
}
impl<T, G, F> LineSearchCondition<T, G, F> for GoldsteinCondition<F>
where
G: ArgminDot<T, F>,
F: ArgminFloat,
{
fn evaluate_condition(
&self,
current_cost: F,
_current_gradient: Option<&G>,
initial_cost: F,
initial_gradient: &G,
search_direction: &T,
step_length: F,
) -> bool {
let tmp = step_length * initial_gradient.dot(search_direction);
initial_cost + (float!(1.0) - self.c) * tmp <= current_cost
&& current_cost <= initial_cost + self.c * tmp
}
fn requires_current_gradient(&self) -> bool {
false
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::ArgminError;
test_trait_impl!(goldstein, GoldsteinCondition<f64>);
#[test]
fn test_goldstein_new() {
let c: f64 = 0.01;
let GoldsteinCondition { c: c_arm } = GoldsteinCondition::new(c).unwrap();
assert_eq!(c.to_ne_bytes(), c_arm.to_ne_bytes());
assert_error!(
GoldsteinCondition::new(0.5f64),
ArgminError,
"Invalid parameter: \"GoldsteinCondition: Parameter c must be in (0, 0.5)\""
);
assert_error!(
GoldsteinCondition::new(1.0f64),
ArgminError,
"Invalid parameter: \"GoldsteinCondition: Parameter c must be in (0, 0.5)\""
);
assert_error!(
GoldsteinCondition::new(0.0f64),
ArgminError,
"Invalid parameter: \"GoldsteinCondition: Parameter c must be in (0, 0.5)\""
);
assert_error!(
GoldsteinCondition::new(-1.0f64),
ArgminError,
"Invalid parameter: \"GoldsteinCondition: Parameter c must be in (0, 0.5)\""
);
}
#[test]
fn test_goldstein() {
let c: f64 = 0.1;
let cond = GoldsteinCondition::new(c).unwrap();
let f = |x: f64, y: f64| x.powf(2.0) + y.powf(2.0);
let g = |x: f64, y: f64| vec![2.0 * x, 2.0 * y];
let initial_x = -1.0;
let initial_y = -0.0;
let search_direction = vec![1.0, 0.0];
for (alpha, acc) in [
(0.001, false),
(0.03, false),
(0.2 - 6.0 * f64::EPSILON, false),
(0.2, true),
(0.2, true),
(0.5, true),
(0.9, true),
(0.99, true),
(1.0, true),
(1.5, true),
(1.8 - f64::EPSILON, true),
(1.8, false),
(2.0, false),
(2.3, false),
] {
assert_eq!(
cond.evaluate_condition(
f(initial_x + alpha, initial_y),
Some(&g(initial_x + alpha, initial_y)),
f(initial_x, initial_y),
&g(initial_x, initial_y),
&search_direction,
alpha,
),
acc
);
}
}
}