argmin_checkpointing_file/lib.rs
1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! This crate creates checkpoints on disk for an optimization run.
9//!
10//! Saves a checkpoint on disk from which an interrupted optimization run can be resumed.
11//! For details on the usage please see the documentation of [`FileCheckpoint`] or have a look at
12//! the [example](https://github.com/argmin-rs/argmin/tree/main/examples/checkpoint).
13//!
14//! # Usage
15//!
16//! Add the following line to your dependencies list:
17//!
18//! ```toml
19//! [dependencies]
20#![doc = concat!("argmin-checkpointing-file = \"", env!("CARGO_PKG_VERSION"), "\"")]
21//! ```
22//!
23//! # License
24//!
25//! Licensed under either of
26//!
27//! * Apache License, Version 2.0,
28//! ([LICENSE-APACHE](https://github.com/argmin-rs/argmin/blob/main/LICENSE-APACHE) or
29//! <http://www.apache.org/licenses/LICENSE-2.0>)
30//! * MIT License ([LICENSE-MIT](https://github.com/argmin-rs/argmin/blob/main/LICENSE-MIT) or
31//! <http://opensource.org/licenses/MIT>)
32//!
33//! at your option.
34//!
35//! ## Contribution
36//!
37//! Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion
38//! in the work by you, as defined in the Apache-2.0 license, shall be dual licensed as above,
39//! without any additional terms or conditions.
40
41pub use argmin::core::checkpointing::{Checkpoint, CheckpointingFrequency};
42use argmin::core::Error;
43use serde::{de::DeserializeOwned, Serialize};
44use std::fs::File;
45use std::io::{BufReader, BufWriter};
46use std::path::PathBuf;
47
48/// Handles saving a checkpoint to disk as a binary file.
49#[derive(Clone, Eq, PartialEq, Debug, Hash)]
50pub struct FileCheckpoint {
51 /// Indicates how often a checkpoint is created
52 pub frequency: CheckpointingFrequency,
53 /// Directory where the checkpoints are saved to
54 pub directory: PathBuf,
55 /// Name of the checkpoint files
56 pub filename: PathBuf,
57}
58
59impl Default for FileCheckpoint {
60 /// Create a default `FileCheckpoint` instance.
61 ///
62 /// This will save the checkpoint in the file `.checkpoints/checkpoint.arg`.
63 ///
64 /// # Example
65 ///
66 /// ```
67 /// use argmin_checkpointing_file::FileCheckpoint;
68 /// # use argmin::core::checkpointing::CheckpointingFrequency;
69 /// # use std::path::PathBuf;
70 ///
71 /// let checkpoint = FileCheckpoint::default();
72 /// # assert_eq!(checkpoint.frequency, CheckpointingFrequency::default());
73 /// # assert_eq!(checkpoint.directory, PathBuf::from(".checkpoints"));
74 /// # assert_eq!(checkpoint.filename, PathBuf::from("checkpoint.arg"));
75 /// ```
76 fn default() -> FileCheckpoint {
77 FileCheckpoint {
78 frequency: CheckpointingFrequency::default(),
79 directory: PathBuf::from(".checkpoints"),
80 filename: PathBuf::from("checkpoint.arg"),
81 }
82 }
83}
84
85impl FileCheckpoint {
86 /// Create a new `FileCheckpoint` instance
87 ///
88 /// # Example
89 ///
90 /// ```
91 /// use argmin_checkpointing_file::{FileCheckpoint, CheckpointingFrequency};
92 /// # use std::path::PathBuf;
93 ///
94 /// let directory = "checkpoints";
95 /// let filename = "optimization";
96 ///
97 /// // When passed to an `Executor`, this will save a checkpoint in the file
98 /// // `checkpoints/optimization.arg` in every iteration.
99 /// let checkpoint = FileCheckpoint::new(directory, filename, CheckpointingFrequency::Always);
100 /// # assert_eq!(checkpoint.frequency, CheckpointingFrequency::Always);
101 /// # assert_eq!(checkpoint.directory, PathBuf::from("checkpoints"));
102 /// # assert_eq!(checkpoint.filename, PathBuf::from("optimization.arg"));
103 /// ```
104 pub fn new<N: AsRef<str>>(directory: N, name: N, frequency: CheckpointingFrequency) -> Self {
105 FileCheckpoint {
106 frequency,
107 directory: PathBuf::from(directory.as_ref()),
108 filename: PathBuf::from(format!("{}.arg", name.as_ref())),
109 }
110 }
111}
112
113impl<S, I> Checkpoint<S, I> for FileCheckpoint
114where
115 S: Serialize + DeserializeOwned,
116 I: Serialize + DeserializeOwned,
117{
118 /// Writes checkpoint to disk.
119 ///
120 /// If the directory does not exist already, it will be created. It uses `bincode` to serialize
121 /// the data.
122 /// It will return an error if creating the directory or file or serialization failed.
123 ///
124 /// # Example
125 ///
126 /// ```
127 /// use argmin_checkpointing_file::{FileCheckpoint, CheckpointingFrequency, Checkpoint};
128 ///
129 /// # use std::fs::File;
130 /// # use std::io::BufReader;
131 /// # let checkpoint = FileCheckpoint::new(".checkpoints", "save_test" , CheckpointingFrequency::Always);
132 /// # let solver: u64 = 12;
133 /// # let state: u64 = 21;
134 /// # let _ = std::fs::remove_file(".checkpoints/save_test.arg");
135 /// checkpoint.save(&solver, &state);
136 /// # let (f_solver, f_state): (u64, u64) = bincode::deserialize_from(
137 /// # BufReader::new(File::open(".checkpoints/save_test.arg").unwrap())
138 /// # ).unwrap();
139 /// # assert_eq!(solver, f_solver);
140 /// # assert_eq!(state, f_state);
141 /// # let _ = std::fs::remove_file(".checkpoints/save_test.arg");
142 /// ```
143 fn save(&self, solver: &S, state: &I) -> Result<(), Error> {
144 if !self.directory.exists() {
145 std::fs::create_dir_all(&self.directory)?
146 }
147 let fname = self.directory.join(&self.filename);
148 let f = BufWriter::new(File::create(fname)?);
149 bincode::serialize_into(f, &(solver, state))?;
150 Ok(())
151 }
152
153 /// Load a checkpoint from disk.
154 ///
155 ///
156 /// If there is no checkpoint on disk, it will return `Ok(None)`.
157 /// Returns an error if opening the file or deserialization failed.
158 ///
159 /// # Example
160 ///
161 /// ```
162 /// use argmin_checkpointing_file::{FileCheckpoint, CheckpointingFrequency, Checkpoint};
163 /// # use argmin::core::Error;
164 ///
165 /// # use std::fs::File;
166 /// # use std::io::BufWriter;
167 /// # fn main() -> Result<(), Error> {
168 /// # std::fs::DirBuilder::new().recursive(true).create(".checkpoints").unwrap();
169 /// # let f = BufWriter::new(File::create(".checkpoints/load_test.arg")?);
170 /// # let f_solver: u64 = 12;
171 /// # let f_state: u64 = 21;
172 /// # bincode::serialize_into(f, &(f_solver, f_state))?;
173 /// # let checkpoint = FileCheckpoint::new(".checkpoints", "load_test" , CheckpointingFrequency::Always);
174 /// let (solver, state) = checkpoint.load()?.unwrap();
175 /// # // Let the compiler know which types to expect.
176 /// # let blah1: u64 = solver;
177 /// # let blah2: u64 = state;
178 /// # assert_eq!(solver, f_solver);
179 /// # assert_eq!(state, f_state);
180 /// # let _ = std::fs::remove_file(".checkpoints/load_test.arg");
181 /// #
182 /// # // Return none if File does not exist
183 /// # let checkpoint = FileCheckpoint::new(".checkpoints", "certainly_does_not_exist" , CheckpointingFrequency::Always);
184 /// # let loaded: Option<(u64, u64)> = checkpoint.load()?;
185 /// # assert!(loaded.is_none());
186 /// # Ok(())
187 /// # }
188 /// ```
189 fn load(&self) -> Result<Option<(S, I)>, Error> {
190 let path = &self.directory.join(&self.filename);
191 if !path.exists() {
192 return Ok(None);
193 }
194 let file = File::open(path)?;
195 let reader = BufReader::new(file);
196 Ok(Some(bincode::deserialize_from(reader)?))
197 }
198
199 /// Returns the how often a checkpoint is to be saved.
200 ///
201 /// Used internally by [`save_cond`](`argmin::core::checkpointing::Checkpoint::save_cond`).
202 fn frequency(&self) -> CheckpointingFrequency {
203 self.frequency
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210 use argmin::core::test_utils::TestSolver;
211 use argmin::core::{IterState, State};
212
213 #[test]
214 #[allow(clippy::type_complexity)]
215 fn test_save() {
216 let solver = TestSolver::new();
217 let state: IterState<Vec<f64>, (), (), (), (), f64> =
218 IterState::new().param(vec![1.0f64, 0.0]);
219 let check = FileCheckpoint::new("checkpoints", "solver", CheckpointingFrequency::Always);
220 check.save_cond(&solver, &state, 20).unwrap();
221
222 let _loaded: Option<(TestSolver, IterState<Vec<f64>, (), (), (), (), f64>)> =
223 check.load().unwrap();
224 }
225}