1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
// Copyright 2018-2024 argmin developers
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.
//! # Checkpointing
//!
//! Checkpointing is a useful mechanism for mitigating the effects of crashes when software is run
//! in an unstable environment, particularly for long run times. Checkpoints are saved regularly
//! with a user-chosen frequency. Optimizations can then be resumed from a given checkpoint after a
//! crash.
//!
//! For saving checkpoints to disk, `FileCheckpoint` is provided in the `argmin-checkpointing-file`
//! crate.
//! Via the `Checkpoint` trait other checkpointing approaches can be implemented.
//!
//! The `CheckpointingFrequency` defines how often checkpoints are saved and can be chosen to be
//! either `Always` (every iteration), `Every(u64)` (every Nth iteration) or `Never`.
//!
//! The following example shows how the `checkpointing` method is used to activate checkpointing.
//! If no checkpoint is available on disk, an optimization will be started from scratch. If the run
//! crashes and a checkpoint is found on disk, then it will resume from the checkpoint.
//!
//! ## Example
//!
//! ```rust
//! # extern crate argmin;
//! # extern crate argmin_testfunctions;
//! # use argmin::core::{CostFunction, Error, Executor, Gradient, observers::ObserverMode};
//! # #[cfg(feature = "serde1")]
//! use argmin::core::checkpointing::CheckpointingFrequency;
//! # #[cfg(feature = "serde1")]
//! use argmin_checkpointing_file::FileCheckpoint;
//! # use argmin_observer_slog::SlogLogger;
//! # use argmin::solver::landweber::Landweber;
//! # use argmin_testfunctions::{rosenbrock, rosenbrock_derivative};
//! #
//! # #[derive(Default)]
//! # struct Rosenbrock {}
//! #
//! # /// Implement `CostFunction` for `Rosenbrock`
//! # impl CostFunction for Rosenbrock {
//! # /// Type of the parameter vector
//! # type Param = Vec<f64>;
//! # /// Type of the return value computed by the cost function
//! # type Output = f64;
//! #
//! # /// Apply the cost function to a parameter `p`
//! # fn cost(&self, p: &Self::Param) -> Result<Self::Output, Error> {
//! # Ok(rosenbrock(p))
//! # }
//! # }
//! #
//! # /// Implement `Gradient` for `Rosenbrock`
//! # impl Gradient for Rosenbrock {
//! # /// Type of the parameter vector
//! # type Param = Vec<f64>;
//! # /// Type of the return value computed by the cost function
//! # type Gradient = Vec<f64>;
//! #
//! # /// Compute the gradient at parameter `p`.
//! # fn gradient(&self, p: &Self::Param) -> Result<Self::Gradient, Error> {
//! # Ok(rosenbrock_derivative(p))
//! # }
//! # }
//! #
//! # fn run() -> Result<(), Error> {
//! # // define initial parameter vector
//! # let init_param: Vec<f64> = vec![1.2, 1.2];
//! # let my_optimization_problem = Rosenbrock {};
//! #
//! # let iters = 35;
//! # let solver = Landweber::new(0.001);
//!
//! // [...]
//!
//! # #[cfg(feature = "serde1")]
//! let checkpoint = FileCheckpoint::new(
//! ".checkpoints",
//! "optim",
//! CheckpointingFrequency::Every(20)
//! );
//!
//! #
//! # #[cfg(feature = "serde1")]
//! let res = Executor::new(my_optimization_problem, solver)
//! .configure(|config| config.param(init_param).max_iters(iters))
//! .checkpointing(checkpoint)
//! .run()?;
//!
//! // [...]
//! #
//! # Ok(())
//! # }
//! #
//! # fn main() {
//! # if let Err(ref e) = run() {
//! # println!("{}", e);
//! # }
//! # }
//! ```
use crate::core::Error;
use std::default::Default;
use std::fmt::Display;
/// An interface for checkpointing methods
///
/// Handles saving of a checkpoint. The methods [`save`](`Checkpoint::save`) (saving the
/// checkpoint), [`load`](`Checkpoint::load`) (loading a checkpoint) are mandatory to implement.
/// The method [`save_cond`](`Checkpoint::save_cond`) determines if the conditions for calling
/// `save` are met, and if yes, calls `save`. [`frequency`](`Checkpoint::frequency`) returns the
/// conditions in form of a [`CheckpointingFrequency`].
///
/// # Example
///
/// ```
/// use argmin::core::Error;
/// use argmin::core::checkpointing::{Checkpoint, CheckpointingFrequency};
/// # #[cfg(feature = "serde1")]
/// use serde::{Serialize, de::DeserializeOwned};
///
/// struct MyCheckpoint {
/// frequency: CheckpointingFrequency,
/// // ..
/// }
///
/// # #[cfg(feature = "serde1")]
/// impl<S, I> Checkpoint<S, I> for MyCheckpoint
/// where
/// // Both `solver` (`S`) and `state` (`I`) (probably) need to be (de)serializable
/// S: Serialize + DeserializeOwned,
/// I: Serialize + DeserializeOwned,
/// # S: Default,
/// # I: Default,
/// {
/// fn save(&self, solver: &S, state: &I) -> Result<(), Error> {
/// // Save `solver` and `state`
/// Ok(())
/// }
///
/// fn load(&self) -> Result<Option<(S, I)>, Error> {
/// // Load `solver` and `state` from checkpoint
/// // Return `Ok(None)` in case checkpoint is not found.
/// # let solver = S::default();
/// # let state = I::default();
/// Ok(Some((solver, state)))
/// }
///
/// fn frequency(&self) -> CheckpointingFrequency {
/// self.frequency
/// }
/// }
/// # fn main() {}
/// ```
pub trait Checkpoint<S, I> {
/// Save a checkpoint
///
/// Gets a reference to the current `solver` of type `S` and to the current `state` of type
/// `I`. Both solver and state can maintain state. Optimization problems itself are not allowed
/// to have state which changes during an optimization (at least not in the context of
/// checkpointing).
fn save(&self, solver: &S, state: &I) -> Result<(), Error>;
/// Saves a checkpoint when the checkpointing condition is met.
///
/// Calls [`save`](`Checkpoint::save`) in each iteration (`CheckpointingFrequency::Always`),
/// every X iterations (`CheckpointingFrequency::Every(X)`) or never
/// (`CheckpointingFrequency::Never`).
fn save_cond(&self, solver: &S, state: &I, iter: u64) -> Result<(), Error> {
match self.frequency() {
CheckpointingFrequency::Always => self.save(solver, state)?,
CheckpointingFrequency::Every(it) if iter % it == 0 => self.save(solver, state)?,
CheckpointingFrequency::Never | CheckpointingFrequency::Every(_) => {}
};
Ok(())
}
/// Loads a saved checkpoint
///
/// Returns the solver of type `S` and the `state` of type `I`.
fn load(&self) -> Result<Option<(S, I)>, Error>;
/// Indicates how often checkpoints should be saved
///
/// Returns enum `CheckpointingFrequency`.
fn frequency(&self) -> CheckpointingFrequency;
}
/// Defines at which intervals a checkpoint is saved.
///
/// # Example
///
/// ```
/// use argmin::core::checkpointing::CheckpointingFrequency;
///
/// // A checkpoint every 10 iterations
/// let every_10 = CheckpointingFrequency::Every(10);
///
/// // A checkpoint in each iteration
/// let always = CheckpointingFrequency::Always;
///
/// // The default is `CheckpointingFrequency::Always`
/// assert_eq!(CheckpointingFrequency::default(), CheckpointingFrequency::Always);
/// ```
#[derive(Clone, Eq, PartialEq, Debug, Hash, Copy, Default)]
pub enum CheckpointingFrequency {
/// Never create checkpoint
Never,
/// Create checkpoint every N iterations
Every(u64),
/// Create checkpoint in every iteration
#[default]
Always,
}
impl Display for CheckpointingFrequency {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match *self {
CheckpointingFrequency::Never => write!(f, "Never"),
CheckpointingFrequency::Every(i) => write!(f, "Every({i})"),
CheckpointingFrequency::Always => write!(f, "Always"),
}
}
}