1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
// Copyright 2018-2024 argmin developers
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

//! # Checkpointing
//!
//! Checkpointing is a useful mechanism for mitigating the effects of crashes when software is run
//! in an unstable environment, particularly for long run times. Checkpoints are saved regularly
//! with a user-chosen frequency. Optimizations can then be resumed from a given checkpoint after a
//! crash.
//!
//! For saving checkpoints to disk, `FileCheckpoint` is provided in the `argmin-checkpointing-file`
//! crate.
//! Via the `Checkpoint` trait other checkpointing approaches can be implemented.
//!
//! The `CheckpointingFrequency` defines how often checkpoints are saved and can be chosen to be
//! either `Always` (every iteration), `Every(u64)` (every Nth iteration) or `Never`.
//!
//! The following example shows how the `checkpointing` method is used to activate checkpointing.
//! If no checkpoint is available on disk, an optimization will be started from scratch. If the run
//! crashes and a checkpoint is found on disk, then it will resume from the checkpoint.
//!
//! ## Example
//!
//! ```rust
//! # extern crate argmin;
//! # extern crate argmin_testfunctions;
//! # use argmin::core::{CostFunction, Error, Executor, Gradient, observers::ObserverMode};
//! # #[cfg(feature = "serde1")]
//! use argmin::core::checkpointing::CheckpointingFrequency;
//! # #[cfg(feature = "serde1")]
//! use argmin_checkpointing_file::FileCheckpoint;
//! # use argmin_observer_slog::SlogLogger;
//! # use argmin::solver::landweber::Landweber;
//! # use argmin_testfunctions::{rosenbrock, rosenbrock_derivative};
//! #
//! # #[derive(Default)]
//! # struct Rosenbrock {}
//! #
//! # /// Implement `CostFunction` for `Rosenbrock`
//! # impl CostFunction for Rosenbrock {
//! #     /// Type of the parameter vector
//! #     type Param = Vec<f64>;
//! #     /// Type of the return value computed by the cost function
//! #     type Output = f64;
//! #
//! #     /// Apply the cost function to a parameter `p`
//! #     fn cost(&self, p: &Self::Param) -> Result<Self::Output, Error> {
//! #         Ok(rosenbrock(p))
//! #     }
//! # }
//! #
//! # /// Implement `Gradient` for `Rosenbrock`
//! # impl Gradient for Rosenbrock {
//! #     /// Type of the parameter vector
//! #     type Param = Vec<f64>;
//! #     /// Type of the return value computed by the cost function
//! #     type Gradient = Vec<f64>;
//! #
//! #     /// Compute the gradient at parameter `p`.
//! #     fn gradient(&self, p: &Self::Param) -> Result<Self::Gradient, Error> {
//! #         Ok(rosenbrock_derivative(p))
//! #     }
//! # }
//! #
//! # fn run() -> Result<(), Error> {
//! #     // define initial parameter vector
//! #     let init_param: Vec<f64> = vec![1.2, 1.2];
//! #     let my_optimization_problem = Rosenbrock {};
//! #
//! #     let iters = 35;
//! #     let solver = Landweber::new(0.001);
//!
//! // [...]
//!
//! # #[cfg(feature = "serde1")]
//! let checkpoint = FileCheckpoint::new(
//!     ".checkpoints",
//!     "optim",
//!     CheckpointingFrequency::Every(20)
//! );
//!
//! #
//! # #[cfg(feature = "serde1")]
//! let res = Executor::new(my_optimization_problem, solver)
//!     .configure(|config| config.param(init_param).max_iters(iters))
//!     .checkpointing(checkpoint)
//!     .run()?;
//!
//! // [...]
//! #
//! #     Ok(())
//! # }
//! #
//! # fn main() {
//! #     if let Err(ref e) = run() {
//! #         println!("{}", e);
//! #     }
//! # }
//! ```

use crate::core::Error;
use std::default::Default;
use std::fmt::Display;

/// An interface for checkpointing methods
///
/// Handles saving of a checkpoint. The methods [`save`](`Checkpoint::save`) (saving the
/// checkpoint), [`load`](`Checkpoint::load`) (loading a checkpoint) are mandatory to implement.
/// The method [`save_cond`](`Checkpoint::save_cond`) determines if the conditions for calling
/// `save` are met, and if yes, calls `save`. [`frequency`](`Checkpoint::frequency`) returns the
/// conditions in form of a [`CheckpointingFrequency`].
///
/// # Example
///
/// ```
/// use argmin::core::Error;
/// use argmin::core::checkpointing::{Checkpoint, CheckpointingFrequency};
/// # #[cfg(feature = "serde1")]
/// use serde::{Serialize, de::DeserializeOwned};
///
/// struct MyCheckpoint {
///    frequency: CheckpointingFrequency,
///    // ..
/// }
///
/// # #[cfg(feature = "serde1")]
/// impl<S, I> Checkpoint<S, I> for MyCheckpoint
/// where
///     // Both `solver` (`S`) and `state` (`I`) (probably) need to be (de)serializable
///     S: Serialize + DeserializeOwned,
///     I: Serialize + DeserializeOwned,
/// #     S: Default,
/// #     I: Default,
/// {
///     fn save(&self, solver: &S, state: &I) -> Result<(), Error> {
///         // Save `solver` and `state`
///         Ok(())
///     }
///
///     fn load(&self) -> Result<Option<(S, I)>, Error> {
///         // Load `solver` and `state` from checkpoint
///         // Return `Ok(None)` in case checkpoint is not found.
/// #         let solver = S::default();
/// #         let state = I::default();
///         Ok(Some((solver, state)))
///     }
///
///     fn frequency(&self) -> CheckpointingFrequency {
///         self.frequency
///     }
/// }
/// # fn main() {}
/// ```
pub trait Checkpoint<S, I> {
    /// Save a checkpoint
    ///
    /// Gets a reference to the current `solver` of type `S` and to the current `state` of type
    /// `I`. Both solver and state can maintain state. Optimization problems itself are not allowed
    /// to have state which changes during an optimization (at least not in the context of
    /// checkpointing).
    fn save(&self, solver: &S, state: &I) -> Result<(), Error>;

    /// Saves a checkpoint when the checkpointing condition is met.
    ///
    /// Calls [`save`](`Checkpoint::save`) in each iteration (`CheckpointingFrequency::Always`),
    /// every X iterations (`CheckpointingFrequency::Every(X)`) or never
    /// (`CheckpointingFrequency::Never`).
    fn save_cond(&self, solver: &S, state: &I, iter: u64) -> Result<(), Error> {
        match self.frequency() {
            CheckpointingFrequency::Always => self.save(solver, state)?,
            CheckpointingFrequency::Every(it) if iter % it == 0 => self.save(solver, state)?,
            CheckpointingFrequency::Never | CheckpointingFrequency::Every(_) => {}
        };
        Ok(())
    }

    /// Loads a saved checkpoint
    ///
    /// Returns the solver of type `S` and the `state` of type `I`.
    fn load(&self) -> Result<Option<(S, I)>, Error>;

    /// Indicates how often checkpoints should be saved
    ///
    /// Returns enum `CheckpointingFrequency`.
    fn frequency(&self) -> CheckpointingFrequency;
}

/// Defines at which intervals a checkpoint is saved.
///
/// # Example
///
/// ```
/// use argmin::core::checkpointing::CheckpointingFrequency;
///
/// // A checkpoint every 10 iterations
/// let every_10 = CheckpointingFrequency::Every(10);
///
/// // A checkpoint in each iteration
/// let always = CheckpointingFrequency::Always;
///
/// // The default is `CheckpointingFrequency::Always`
/// assert_eq!(CheckpointingFrequency::default(), CheckpointingFrequency::Always);
/// ```
#[derive(Clone, Eq, PartialEq, Debug, Hash, Copy, Default)]
pub enum CheckpointingFrequency {
    /// Never create checkpoint
    Never,
    /// Create checkpoint every N iterations
    Every(u64),
    /// Create checkpoint in every iteration
    #[default]
    Always,
}

impl Display for CheckpointingFrequency {
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
        match *self {
            CheckpointingFrequency::Never => write!(f, "Never"),
            CheckpointingFrequency::Every(i) => write!(f, "Every({i})"),
            CheckpointingFrequency::Always => write!(f, "Always"),
        }
    }
}