argmin/core/test_utils.rs
1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use crate::core::{
9 CostFunction, Error, Gradient, Hessian, IterState, Jacobian, Operator, Problem, Solver, KV,
10};
11#[cfg(feature = "rand")]
12use crate::solver::simulatedannealing::Anneal;
13#[cfg(feature = "serde1")]
14use serde::{Deserialize, Serialize};
15use std::fmt::Debug;
16
17/// Pseudo problem useful for testing
18///
19/// Implements [`CostFunction`], [`Operator`], [`Gradient`], [`Jacobian`], [`Hessian`], and
20/// [`Anneal`].
21#[derive(Clone, Copy, Default, Debug, Eq, PartialEq, Hash)]
22#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
23pub struct TestProblem {}
24
25impl TestProblem {
26 /// Create an instance of `TestProblem`.
27 ///
28 /// # Example
29 ///
30 /// ```
31 /// use argmin::core::test_utils::TestProblem;
32 ///
33 /// let problem = TestProblem::new();
34 /// # assert_eq!(problem, TestProblem {});
35 /// ```
36 #[allow(dead_code)]
37 pub fn new() -> Self {
38 TestProblem {}
39 }
40}
41
42impl Operator for TestProblem {
43 type Param = Vec<f64>;
44 type Output = Vec<f64>;
45
46 /// Returns a clone of parameter `p`.
47 ///
48 /// # Example
49 ///
50 /// ```
51 /// use argmin::core::test_utils::TestProblem;
52 /// use argmin::core::Operator;
53 /// # use argmin::core::Error;
54 ///
55 /// # fn main() -> Result<(), Error> {
56 /// let problem = TestProblem::new();
57 ///
58 /// let param = vec![1.0, 2.0];
59 ///
60 /// let res = problem.apply(¶m)?;
61 /// # assert_eq!(res, param);
62 /// # Ok(())
63 /// # }
64 /// ```
65 fn apply(&self, p: &Self::Param) -> Result<Self::Output, Error> {
66 Ok(p.clone())
67 }
68}
69
70impl CostFunction for TestProblem {
71 type Param = Vec<f64>;
72 type Output = f64;
73
74 /// Returns `1.0f64`.
75 ///
76 /// # Example
77 ///
78 /// ```
79 /// use argmin::core::test_utils::TestProblem;
80 /// use argmin::core::CostFunction;
81 /// # use argmin::core::Error;
82 ///
83 /// # fn main() -> Result<(), Error> {
84 /// let problem = TestProblem::new();
85 ///
86 /// let param = vec![1.0, 2.0];
87 ///
88 /// let res = problem.cost(¶m)?;
89 /// # assert_eq!(res, 1.0f64);
90 /// # Ok(())
91 /// # }
92 /// ```
93 fn cost(&self, _p: &Self::Param) -> Result<Self::Output, Error> {
94 Ok(1.0f64)
95 }
96}
97
98impl Gradient for TestProblem {
99 type Param = Vec<f64>;
100 type Gradient = Vec<f64>;
101
102 /// Returns a clone of parameter `p`.
103 ///
104 /// # Example
105 ///
106 /// ```
107 /// use argmin::core::test_utils::TestProblem;
108 /// use argmin::core::Gradient;
109 /// # use argmin::core::Error;
110 ///
111 /// # fn main() -> Result<(), Error> {
112 /// let problem = TestProblem::new();
113 ///
114 /// let param = vec![1.0, 2.0];
115 ///
116 /// let res = problem.gradient(¶m)?;
117 /// # assert_eq!(res, param);
118 /// # Ok(())
119 /// # }
120 /// ```
121 fn gradient(&self, p: &Self::Param) -> Result<Self::Param, Error> {
122 Ok(p.clone())
123 }
124}
125
126impl Hessian for TestProblem {
127 type Param = Vec<f64>;
128 type Hessian = Vec<Vec<f64>>;
129
130 /// Returns `vec![p, p]`.
131 ///
132 /// # Example
133 ///
134 /// ```
135 /// use argmin::core::test_utils::TestProblem;
136 /// use argmin::core::Hessian;
137 /// # use argmin::core::Error;
138 ///
139 /// # fn main() -> Result<(), Error> {
140 /// let problem = TestProblem::new();
141 ///
142 /// let param = vec![1.0, 2.0];
143 ///
144 /// let res = problem.hessian(¶m)?;
145 /// # assert_eq!(res, vec![param.clone(), param.clone()]);
146 /// # Ok(())
147 /// # }
148 /// ```
149 fn hessian(&self, p: &Self::Param) -> Result<Self::Hessian, Error> {
150 Ok(vec![p.clone(), p.clone()])
151 }
152}
153
154impl Jacobian for TestProblem {
155 type Param = Vec<f64>;
156 type Jacobian = Vec<Vec<f64>>;
157
158 /// Returns `vec![p, p]`.
159 ///
160 /// # Example
161 ///
162 /// ```
163 /// use argmin::core::test_utils::TestProblem;
164 /// use argmin::core::Jacobian;
165 /// # use argmin::core::Error;
166 ///
167 /// # fn main() -> Result<(), Error> {
168 /// let problem = TestProblem::new();
169 ///
170 /// let param = vec![1.0, 2.0];
171 ///
172 /// let res = problem.jacobian(¶m)?;
173 /// # assert_eq!(res, vec![param.clone(), param.clone()]);
174 /// # Ok(())
175 /// # }
176 /// ```
177 fn jacobian(&self, p: &Self::Param) -> Result<Self::Jacobian, Error> {
178 Ok(vec![p.clone(), p.clone()])
179 }
180}
181
182#[cfg(feature = "rand")]
183impl Anneal for TestProblem {
184 type Param = Vec<f64>;
185 type Output = Vec<f64>;
186 type Float = f64;
187
188 /// Returns a clone of parameter `p`.
189 ///
190 /// # Example
191 ///
192 /// ```
193 /// use argmin::core::test_utils::TestProblem;
194 /// use argmin::solver::simulatedannealing::Anneal;
195 /// # use argmin::core::Error;
196 ///
197 /// # fn main() -> Result<(), Error> {
198 /// let problem = TestProblem::new();
199 ///
200 /// let param = vec![1.0, 2.0];
201 ///
202 /// let res = problem.anneal(¶m, 1.0)?;
203 /// # assert_eq!(res, param);
204 /// # Ok(())
205 /// # }
206 /// ```
207 fn anneal(&self, p: &Self::Param, _t: Self::Float) -> Result<Self::Output, Error> {
208 Ok(p.clone())
209 }
210}
211
212/// A struct representing the following sparse problem.
213///
214/// Example 1: x = [1, 1, 0, 0], y = 1
215/// Example 2: x = [0, 0, 1, 1], y = -1
216/// Example 3: x = [1, 0, 0, 0], y = 1
217/// Example 4: x = [0, 0, 1, 0], y = -1
218///
219/// cost = Σ (w^T x - y)^2
220///
221/// Implements [`CostFunction`] and [`Gradient`].
222#[derive(Clone, Copy, Default, Debug, Eq, PartialEq, Hash)]
223#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
224pub struct TestSparseProblem {}
225
226impl TestSparseProblem {
227 /// Create an instance of `TestSparseProblem`.
228 ///
229 /// # Example
230 ///
231 /// ```
232 /// use argmin::core::test_utils::TestSparseProblem;
233 ///
234 /// let problem = TestSparseProblem::new();
235 /// # assert_eq!(problem, TestSparseProblem {});
236 /// ```
237 #[allow(dead_code)]
238 pub fn new() -> Self {
239 TestSparseProblem {}
240 }
241}
242
243impl CostFunction for TestSparseProblem {
244 type Param = Vec<f64>;
245 type Output = f64;
246
247 /// Returns a sum of squared errors.
248 ///
249 /// # Example
250 ///
251 /// ```
252 /// use argmin::core::test_utils::TestSparseProblem;
253 /// use argmin::core::CostFunction;
254 /// # use argmin::core::Error;
255 ///
256 /// # fn main() -> Result<(), Error> {
257 /// let problem = TestSparseProblem::new();
258 ///
259 /// let param = vec![1.0, 2.0, 3.0, 4.0];
260 ///
261 /// let res = problem.cost(¶m)?;
262 /// # assert_eq!(res, 84f64);
263 /// # Ok(())
264 /// # }
265 /// ```
266 fn cost(&self, param: &Self::Param) -> Result<Self::Output, Error> {
267 let err1 = (param[0] + param[1] - 1.0).powi(2);
268 let err2 = (param[2] + param[3] + 1.0).powi(2);
269 let err3 = (param[0] - 1.0).powi(2);
270 let err4 = (param[2] + 1.0).powi(2);
271 Ok(err1 + err2 + err3 + err4)
272 }
273}
274
275impl Gradient for TestSparseProblem {
276 type Param = Vec<f64>;
277 type Gradient = Vec<f64>;
278
279 /// Returns a gradient of the cost function.
280 ///
281 /// # Example
282 ///
283 /// ```
284 /// use argmin::core::test_utils::TestSparseProblem;
285 /// use argmin::core::Gradient;
286 /// # use argmin::core::Error;
287 ///
288 /// # fn main() -> Result<(), Error> {
289 /// let problem = TestSparseProblem::new();
290 ///
291 /// let param = vec![1.0, 2.0, 3.0, 4.0];
292 ///
293 /// let res = problem.gradient(¶m)?;
294 /// # assert_eq!(res, vec![4.0, 4.0, 24.0, 16.0]);
295 /// # Ok(())
296 /// # }
297 /// ```
298 fn gradient(&self, param: &Self::Param) -> Result<Self::Gradient, Error> {
299 let mut g = vec![0.0; 4];
300 g[0] = 4.0 * param[0] + 2.0 * param[1] - 4.0;
301 g[1] = 2.0 * param[0] + 2.0 * param[1] - 2.0;
302 g[2] = 4.0 * param[2] + 2.0 * param[3] + 4.0;
303 g[3] = 2.0 * param[2] + 2.0 * param[3] + 2.0;
304 Ok(g)
305 }
306}
307
308/// A (non-working) solver useful for testing
309///
310/// Implements the [`Solver`] trait.
311#[derive(Clone, Copy, Default, Eq, PartialEq, Debug)]
312#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
313pub struct TestSolver {}
314
315impl TestSolver {
316 /// Create an instance of `TestSolver`.
317 ///
318 /// # Example
319 ///
320 /// ```
321 /// use argmin::core::test_utils::TestSolver;
322 ///
323 /// let solver = TestSolver::new();
324 /// # assert_eq!(solver, TestSolver {});
325 /// ```
326 pub fn new() -> TestSolver {
327 TestSolver {}
328 }
329}
330
331impl<O> Solver<O, IterState<Vec<f64>, (), (), (), (), f64>> for TestSolver {
332 fn name(&self) -> &str {
333 "TestSolver"
334 }
335
336 fn next_iter(
337 &mut self,
338 _problem: &mut Problem<O>,
339 state: IterState<Vec<f64>, (), (), (), (), f64>,
340 ) -> Result<(IterState<Vec<f64>, (), (), (), (), f64>, Option<KV>), Error> {
341 Ok((state, None))
342 }
343}