1use argmin_testfunctions::*;
2use paste::paste;
3use pyo3::exceptions::PyValueError;
4use pyo3::prelude::*;
5use std::stringify;
6
7#[macro_export]
8macro_rules! func {
9 ($function:ident) => {
10 func!(name = $function, function = $function,);
11 };
12 ($function:ident, num = $num:expr) => {
13 func!(name = $function, function = $function, num = $num,);
14 };
15 (name = $name:ident, function = $function:ident, $($a:ident : $t:ty = $v:expr),* ) => {
16 paste! {
17 #[pyfunction(name = $name "", signature = (param, $($a = $v),*))]
18 fn [<$name _py>](param: Vec<f64>, $($a: $t),*) -> f64 {
19 $function(¶m[..], $($a),*)
20 }
21
22 #[pyfunction(name = $name "_derivative", signature = (param, $($a = $v),*))]
23 fn [<$name _derivative_py>](param: Vec<f64>, $($a: $t),*) -> Vec<f64> {
24 [<$function _derivative>](¶m[..], $($a),*)
25 }
26
27 #[pyfunction(name = $name "_hessian", signature = (param, $($a = $v),*))]
28 fn [<$name _hessian_py>](param: Vec<f64>, $($a: $t),*) -> Vec<Vec<f64>> {
29 [<$function _hessian>](¶m[..], $($a),*)
30 }
31 }
32 };
33 (name = $name:ident, function = $function:ident, num = $num:expr, $($a:ident : $t:ty = $v:expr),* ) => {
34 paste! {
35 #[pyfunction(name = $name "", signature = (param, $($a = $v),*))]
36 fn [<$name _py>](param: Vec<f64>, $($a: $t),*) -> PyResult<f64> {
37 let n = param.len();
38 if let Ok(param) = param.try_into() {
39 Ok($function(¶m, $($a),*))
40 } else {
41 Err(PyValueError::new_err(format!("incompatible number of parameters: expected {}, found {}", stringify!($num), n)))
42 }
43 }
44
45 #[pyfunction(name = $name "_derivative", signature = (param, $($a = $v),*))]
46 fn [<$name _derivative_py>](param: Vec<f64>, $($a: $t),*) -> PyResult<Vec<f64>> {
47 let n = param.len();
48 if let Ok(param) = param.try_into() {
49 Ok([<$function _derivative>](¶m, $($a),*).to_vec())
50 } else {
51 Err(PyValueError::new_err(format!("incompatible number of parameters: expected {}, found {}", stringify!($num), n)))
52 }
53 }
54
55 #[pyfunction(name = $name "_hessian", signature = (param, $($a = $v),*))]
56 fn [<$name _hessian_py>](param: Vec<f64>, $($a: $t),*) -> PyResult<Vec<Vec<f64>>> {
57 let n = param.len();
58 if let Ok(param) = param.try_into() {
59 Ok([<$function _hessian>](¶m, $($a),*).iter().map(|r| r.to_vec()).collect::<Vec<_>>())
60 } else {
61 Err(PyValueError::new_err(format!("incompatible number of parameters: expected {}, found {}", stringify!($num), n)))
62 }
63 }
64 }
65 };
66}
67
68#[macro_export]
69macro_rules! add_function {
70 ($m:ident, $function:ident) => {
71 paste! {
72 $m.add_function(wrap_pyfunction!([<$function _py>], $m)?)?;
73 $m.add_function(wrap_pyfunction!([<$function _derivative_py>], $m)?)?;
74 $m.add_function(wrap_pyfunction!([<$function _hessian_py>], $m)?)?;
75 }
76 };
77}
78
79func!(name = ackley, function = ackley_abc, a: f64 = 20.0, b: f64 = 0.2, c: f64 = core::f64::consts::TAU);
80func!(beale, num = 2);
81func!(booth, num = 2);
82func!(bukin_n6, num = 2);
83func!(cross_in_tray, num = 2);
84func!(easom, num = 2);
85func!(eggholder, num = 2);
86func!(goldsteinprice, num = 2);
87func!(himmelblau, num = 2);
88func!(holder_table, num = 2);
89func!(levy);
90func!(levy_n13, num = 2);
91func!(matyas, num = 2);
92func!(mccorminck, num = 2);
93func!(picheny, num = 2);
94func!(name = rastrigin, function = rastrigin_a, a: f64 = 10.0);
95func!(name = rosenbrock, function = rosenbrock_ab, a: f64 = 1.0, b: f64 = 100.0);
96func!(schaffer_n2, num = 2);
97func!(schaffer_n4, num = 2);
98func!(sphere);
99func!(styblinski_tang);
100func!(threehumpcamel, num = 2);
101
102#[pymodule]
103fn argmin_testfunctions_py(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
104 add_function!(m, ackley);
105 add_function!(m, beale);
106 add_function!(m, booth);
107 add_function!(m, bukin_n6);
108 add_function!(m, cross_in_tray);
109 add_function!(m, easom);
110 add_function!(m, eggholder);
111 add_function!(m, goldsteinprice);
112 add_function!(m, himmelblau);
113 add_function!(m, holder_table);
114 add_function!(m, levy);
115 add_function!(m, levy_n13);
116 add_function!(m, matyas);
117 add_function!(m, mccorminck);
118 add_function!(m, picheny);
119 add_function!(m, rastrigin);
120 add_function!(m, rosenbrock);
121 add_function!(m, schaffer_n2);
122 add_function!(m, schaffer_n4);
123 add_function!(m, sphere);
124 add_function!(m, styblinski_tang);
125 add_function!(m, threehumpcamel);
126 Ok(())
127}