argmin_observer_paramwriter/lib.rs
1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! Writes parameter vectors to a file during optimization.
9//!
10//! See documentation of [`ParamWriter`] and [`ParamWriterFormat`] for details.
11//!
12//! # Usage
13//!
14//! Add the following line to your dependencies list:
15//!
16//! ```toml
17//! [dependencies]
18#![doc = concat!("argmin-observer-paramwriter = \"", env!("CARGO_PKG_VERSION"), "\"")]
19//! ```
20//!
21//! # License
22//!
23//! Licensed under either of
24//!
25//! * Apache License, Version 2.0,
26//! ([LICENSE-APACHE](https://github.com/argmin-rs/argmin/blob/main/LICENSE-APACHE) or
27//! <http://www.apache.org/licenses/LICENSE-2.0>)
28//! * MIT License ([LICENSE-MIT](https://github.com/argmin-rs/argmin/blob/main/LICENSE-MIT) or
29//! <http://opensource.org/licenses/MIT>)
30//!
31//! at your option.
32//!
33//! ## Contribution
34//!
35//! Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion
36//! in the work by you, as defined in the Apache-2.0 license, shall be dual licensed as above,
37//! without any additional terms or conditions.
38
39use argmin::core::observers::Observe;
40use argmin::core::{Error, State, KV};
41use serde::Serialize;
42use std::default::Default;
43use std::fs::File;
44use std::io::BufWriter;
45use std::path::PathBuf;
46
47/// Write parameter vectors to a file during optimization.
48///
49/// This observer requires a directory to save the files to and a file prefix. Files will be
50/// written to disk as `<directory>/<file_prefix>_<iteration_number>.<extension>`. For
51/// serialization either `JSON` or `Binary` (via [bincode](https://crates.io/crates/bincode))
52/// can be chosen via the enum [`ParamWriterFormat`].
53///
54/// # Example
55///
56/// Create an observer for saving the parameter vector into a JSON file.
57///
58/// ```
59/// use argmin_observer_paramwriter::{ParamWriter, ParamWriterFormat};
60///
61/// let observer = ParamWriter::new("directory", "file_prefix", ParamWriterFormat::JSON);
62/// ```
63///
64/// Create an observer for saving the parameter vector into a binary file using
65/// [`bincode`](https://crates.io/crates/bincode).
66///
67/// ```
68/// use argmin_observer_paramwriter::{ParamWriter, ParamWriterFormat};
69///
70/// let observer = ParamWriter::new("directory", "file_prefix", ParamWriterFormat::Binary);
71/// ```
72#[derive(Clone, Debug, Eq, PartialEq)]
73pub struct ParamWriter {
74 /// Directory where files are saved to
75 dir: PathBuf,
76 /// File prefix
77 prefix: String,
78 /// Chosen serializer
79 serializer: ParamWriterFormat,
80}
81
82impl ParamWriter {
83 /// Create a new instance of `ParamWriter`.
84 ///
85 /// # Example
86 /// ```
87 /// # use argmin_observer_paramwriter::{ParamWriter, ParamWriterFormat};
88 /// let observer = ParamWriter::new("directory", "file_prefix", ParamWriterFormat::JSON);
89 /// ```
90 pub fn new<N: AsRef<str>>(dir: N, prefix: N, serializer: ParamWriterFormat) -> Self {
91 ParamWriter {
92 dir: PathBuf::from(dir.as_ref()),
93 prefix: String::from(prefix.as_ref()),
94 serializer,
95 }
96 }
97}
98
99/// `ParamWriter` only implements `observer_iter` and not `observe_init` to avoid saving the
100/// initial parameter vector. It will only save if there is a parameter vector available in the
101/// state, otherwise it will skip saving silently.
102impl<I> Observe<I> for ParamWriter
103where
104 I: State,
105 <I as State>::Param: Serialize,
106{
107 fn observe_iter(&mut self, state: &I, _kv: &KV) -> Result<(), Error> {
108 if let Some(param) = state.get_param() {
109 let iter = state.get_iter();
110 if !self.dir.exists() {
111 std::fs::create_dir_all(&self.dir)?
112 }
113
114 let fname = self.dir.join(format!(
115 "{}_{}.{}",
116 self.prefix,
117 iter,
118 self.serializer.extension()
119 ));
120 let f = BufWriter::new(File::create(fname)?);
121
122 match self.serializer {
123 ParamWriterFormat::Binary => {
124 bincode::serialize_into(f, param)?;
125 }
126 ParamWriterFormat::JSON => {
127 serde_json::to_writer_pretty(f, param)?;
128 }
129 }
130 }
131 Ok(())
132 }
133}
134
135/// Available serializers for [`ParamWriter`].
136///
137/// # Extensions
138///
139/// * JSON: `.json`
140/// * Binary: `.bin`
141///
142/// # Example
143///
144/// ```
145/// use argmin_observer_paramwriter::ParamWriterFormat;
146///
147/// let bincode = ParamWriterFormat::Binary;
148/// let json = ParamWriterFormat::JSON;
149/// ```
150#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
151pub enum ParamWriterFormat {
152 /// Use [`bincode`](https://crates.io/crates/bincode) for creating binary files
153 #[default]
154 Binary,
155 /// Use [`serde_json`](https://crates.io/crates/serde_json) for creating JSON files
156 JSON,
157}
158
159impl ParamWriterFormat {
160 pub fn extension(&self) -> &str {
161 match *self {
162 ParamWriterFormat::Binary => "bin",
163 ParamWriterFormat::JSON => "json",
164 }
165 }
166}
167
168#[cfg(test)]
169mod tests {}