argmin_testfunctions_py/
lib.rs

1use argmin_testfunctions::*;
2use paste::paste;
3use pyo3::exceptions::PyValueError;
4use pyo3::prelude::*;
5use std::stringify;
6
7#[macro_export]
8macro_rules! func {
9    ($function:ident) => {
10        func!(name = $function, function = $function,);
11    };
12    ($function:ident, num = $num:expr) => {
13        func!(name = $function, function = $function, num = $num,);
14    };
15    (name = $name:ident, function = $function:ident, $($a:ident : $t:ty = $v:expr),* ) => {
16        paste! {
17            #[pyfunction(name = $name "", signature = (param, $($a = $v),*))]
18            fn [<$name _py>](param: Vec<f64>, $($a: $t),*) -> f64 {
19                $function(&param[..], $($a),*)
20            }
21
22            #[pyfunction(name = $name "_derivative", signature = (param, $($a = $v),*))]
23            fn [<$name _derivative_py>](param: Vec<f64>, $($a: $t),*) -> Vec<f64> {
24                [<$function _derivative>](&param[..], $($a),*)
25            }
26
27            #[pyfunction(name = $name "_hessian", signature = (param, $($a = $v),*))]
28            fn [<$name _hessian_py>](param: Vec<f64>, $($a: $t),*) -> Vec<Vec<f64>> {
29                [<$function _hessian>](&param[..], $($a),*)
30            }
31        }
32    };
33    (name = $name:ident, function = $function:ident, num = $num:expr, $($a:ident : $t:ty = $v:expr),* ) => {
34        paste! {
35            #[pyfunction(name = $name "", signature = (param, $($a = $v),*))]
36            fn [<$name _py>](param: Vec<f64>, $($a: $t),*) -> PyResult<f64> {
37                let n = param.len();
38                if let Ok(param) = param.try_into() {
39                    Ok($function(&param, $($a),*))
40                } else {
41                    Err(PyValueError::new_err(format!("incompatible number of parameters: expected {}, found {}", stringify!($num), n)))
42                }
43            }
44
45            #[pyfunction(name = $name "_derivative", signature = (param, $($a = $v),*))]
46            fn [<$name _derivative_py>](param: Vec<f64>, $($a: $t),*) -> PyResult<Vec<f64>> {
47                let n = param.len();
48                if let Ok(param) = param.try_into() {
49                    Ok([<$function _derivative>](&param, $($a),*).to_vec())
50                } else {
51                    Err(PyValueError::new_err(format!("incompatible number of parameters: expected {}, found {}", stringify!($num), n)))
52                }
53            }
54
55            #[pyfunction(name = $name "_hessian", signature = (param, $($a = $v),*))]
56            fn [<$name _hessian_py>](param: Vec<f64>, $($a: $t),*) -> PyResult<Vec<Vec<f64>>> {
57                let n = param.len();
58                if let Ok(param) = param.try_into() {
59                    Ok([<$function _hessian>](&param, $($a),*).iter().map(|r| r.to_vec()).collect::<Vec<_>>())
60                } else {
61                    Err(PyValueError::new_err(format!("incompatible number of parameters: expected {}, found {}", stringify!($num), n)))
62                }
63            }
64        }
65    };
66}
67
68#[macro_export]
69macro_rules! add_function {
70    ($m:ident, $function:ident) => {
71        paste! {
72            $m.add_function(wrap_pyfunction!([<$function _py>], $m)?)?;
73            $m.add_function(wrap_pyfunction!([<$function _derivative_py>], $m)?)?;
74            $m.add_function(wrap_pyfunction!([<$function _hessian_py>], $m)?)?;
75        }
76    };
77}
78
79func!(name = ackley, function = ackley_abc, a: f64 = 20.0, b: f64 = 0.2, c: f64 = core::f64::consts::TAU);
80func!(beale, num = 2);
81func!(booth, num = 2);
82func!(bukin_n6, num = 2);
83func!(cross_in_tray, num = 2);
84func!(easom, num = 2);
85func!(eggholder, num = 2);
86func!(goldsteinprice, num = 2);
87func!(himmelblau, num = 2);
88func!(holder_table, num = 2);
89func!(levy);
90func!(levy_n13, num = 2);
91func!(matyas, num = 2);
92func!(mccorminck, num = 2);
93func!(picheny, num = 2);
94func!(name = rastrigin, function = rastrigin_a, a: f64 = 10.0);
95func!(name = rosenbrock, function = rosenbrock_ab, a: f64 = 1.0, b: f64 = 100.0);
96func!(schaffer_n2, num = 2);
97func!(schaffer_n4, num = 2);
98func!(sphere);
99func!(styblinski_tang);
100func!(threehumpcamel, num = 2);
101
102#[pymodule]
103fn argmin_testfunctions_py(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
104    add_function!(m, ackley);
105    add_function!(m, beale);
106    add_function!(m, booth);
107    add_function!(m, bukin_n6);
108    add_function!(m, cross_in_tray);
109    add_function!(m, easom);
110    add_function!(m, eggholder);
111    add_function!(m, goldsteinprice);
112    add_function!(m, himmelblau);
113    add_function!(m, holder_table);
114    add_function!(m, levy);
115    add_function!(m, levy_n13);
116    add_function!(m, matyas);
117    add_function!(m, mccorminck);
118    add_function!(m, picheny);
119    add_function!(m, rastrigin);
120    add_function!(m, rosenbrock);
121    add_function!(m, schaffer_n2);
122    add_function!(m, schaffer_n4);
123    add_function!(m, sphere);
124    add_function!(m, styblinski_tang);
125    add_function!(m, threehumpcamel);
126    Ok(())
127}