argmin/core/
test_utils.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use crate::core::{
9    CostFunction, Error, Gradient, Hessian, IterState, Jacobian, Operator, Problem, Solver, KV,
10};
11#[cfg(feature = "rand")]
12use crate::solver::simulatedannealing::Anneal;
13#[cfg(feature = "serde1")]
14use serde::{Deserialize, Serialize};
15use std::fmt::Debug;
16
17/// Pseudo problem useful for testing
18///
19/// Implements [`CostFunction`], [`Operator`], [`Gradient`], [`Jacobian`], [`Hessian`], and
20/// [`Anneal`].
21#[derive(Clone, Copy, Default, Debug, Eq, PartialEq, Hash)]
22#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
23pub struct TestProblem {}
24
25impl TestProblem {
26    /// Create an instance of `TestProblem`.
27    ///
28    /// # Example
29    ///
30    /// ```
31    /// use argmin::core::test_utils::TestProblem;
32    ///
33    /// let problem = TestProblem::new();
34    /// # assert_eq!(problem, TestProblem {});
35    /// ```
36    #[allow(dead_code)]
37    pub fn new() -> Self {
38        TestProblem {}
39    }
40}
41
42impl Operator for TestProblem {
43    type Param = Vec<f64>;
44    type Output = Vec<f64>;
45
46    /// Returns a clone of parameter `p`.
47    ///
48    /// # Example
49    ///
50    /// ```
51    /// use argmin::core::test_utils::TestProblem;
52    /// use argmin::core::Operator;
53    /// # use argmin::core::Error;
54    ///
55    /// # fn main() -> Result<(), Error> {
56    /// let problem = TestProblem::new();
57    ///
58    /// let param = vec![1.0, 2.0];
59    ///
60    /// let res = problem.apply(&param)?;
61    /// # assert_eq!(res, param);
62    /// # Ok(())
63    /// # }
64    /// ```
65    fn apply(&self, p: &Self::Param) -> Result<Self::Output, Error> {
66        Ok(p.clone())
67    }
68}
69
70impl CostFunction for TestProblem {
71    type Param = Vec<f64>;
72    type Output = f64;
73
74    /// Returns `1.0f64`.
75    ///
76    /// # Example
77    ///
78    /// ```
79    /// use argmin::core::test_utils::TestProblem;
80    /// use argmin::core::CostFunction;
81    /// # use argmin::core::Error;
82    ///
83    /// # fn main() -> Result<(), Error> {
84    /// let problem = TestProblem::new();
85    ///
86    /// let param = vec![1.0, 2.0];
87    ///
88    /// let res = problem.cost(&param)?;
89    /// # assert_eq!(res, 1.0f64);
90    /// # Ok(())
91    /// # }
92    /// ```
93    fn cost(&self, _p: &Self::Param) -> Result<Self::Output, Error> {
94        Ok(1.0f64)
95    }
96}
97
98impl Gradient for TestProblem {
99    type Param = Vec<f64>;
100    type Gradient = Vec<f64>;
101
102    /// Returns a clone of parameter `p`.
103    ///
104    /// # Example
105    ///
106    /// ```
107    /// use argmin::core::test_utils::TestProblem;
108    /// use argmin::core::Gradient;
109    /// # use argmin::core::Error;
110    ///
111    /// # fn main() -> Result<(), Error> {
112    /// let problem = TestProblem::new();
113    ///
114    /// let param = vec![1.0, 2.0];
115    ///
116    /// let res = problem.gradient(&param)?;
117    /// # assert_eq!(res, param);
118    /// # Ok(())
119    /// # }
120    /// ```
121    fn gradient(&self, p: &Self::Param) -> Result<Self::Param, Error> {
122        Ok(p.clone())
123    }
124}
125
126impl Hessian for TestProblem {
127    type Param = Vec<f64>;
128    type Hessian = Vec<Vec<f64>>;
129
130    /// Returns `vec![p, p]`.
131    ///
132    /// # Example
133    ///
134    /// ```
135    /// use argmin::core::test_utils::TestProblem;
136    /// use argmin::core::Hessian;
137    /// # use argmin::core::Error;
138    ///
139    /// # fn main() -> Result<(), Error> {
140    /// let problem = TestProblem::new();
141    ///
142    /// let param = vec![1.0, 2.0];
143    ///
144    /// let res = problem.hessian(&param)?;
145    /// # assert_eq!(res, vec![param.clone(), param.clone()]);
146    /// # Ok(())
147    /// # }
148    /// ```
149    fn hessian(&self, p: &Self::Param) -> Result<Self::Hessian, Error> {
150        Ok(vec![p.clone(), p.clone()])
151    }
152}
153
154impl Jacobian for TestProblem {
155    type Param = Vec<f64>;
156    type Jacobian = Vec<Vec<f64>>;
157
158    /// Returns `vec![p, p]`.
159    ///
160    /// # Example
161    ///
162    /// ```
163    /// use argmin::core::test_utils::TestProblem;
164    /// use argmin::core::Jacobian;
165    /// # use argmin::core::Error;
166    ///
167    /// # fn main() -> Result<(), Error> {
168    /// let problem = TestProblem::new();
169    ///
170    /// let param = vec![1.0, 2.0];
171    ///
172    /// let res = problem.jacobian(&param)?;
173    /// # assert_eq!(res, vec![param.clone(), param.clone()]);
174    /// # Ok(())
175    /// # }
176    /// ```
177    fn jacobian(&self, p: &Self::Param) -> Result<Self::Jacobian, Error> {
178        Ok(vec![p.clone(), p.clone()])
179    }
180}
181
182#[cfg(feature = "rand")]
183impl Anneal for TestProblem {
184    type Param = Vec<f64>;
185    type Output = Vec<f64>;
186    type Float = f64;
187
188    /// Returns a clone of parameter `p`.
189    ///
190    /// # Example
191    ///
192    /// ```
193    /// use argmin::core::test_utils::TestProblem;
194    /// use argmin::solver::simulatedannealing::Anneal;
195    /// # use argmin::core::Error;
196    ///
197    /// # fn main() -> Result<(), Error> {
198    /// let problem = TestProblem::new();
199    ///
200    /// let param = vec![1.0, 2.0];
201    ///
202    /// let res = problem.anneal(&param, 1.0)?;
203    /// # assert_eq!(res, param);
204    /// # Ok(())
205    /// # }
206    /// ```
207    fn anneal(&self, p: &Self::Param, _t: Self::Float) -> Result<Self::Output, Error> {
208        Ok(p.clone())
209    }
210}
211
212/// A struct representing the following sparse problem.
213///
214/// Example 1: x = [1, 1, 0, 0], y =  1
215/// Example 2: x = [0, 0, 1, 1], y = -1
216/// Example 3: x = [1, 0, 0, 0], y =  1
217/// Example 4: x = [0, 0, 1, 0], y = -1
218///
219/// cost = Σ (w^T x - y)^2
220///
221/// Implements [`CostFunction`] and [`Gradient`].
222#[derive(Clone, Copy, Default, Debug, Eq, PartialEq, Hash)]
223#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
224pub struct TestSparseProblem {}
225
226impl TestSparseProblem {
227    /// Create an instance of `TestSparseProblem`.
228    ///
229    /// # Example
230    ///
231    /// ```
232    /// use argmin::core::test_utils::TestSparseProblem;
233    ///
234    /// let problem = TestSparseProblem::new();
235    /// # assert_eq!(problem, TestSparseProblem {});
236    /// ```
237    #[allow(dead_code)]
238    pub fn new() -> Self {
239        TestSparseProblem {}
240    }
241}
242
243impl CostFunction for TestSparseProblem {
244    type Param = Vec<f64>;
245    type Output = f64;
246
247    /// Returns a sum of squared errors.
248    ///
249    /// # Example
250    ///
251    /// ```
252    /// use argmin::core::test_utils::TestSparseProblem;
253    /// use argmin::core::CostFunction;
254    /// # use argmin::core::Error;
255    ///
256    /// # fn main() -> Result<(), Error> {
257    /// let problem = TestSparseProblem::new();
258    ///
259    /// let param = vec![1.0, 2.0, 3.0, 4.0];
260    ///
261    /// let res = problem.cost(&param)?;
262    /// # assert_eq!(res, 84f64);
263    /// # Ok(())
264    /// # }
265    /// ```
266    fn cost(&self, param: &Self::Param) -> Result<Self::Output, Error> {
267        let err1 = (param[0] + param[1] - 1.0).powi(2);
268        let err2 = (param[2] + param[3] + 1.0).powi(2);
269        let err3 = (param[0] - 1.0).powi(2);
270        let err4 = (param[2] + 1.0).powi(2);
271        Ok(err1 + err2 + err3 + err4)
272    }
273}
274
275impl Gradient for TestSparseProblem {
276    type Param = Vec<f64>;
277    type Gradient = Vec<f64>;
278
279    /// Returns a gradient of the cost function.
280    ///
281    /// # Example
282    ///
283    /// ```
284    /// use argmin::core::test_utils::TestSparseProblem;
285    /// use argmin::core::Gradient;
286    /// # use argmin::core::Error;
287    ///
288    /// # fn main() -> Result<(), Error> {
289    /// let problem = TestSparseProblem::new();
290    ///
291    /// let param = vec![1.0, 2.0, 3.0, 4.0];
292    ///
293    /// let res = problem.gradient(&param)?;
294    /// # assert_eq!(res, vec![4.0, 4.0, 24.0, 16.0]);
295    /// # Ok(())
296    /// # }
297    /// ```
298    fn gradient(&self, param: &Self::Param) -> Result<Self::Gradient, Error> {
299        let mut g = vec![0.0; 4];
300        g[0] = 4.0 * param[0] + 2.0 * param[1] - 4.0;
301        g[1] = 2.0 * param[0] + 2.0 * param[1] - 2.0;
302        g[2] = 4.0 * param[2] + 2.0 * param[3] + 4.0;
303        g[3] = 2.0 * param[2] + 2.0 * param[3] + 2.0;
304        Ok(g)
305    }
306}
307
308/// A (non-working) solver useful for testing
309///
310/// Implements the [`Solver`] trait.
311#[derive(Clone, Copy, Default, Eq, PartialEq, Debug)]
312#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
313pub struct TestSolver {}
314
315impl TestSolver {
316    /// Create an instance of `TestSolver`.
317    ///
318    /// # Example
319    ///
320    /// ```
321    /// use argmin::core::test_utils::TestSolver;
322    ///
323    /// let solver = TestSolver::new();
324    /// # assert_eq!(solver, TestSolver {});
325    /// ```
326    pub fn new() -> TestSolver {
327        TestSolver {}
328    }
329}
330
331impl<O> Solver<O, IterState<Vec<f64>, (), (), (), (), f64>> for TestSolver {
332    fn name(&self) -> &str {
333        "TestSolver"
334    }
335
336    fn next_iter(
337        &mut self,
338        _problem: &mut Problem<O>,
339        state: IterState<Vec<f64>, (), (), (), (), f64>,
340    ) -> Result<(IterState<Vec<f64>, (), (), (), (), f64>, Option<KV>), Error> {
341        Ok((state, None))
342    }
343}