use super::LineSearchCondition;
use crate::core::{ArgminFloat, Error};
use argmin_math::ArgminDot;
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct StrongWolfeCondition<F> {
c1: F,
c2: F,
}
impl<F> StrongWolfeCondition<F>
where
F: ArgminFloat,
{
pub fn new(c1: F, c2: F) -> Result<Self, Error> {
if c1 <= float!(0.0) || c1 >= float!(1.0) {
return Err(argmin_error!(
InvalidParameter,
"StrongWolfeCondition: Parameter c1 must be in (0, 1)"
));
}
if c2 <= c1 || c2 >= float!(1.0) {
return Err(argmin_error!(
InvalidParameter,
"StrongWolfeCondition: Parameter c2 must be in (c1, 1)"
));
}
Ok(StrongWolfeCondition { c1, c2 })
}
}
impl<T, G, F> LineSearchCondition<T, G, F> for StrongWolfeCondition<F>
where
G: ArgminDot<T, F>,
F: ArgminFloat,
{
fn evaluate_condition(
&self,
current_cost: F,
current_gradient: Option<&G>,
initial_cost: F,
initial_gradient: &G,
search_direction: &T,
step_length: F,
) -> bool {
let tmp = initial_gradient.dot(search_direction);
(current_cost <= initial_cost + self.c1 * step_length * tmp)
&& current_gradient
.expect("Gradient not supplied to `evaluate_condition` of `StrongWolveCondition`")
.dot(search_direction)
.abs()
<= self.c2 * tmp.abs()
}
fn requires_current_gradient(&self) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::ArgminError;
test_trait_impl!(strongwolfe, StrongWolfeCondition<f64>);
#[test]
fn test_strongwolfe_new() {
let c1: f64 = 0.01;
let c2: f64 = 0.08;
let StrongWolfeCondition {
c1: c1_wolfe,
c2: c2_wolfe,
} = StrongWolfeCondition::new(c1, c2).unwrap();
assert_eq!(c1.to_ne_bytes(), c1_wolfe.to_ne_bytes());
assert_eq!(c2.to_ne_bytes(), c2_wolfe.to_ne_bytes());
assert_error!(
StrongWolfeCondition::new(1.0, 0.5),
ArgminError,
"Invalid parameter: \"StrongWolfeCondition: Parameter c1 must be in (0, 1)\""
);
assert_error!(
StrongWolfeCondition::new(0.0, 0.5),
ArgminError,
"Invalid parameter: \"StrongWolfeCondition: Parameter c1 must be in (0, 1)\""
);
assert_error!(
StrongWolfeCondition::new(-1.0, 0.5),
ArgminError,
"Invalid parameter: \"StrongWolfeCondition: Parameter c1 must be in (0, 1)\""
);
assert_error!(
StrongWolfeCondition::new(2.0, 0.5),
ArgminError,
"Invalid parameter: \"StrongWolfeCondition: Parameter c1 must be in (0, 1)\""
);
assert_error!(
StrongWolfeCondition::new(0.5, -1.0),
ArgminError,
"Invalid parameter: \"StrongWolfeCondition: Parameter c2 must be in (c1, 1)\""
);
assert_error!(
StrongWolfeCondition::new(0.5, 0.0),
ArgminError,
"Invalid parameter: \"StrongWolfeCondition: Parameter c2 must be in (c1, 1)\""
);
assert_error!(
StrongWolfeCondition::new(0.5, 0.5),
ArgminError,
"Invalid parameter: \"StrongWolfeCondition: Parameter c2 must be in (c1, 1)\""
);
assert_error!(
StrongWolfeCondition::new(0.5, 1.0),
ArgminError,
"Invalid parameter: \"StrongWolfeCondition: Parameter c2 must be in (c1, 1)\""
);
assert_error!(
StrongWolfeCondition::new(0.5, 2.0),
ArgminError,
"Invalid parameter: \"StrongWolfeCondition: Parameter c2 must be in (c1, 1)\""
);
}
#[test]
fn test_strongwolfe() {
let c1: f64 = 0.01;
let c2: f64 = 0.9;
let cond = StrongWolfeCondition::new(c1, c2).unwrap();
let f = |x: f64, y: f64| x.powf(2.0) + y.powf(2.0);
let g = |x: f64, y: f64| vec![2.0 * x, 2.0 * y];
let initial_x = -1.0;
let initial_y = -0.0;
let search_direction = vec![1.0, 0.0];
for (alpha, acc) in [
(0.001, false),
(0.03, false),
(0.1 - f64::EPSILON, false),
(0.1, true),
(0.15, true),
(0.9, true),
(0.99, true),
(1.0, true),
(1.9, true),
(1.9 + f64::EPSILON, false),
(2.0, false),
(2.3, false),
] {
assert_eq!(
cond.evaluate_condition(
f(initial_x + alpha, initial_y),
Some(&g(initial_x + alpha, initial_y)),
f(initial_x, initial_y),
&g(initial_x, initial_y),
&search_direction,
alpha,
),
acc
);
}
let c1: f64 = 0.5;
let c2: f64 = 0.9;
let cond = StrongWolfeCondition::new(c1, c2).unwrap();
let f = |x: f64, y: f64| x.powf(2.0) + y.powf(2.0);
let g = |x: f64, y: f64| vec![2.0 * x, 2.0 * y];
let initial_x = -1.0;
let initial_y = -0.0;
let search_direction = vec![1.0, 0.0];
for (alpha, acc) in [
(0.001, false),
(0.03, false),
(0.1 - f64::EPSILON, false),
(0.1, true),
(0.15, true),
(0.9, true),
(0.99, true),
(1.0, true),
(1.0 + f64::EPSILON, false),
(1.9, false),
(2.0, false),
(2.3, false),
] {
assert_eq!(
cond.evaluate_condition(
f(initial_x + alpha, initial_y),
Some(&g(initial_x + alpha, initial_y)),
f(initial_x, initial_y),
&g(initial_x, initial_y),
&search_direction,
alpha,
),
acc
);
}
}
}