argmin_checkpointing_file/
lib.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! This crate creates checkpoints on disk for an optimization run.
9//!
10//! Saves a checkpoint on disk from which an interrupted optimization run can be resumed.
11//! For details on the usage please see the documentation of [`FileCheckpoint`] or have a look at
12//! the [example](https://github.com/argmin-rs/argmin/tree/main/examples/checkpoint).
13//!
14//! # Usage
15//!
16//! Add the following line to your dependencies list:
17//!
18//! ```toml
19//! [dependencies]
20#![doc = concat!("argmin-checkpointing-file = \"", env!("CARGO_PKG_VERSION"), "\"")]
21//! ```
22//!
23//! # License
24//!
25//! Licensed under either of
26//!
27//!   * Apache License, Version 2.0,
28//!     ([LICENSE-APACHE](https://github.com/argmin-rs/argmin/blob/main/LICENSE-APACHE) or
29//!     <http://www.apache.org/licenses/LICENSE-2.0>)
30//!   * MIT License ([LICENSE-MIT](https://github.com/argmin-rs/argmin/blob/main/LICENSE-MIT) or
31//!     <http://opensource.org/licenses/MIT>)
32//!
33//! at your option.
34//!
35//! ## Contribution
36//!
37//! Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion
38//! in the work by you, as defined in the Apache-2.0 license, shall be dual licensed as above,
39//! without any additional terms or conditions.
40
41pub use argmin::core::checkpointing::{Checkpoint, CheckpointingFrequency};
42use argmin::core::Error;
43use serde::{de::DeserializeOwned, Serialize};
44use std::fs::File;
45use std::io::{BufReader, BufWriter};
46use std::path::PathBuf;
47
48/// Handles saving a checkpoint to disk as a binary file.
49#[derive(Clone, Eq, PartialEq, Debug, Hash)]
50pub struct FileCheckpoint {
51    /// Indicates how often a checkpoint is created
52    pub frequency: CheckpointingFrequency,
53    /// Directory where the checkpoints are saved to
54    pub directory: PathBuf,
55    /// Name of the checkpoint files
56    pub filename: PathBuf,
57}
58
59impl Default for FileCheckpoint {
60    /// Create a default `FileCheckpoint` instance.
61    ///
62    /// This will save the checkpoint in the file `.checkpoints/checkpoint.arg`.
63    ///
64    /// # Example
65    ///
66    /// ```
67    /// use argmin_checkpointing_file::FileCheckpoint;
68    /// # use argmin::core::checkpointing::CheckpointingFrequency;
69    /// # use std::path::PathBuf;
70    ///
71    /// let checkpoint = FileCheckpoint::default();
72    /// # assert_eq!(checkpoint.frequency, CheckpointingFrequency::default());
73    /// # assert_eq!(checkpoint.directory, PathBuf::from(".checkpoints"));
74    /// # assert_eq!(checkpoint.filename, PathBuf::from("checkpoint.arg"));
75    /// ```
76    fn default() -> FileCheckpoint {
77        FileCheckpoint {
78            frequency: CheckpointingFrequency::default(),
79            directory: PathBuf::from(".checkpoints"),
80            filename: PathBuf::from("checkpoint.arg"),
81        }
82    }
83}
84
85impl FileCheckpoint {
86    /// Create a new `FileCheckpoint` instance
87    ///
88    /// # Example
89    ///
90    /// ```
91    /// use argmin_checkpointing_file::{FileCheckpoint, CheckpointingFrequency};
92    /// # use std::path::PathBuf;
93    ///
94    /// let directory = "checkpoints";
95    /// let filename = "optimization";
96    ///
97    /// // When passed to an `Executor`, this will save a checkpoint in the file
98    /// // `checkpoints/optimization.arg` in every iteration.
99    /// let checkpoint = FileCheckpoint::new(directory, filename, CheckpointingFrequency::Always);
100    /// # assert_eq!(checkpoint.frequency, CheckpointingFrequency::Always);
101    /// # assert_eq!(checkpoint.directory, PathBuf::from("checkpoints"));
102    /// # assert_eq!(checkpoint.filename, PathBuf::from("optimization.arg"));
103    /// ```
104    pub fn new<N: AsRef<str>>(directory: N, name: N, frequency: CheckpointingFrequency) -> Self {
105        FileCheckpoint {
106            frequency,
107            directory: PathBuf::from(directory.as_ref()),
108            filename: PathBuf::from(format!("{}.arg", name.as_ref())),
109        }
110    }
111}
112
113impl<S, I> Checkpoint<S, I> for FileCheckpoint
114where
115    S: Serialize + DeserializeOwned,
116    I: Serialize + DeserializeOwned,
117{
118    /// Writes checkpoint to disk.
119    ///
120    /// If the directory does not exist already, it will be created. It uses `bincode` to serialize
121    /// the data.
122    /// It will return an error if creating the directory or file or serialization failed.
123    ///
124    /// # Example
125    ///
126    /// ```
127    /// use argmin_checkpointing_file::{FileCheckpoint, CheckpointingFrequency, Checkpoint};
128    ///
129    /// # use std::fs::File;
130    /// # use std::io::BufReader;
131    /// # let checkpoint = FileCheckpoint::new(".checkpoints", "save_test" , CheckpointingFrequency::Always);
132    /// # let solver: u64 = 12;
133    /// # let state: u64 = 21;
134    /// # let _ = std::fs::remove_file(".checkpoints/save_test.arg");
135    /// checkpoint.save(&solver, &state);
136    /// # let (f_solver, f_state): (u64, u64) = bincode::deserialize_from(
137    /// #     BufReader::new(File::open(".checkpoints/save_test.arg").unwrap())
138    /// # ).unwrap();
139    /// # assert_eq!(solver, f_solver);
140    /// # assert_eq!(state, f_state);
141    /// # let _ = std::fs::remove_file(".checkpoints/save_test.arg");
142    /// ```
143    fn save(&self, solver: &S, state: &I) -> Result<(), Error> {
144        if !self.directory.exists() {
145            std::fs::create_dir_all(&self.directory)?
146        }
147        let fname = self.directory.join(&self.filename);
148        let f = BufWriter::new(File::create(fname)?);
149        bincode::serialize_into(f, &(solver, state))?;
150        Ok(())
151    }
152
153    /// Load a checkpoint from disk.
154    ///
155    ///
156    /// If there is no checkpoint on disk, it will return `Ok(None)`.
157    /// Returns an error if opening the file or deserialization failed.
158    ///
159    /// # Example
160    ///
161    /// ```
162    /// use argmin_checkpointing_file::{FileCheckpoint, CheckpointingFrequency, Checkpoint};
163    /// # use argmin::core::Error;
164    ///
165    /// # use std::fs::File;
166    /// # use std::io::BufWriter;
167    /// # fn main() -> Result<(), Error> {
168    /// # std::fs::DirBuilder::new().recursive(true).create(".checkpoints").unwrap();
169    /// # let f = BufWriter::new(File::create(".checkpoints/load_test.arg")?);
170    /// # let f_solver: u64 = 12;
171    /// # let f_state: u64 = 21;
172    /// # bincode::serialize_into(f, &(f_solver, f_state))?;
173    /// # let checkpoint = FileCheckpoint::new(".checkpoints", "load_test" , CheckpointingFrequency::Always);
174    /// let (solver, state) = checkpoint.load()?.unwrap();
175    /// # // Let the compiler know which types to expect.
176    /// # let blah1: u64 = solver;
177    /// # let blah2: u64 = state;
178    /// # assert_eq!(solver, f_solver);
179    /// # assert_eq!(state, f_state);
180    /// # let _ = std::fs::remove_file(".checkpoints/load_test.arg");
181    /// #
182    /// # // Return none if File does not exist
183    /// # let checkpoint = FileCheckpoint::new(".checkpoints", "certainly_does_not_exist" , CheckpointingFrequency::Always);
184    /// # let loaded: Option<(u64, u64)> = checkpoint.load()?;
185    /// # assert!(loaded.is_none());
186    /// # Ok(())
187    /// # }
188    /// ```
189    fn load(&self) -> Result<Option<(S, I)>, Error> {
190        let path = &self.directory.join(&self.filename);
191        if !path.exists() {
192            return Ok(None);
193        }
194        let file = File::open(path)?;
195        let reader = BufReader::new(file);
196        Ok(Some(bincode::deserialize_from(reader)?))
197    }
198
199    /// Returns the how often a checkpoint is to be saved.
200    ///
201    /// Used internally by [`save_cond`](`argmin::core::checkpointing::Checkpoint::save_cond`).
202    fn frequency(&self) -> CheckpointingFrequency {
203        self.frequency
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210    use argmin::core::test_utils::TestSolver;
211    use argmin::core::{IterState, State};
212
213    #[test]
214    #[allow(clippy::type_complexity)]
215    fn test_save() {
216        let solver = TestSolver::new();
217        let state: IterState<Vec<f64>, (), (), (), (), f64> =
218            IterState::new().param(vec![1.0f64, 0.0]);
219        let check = FileCheckpoint::new("checkpoints", "solver", CheckpointingFrequency::Always);
220        check.save_cond(&solver, &state, 20).unwrap();
221
222        let _loaded: Option<(TestSolver, IterState<Vec<f64>, (), (), (), (), f64>)> =
223            check.load().unwrap();
224    }
225}