argmin_math/primitives/
weighteddot.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use crate::ArgminDot;
9use crate::ArgminWeightedDot;
10
11impl<T, U, V> ArgminWeightedDot<T, U, V> for T
12where
13    Self: ArgminDot<T, U>,
14    V: ArgminDot<T, T>,
15{
16    #[inline]
17    fn weighted_dot(&self, w: &V, v: &T) -> U {
18        self.dot(&w.dot(v))
19    }
20}
21
22#[cfg(feature = "vec")]
23#[cfg(test)]
24mod tests_vec {
25    use super::*;
26    use approx::assert_relative_eq;
27    use paste::item;
28
29    macro_rules! make_test {
30        ($t:ty) => {
31            item! {
32                #[test]
33                fn [<test_ $t>]() {
34                    let a = vec![2 as $t, 1 as $t, 2 as $t];
35                    let b = vec![1 as $t, 2 as $t, 1 as $t];
36                    let w = vec![
37                        vec![8 as $t, 1 as $t, 6 as $t],
38                        vec![3 as $t, 5 as $t, 7 as $t],
39                        vec![4 as $t, 9 as $t, 2 as $t],
40                    ];
41                    let res: $t = a.weighted_dot(&w, &b);
42                    assert_relative_eq!(100 as f64, res as f64, epsilon = f64::EPSILON);
43                }
44            }
45        };
46    }
47
48    make_test!(i8);
49    make_test!(u8);
50    make_test!(i16);
51    make_test!(u16);
52    make_test!(i32);
53    make_test!(u32);
54    make_test!(i64);
55    make_test!(u64);
56    make_test!(f32);
57    make_test!(f64);
58}
59
60#[cfg(feature = "ndarray_all")]
61#[cfg(test)]
62mod tests_ndarray {
63    use super::*;
64    use ndarray::array;
65    use paste::item;
66
67    macro_rules! make_test {
68        ($t:ty) => {
69            item! {
70                #[test]
71                fn [<test_ $t>]() {
72                    let a = array![2 as $t, 1 as $t, 2 as $t];
73                    let b = array![1 as $t, 2 as $t, 1 as $t];
74                    let w = array![
75                        [8 as $t, 1 as $t, 6 as $t],
76                        [3 as $t, 5 as $t, 7 as $t],
77                        [4 as $t, 9 as $t, 2 as $t],
78                    ];
79                    let res: $t = a.weighted_dot(&w, &b);
80                    assert!((((res - 100 as $t) as f64).abs()) < f64::EPSILON);
81                }
82            }
83        };
84    }
85
86    make_test!(i8);
87    make_test!(u8);
88    make_test!(i16);
89    make_test!(u16);
90    make_test!(i32);
91    make_test!(u32);
92    make_test!(i64);
93    make_test!(u64);
94    make_test!(f32);
95    make_test!(f64);
96}
97
98#[cfg(feature = "nalgebra_all")]
99#[cfg(test)]
100mod tests_nalgebra {
101    use super::*;
102    use nalgebra::{Matrix3, Vector3};
103    use paste::item;
104
105    macro_rules! make_test {
106        ($t:ty) => {
107            item! {
108                #[test]
109                fn [<test_ $t>]() {
110                    let a = Vector3::new(2 as $t, 1 as $t, 2 as $t);
111                    let b = Vector3::new(1 as $t, 2 as $t, 1 as $t);
112                    let w = Matrix3::new(
113                        8 as $t, 1 as $t, 6 as $t,
114                        3 as $t, 5 as $t, 7 as $t,
115                        4 as $t, 9 as $t, 2 as $t,
116                    );
117                    let res: $t = a.weighted_dot(&w, &b);
118                    assert!((((res - 100 as $t) as f64).abs()) < f64::EPSILON);
119                }
120            }
121        };
122    }
123
124    make_test!(i8);
125    make_test!(u8);
126    make_test!(i16);
127    make_test!(u16);
128    make_test!(i32);
129    make_test!(u32);
130    make_test!(i64);
131    make_test!(u64);
132    make_test!(f32);
133    make_test!(f64);
134}