argmin_testfunctions_py/
lib.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
use argmin_testfunctions::*;
use paste::paste;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use std::stringify;

#[macro_export]
macro_rules! func {
    ($function:ident) => {
        func!(name = $function, function = $function,);
    };
    ($function:ident, num = $num:expr) => {
        func!(name = $function, function = $function, num = $num,);
    };
    (name = $name:ident, function = $function:ident, $($a:ident : $t:ty = $v:expr),* ) => {
        paste! {
            #[pyfunction(name = $name "", signature = (param, $($a = $v),*))]
            fn [<$name _py>](param: Vec<f64>, $($a: $t),*) -> f64 {
                $function(&param[..], $($a),*)
            }

            #[pyfunction(name = $name "_derivative", signature = (param, $($a = $v),*))]
            fn [<$name _derivative_py>](param: Vec<f64>, $($a: $t),*) -> Vec<f64> {
                [<$function _derivative>](&param[..], $($a),*)
            }

            #[pyfunction(name = $name "_hessian", signature = (param, $($a = $v),*))]
            fn [<$name _hessian_py>](param: Vec<f64>, $($a: $t),*) -> Vec<Vec<f64>> {
                [<$function _hessian>](&param[..], $($a),*)
            }
        }
    };
    (name = $name:ident, function = $function:ident, num = $num:expr, $($a:ident : $t:ty = $v:expr),* ) => {
        paste! {
            #[pyfunction(name = $name "", signature = (param, $($a = $v),*))]
            fn [<$name _py>](param: Vec<f64>, $($a: $t),*) -> PyResult<f64> {
                let n = param.len();
                if let Ok(param) = param.try_into() {
                    Ok($function(&param, $($a),*))
                } else {
                    Err(PyValueError::new_err(format!("incompatible number of parameters: expected {}, found {}", stringify!($num), n)))
                }
            }

            #[pyfunction(name = $name "_derivative", signature = (param, $($a = $v),*))]
            fn [<$name _derivative_py>](param: Vec<f64>, $($a: $t),*) -> PyResult<Vec<f64>> {
                let n = param.len();
                if let Ok(param) = param.try_into() {
                    Ok([<$function _derivative>](&param, $($a),*).to_vec())
                } else {
                    Err(PyValueError::new_err(format!("incompatible number of parameters: expected {}, found {}", stringify!($num), n)))
                }
            }

            #[pyfunction(name = $name "_hessian", signature = (param, $($a = $v),*))]
            fn [<$name _hessian_py>](param: Vec<f64>, $($a: $t),*) -> PyResult<Vec<Vec<f64>>> {
                let n = param.len();
                if let Ok(param) = param.try_into() {
                    Ok([<$function _hessian>](&param, $($a),*).iter().map(|r| r.to_vec()).collect::<Vec<_>>())
                } else {
                    Err(PyValueError::new_err(format!("incompatible number of parameters: expected {}, found {}", stringify!($num), n)))
                }
            }
        }
    };
}

#[macro_export]
macro_rules! add_function {
    ($m:ident, $function:ident) => {
        paste! {
            $m.add_function(wrap_pyfunction!([<$function _py>], $m)?)?;
            $m.add_function(wrap_pyfunction!([<$function _derivative_py>], $m)?)?;
            $m.add_function(wrap_pyfunction!([<$function _hessian_py>], $m)?)?;
        }
    };
}

func!(name = ackley, function = ackley_abc, a: f64 = 20.0, b: f64 = 0.2, c: f64 = 6.2831853071795864769252867665590057683943387987502116419498891846);
func!(beale, num = 2);
func!(booth, num = 2);
func!(bukin_n6, num = 2);
func!(cross_in_tray, num = 2);
func!(easom, num = 2);
func!(eggholder, num = 2);
func!(goldsteinprice, num = 2);
func!(himmelblau, num = 2);
func!(holder_table, num = 2);
func!(levy);
func!(levy_n13, num = 2);
func!(matyas, num = 2);
func!(mccorminck, num = 2);
func!(picheny, num = 2);
func!(name = rastrigin, function = rastrigin_a, a: f64 = 10.0);
func!(name = rosenbrock, function = rosenbrock_ab, a: f64 = 1.0, b: f64 = 100.0);
func!(schaffer_n2, num = 2);
func!(schaffer_n4, num = 2);
func!(sphere);
func!(styblinski_tang);
func!(threehumpcamel, num = 2);

#[pymodule]
fn argmin_testfunctions_py(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
    add_function!(m, ackley);
    add_function!(m, beale);
    add_function!(m, booth);
    add_function!(m, bukin_n6);
    add_function!(m, cross_in_tray);
    add_function!(m, easom);
    add_function!(m, eggholder);
    add_function!(m, goldsteinprice);
    add_function!(m, himmelblau);
    add_function!(m, holder_table);
    add_function!(m, levy);
    add_function!(m, levy_n13);
    add_function!(m, matyas);
    add_function!(m, mccorminck);
    add_function!(m, picheny);
    add_function!(m, rastrigin);
    add_function!(m, rosenbrock);
    add_function!(m, schaffer_n2);
    add_function!(m, schaffer_n4);
    add_function!(m, sphere);
    add_function!(m, styblinski_tang);
    add_function!(m, threehumpcamel);
    Ok(())
}