argmin/solver/linesearch/condition/
armijo.rsuse super::LineSearchCondition;
use crate::core::{ArgminFloat, Error};
use argmin_math::ArgminDot;
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct ArmijoCondition<F> {
c: F,
}
impl<F> ArmijoCondition<F>
where
F: ArgminFloat,
{
pub fn new(c: F) -> Result<Self, Error> {
if c <= float!(0.0) || c >= float!(1.0) {
return Err(argmin_error!(
InvalidParameter,
"ArmijoCondition: Parameter c must be in (0, 1)"
));
}
Ok(ArmijoCondition { c })
}
}
impl<T, G, F> LineSearchCondition<T, G, F> for ArmijoCondition<F>
where
G: ArgminDot<T, F>,
F: ArgminFloat,
{
fn evaluate_condition(
&self,
current_cost: F,
_current_gradient: Option<&G>,
initial_cost: F,
initial_gradient: &G,
search_direction: &T,
step_length: F,
) -> bool {
current_cost <= initial_cost + self.c * step_length * initial_gradient.dot(search_direction)
}
fn requires_current_gradient(&self) -> bool {
false
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::ArgminError;
test_trait_impl!(armijo, ArmijoCondition<f64>);
#[test]
fn test_armijo_new() {
let c: f64 = 0.01;
let ArmijoCondition { c: c_arm } = ArmijoCondition::new(c).unwrap();
assert_eq!(c.to_ne_bytes(), c_arm.to_ne_bytes());
assert_error!(
ArmijoCondition::new(1.0f64),
ArgminError,
"Invalid parameter: \"ArmijoCondition: Parameter c must be in (0, 1)\""
);
assert_error!(
ArmijoCondition::new(2.0f64),
ArgminError,
"Invalid parameter: \"ArmijoCondition: Parameter c must be in (0, 1)\""
);
assert_error!(
ArmijoCondition::new(0.0f64),
ArgminError,
"Invalid parameter: \"ArmijoCondition: Parameter c must be in (0, 1)\""
);
assert_error!(
ArmijoCondition::new(-1.0f64),
ArgminError,
"Invalid parameter: \"ArmijoCondition: Parameter c must be in (0, 1)\""
);
}
#[test]
fn test_armijo() {
let c: f64 = 0.50;
let cond = ArmijoCondition::new(c).unwrap();
let f = |x: f64, y: f64| x.powf(2.0) + y.powf(2.0);
let g = |x: f64, y: f64| vec![2.0 * x, 2.0 * y];
let initial_x = -1.0;
let initial_y = -0.0;
let search_direction = vec![1.0, 0.0];
for (alpha, acc) in [
(0.001, true),
(0.03, true),
(0.2, true),
(0.5, true),
(0.9, true),
(0.99, true),
(1.0, true),
(1.0 + f64::EPSILON, false),
(1.5, false),
(1.8, false),
(2.0, false),
(2.3, false),
] {
assert_eq!(
cond.evaluate_condition(
f(initial_x + alpha, initial_y),
Some(&g(initial_x + alpha, initial_y)),
f(initial_x, initial_y),
&g(initial_x, initial_y),
&search_direction,
alpha,
),
acc
);
}
}
}