argmin_observer_paramwriter/
lib.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! Writes parameter vectors to a file during optimization.
9//!
10//! See documentation of [`ParamWriter`] and [`ParamWriterFormat`] for details.
11//!
12//! # Usage
13//!
14//! Add the following line to your dependencies list:
15//!
16//! ```toml
17//! [dependencies]
18#![doc = concat!("argmin-observer-paramwriter = \"", env!("CARGO_PKG_VERSION"), "\"")]
19//! ```
20//!
21//! # License
22//!
23//! Licensed under either of
24//!
25//!   * Apache License, Version 2.0,
26//!     ([LICENSE-APACHE](https://github.com/argmin-rs/argmin/blob/main/LICENSE-APACHE) or
27//!     <http://www.apache.org/licenses/LICENSE-2.0>)
28//!   * MIT License ([LICENSE-MIT](https://github.com/argmin-rs/argmin/blob/main/LICENSE-MIT) or
29//!     <http://opensource.org/licenses/MIT>)
30//!
31//! at your option.
32//!
33//! ## Contribution
34//!
35//! Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion
36//! in the work by you, as defined in the Apache-2.0 license, shall be dual licensed as above,
37//! without any additional terms or conditions.
38
39use argmin::core::observers::Observe;
40use argmin::core::{Error, State, KV};
41use serde::Serialize;
42use std::default::Default;
43use std::fs::File;
44use std::io::BufWriter;
45use std::path::PathBuf;
46
47/// Write parameter vectors to a file during optimization.
48///
49/// This observer requires a directory to save the files to and a file prefix. Files will be
50/// written to disk as `<directory>/<file_prefix>_<iteration_number>.<extension>`. For
51/// serialization either `JSON` or `Binary` (via [bincode](https://crates.io/crates/bincode))
52/// can be chosen via the enum [`ParamWriterFormat`].
53///
54/// # Example
55///
56/// Create an observer for saving the parameter vector into a JSON file.
57///
58/// ```
59/// use argmin_observer_paramwriter::{ParamWriter, ParamWriterFormat};
60///
61/// let observer = ParamWriter::new("directory", "file_prefix", ParamWriterFormat::JSON);
62/// ```
63///
64/// Create an observer for saving the parameter vector into a binary file using
65/// [`bincode`](https://crates.io/crates/bincode).
66///
67/// ```
68/// use argmin_observer_paramwriter::{ParamWriter, ParamWriterFormat};
69///
70/// let observer = ParamWriter::new("directory", "file_prefix", ParamWriterFormat::Binary);
71/// ```
72#[derive(Clone, Debug, Eq, PartialEq)]
73pub struct ParamWriter {
74    /// Directory where files are saved to
75    dir: PathBuf,
76    /// File prefix
77    prefix: String,
78    /// Chosen serializer
79    serializer: ParamWriterFormat,
80}
81
82impl ParamWriter {
83    /// Create a new instance of `ParamWriter`.
84    ///
85    /// # Example
86    /// ```
87    /// # use argmin_observer_paramwriter::{ParamWriter, ParamWriterFormat};
88    /// let observer = ParamWriter::new("directory", "file_prefix", ParamWriterFormat::JSON);
89    /// ```
90    pub fn new<N: AsRef<str>>(dir: N, prefix: N, serializer: ParamWriterFormat) -> Self {
91        ParamWriter {
92            dir: PathBuf::from(dir.as_ref()),
93            prefix: String::from(prefix.as_ref()),
94            serializer,
95        }
96    }
97}
98
99/// `ParamWriter` only implements `observer_iter` and not `observe_init` to avoid saving the
100/// initial parameter vector. It will only save if there is a parameter vector available in the
101/// state, otherwise it will skip saving silently.
102impl<I> Observe<I> for ParamWriter
103where
104    I: State,
105    <I as State>::Param: Serialize,
106{
107    fn observe_iter(&mut self, state: &I, _kv: &KV) -> Result<(), Error> {
108        if let Some(param) = state.get_param() {
109            let iter = state.get_iter();
110            if !self.dir.exists() {
111                std::fs::create_dir_all(&self.dir)?
112            }
113
114            let fname = self.dir.join(format!(
115                "{}_{}.{}",
116                self.prefix,
117                iter,
118                self.serializer.extension()
119            ));
120            let f = BufWriter::new(File::create(fname)?);
121
122            match self.serializer {
123                ParamWriterFormat::Binary => {
124                    bincode::serialize_into(f, param)?;
125                }
126                ParamWriterFormat::JSON => {
127                    serde_json::to_writer_pretty(f, param)?;
128                }
129            }
130        }
131        Ok(())
132    }
133}
134
135/// Available serializers for [`ParamWriter`].
136///
137/// # Extensions
138///
139/// * JSON: `.json`
140/// * Binary: `.bin`
141///
142/// # Example
143///
144/// ```
145/// use argmin_observer_paramwriter::ParamWriterFormat;
146///
147/// let bincode = ParamWriterFormat::Binary;
148/// let json = ParamWriterFormat::JSON;
149/// ```
150#[derive(Copy, Clone, Debug, Eq, PartialEq, Default)]
151pub enum ParamWriterFormat {
152    /// Use [`bincode`](https://crates.io/crates/bincode) for creating binary files
153    #[default]
154    Binary,
155    /// Use [`serde_json`](https://crates.io/crates/serde_json) for creating JSON files
156    JSON,
157}
158
159impl ParamWriterFormat {
160    pub fn extension(&self) -> &str {
161        match *self {
162            ParamWriterFormat::Binary => "bin",
163            ParamWriterFormat::JSON => "json",
164        }
165    }
166}
167
168#[cfg(test)]
169mod tests {}