argmin_math/vec/
dot.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8use crate::ArgminDot;
9use crate::ArgminTranspose;
10use num_complex::Complex;
11
12macro_rules! make_dot_vec {
13    ($t:ty) => {
14        impl ArgminDot<Vec<$t>, $t> for Vec<$t> {
15            #[inline]
16            fn dot(&self, other: &Vec<$t>) -> $t {
17                self.iter().zip(other.iter()).map(|(a, b)| a * b).sum()
18            }
19        }
20
21        impl ArgminDot<$t, Vec<$t>> for Vec<$t> {
22            #[inline]
23            fn dot(&self, other: &$t) -> Vec<$t> {
24                self.iter().map(|a| a * other).collect()
25            }
26        }
27
28        impl ArgminDot<Vec<$t>, Vec<$t>> for $t {
29            #[inline]
30            fn dot(&self, other: &Vec<$t>) -> Vec<$t> {
31                other.iter().map(|a| a * self).collect()
32            }
33        }
34
35        impl ArgminDot<Vec<$t>, Vec<Vec<$t>>> for Vec<$t> {
36            #[inline]
37            fn dot(&self, other: &Vec<$t>) -> Vec<Vec<$t>> {
38                self.iter()
39                    .map(|b| other.iter().map(|a| a * b).collect())
40                    .collect()
41            }
42        }
43
44        impl ArgminDot<Vec<$t>, Vec<$t>> for Vec<Vec<$t>> {
45            #[inline]
46            fn dot(&self, other: &Vec<$t>) -> Vec<$t> {
47                (0..self.len()).map(|i| self[i].dot(other)).collect()
48            }
49        }
50
51        impl ArgminDot<Vec<Vec<$t>>, Vec<Vec<$t>>> for Vec<Vec<$t>> {
52            #[inline]
53            fn dot(&self, other: &Vec<Vec<$t>>) -> Vec<Vec<$t>> {
54                // Would be more efficient if this wasn't necessary!
55                let other = other.clone().t();
56                let sr = self.len();
57                assert!(sr > 0);
58                let sc = self[0].len();
59                assert!(sc > 0);
60                let or = other.len();
61                assert!(or > 0);
62                let oc = other[0].len();
63                assert_eq!(sc, or);
64                assert!(oc > 0);
65                let v = vec![<$t>::default(); oc];
66                let mut out = vec![v; sr];
67                for i in 0..sr {
68                    assert_eq!(self[i].len(), sc);
69                    for j in 0..oc {
70                        out[i][j] = self[i].dot(&other[j]);
71                    }
72                }
73                out
74            }
75        }
76
77        impl ArgminDot<$t, Vec<Vec<$t>>> for Vec<Vec<$t>> {
78            #[inline]
79            fn dot(&self, other: &$t) -> Vec<Vec<$t>> {
80                (0..self.len())
81                    .map(|i| self[i].iter().map(|a| a * other).collect())
82                    .collect()
83            }
84        }
85
86        impl ArgminDot<Vec<Vec<$t>>, Vec<Vec<$t>>> for $t {
87            #[inline]
88            fn dot(&self, other: &Vec<Vec<$t>>) -> Vec<Vec<$t>> {
89                (0..other.len())
90                    .map(|i| other[i].iter().map(|a| a * self).collect())
91                    .collect()
92            }
93        }
94    };
95}
96
97make_dot_vec!(f32);
98make_dot_vec!(f64);
99make_dot_vec!(i8);
100make_dot_vec!(i16);
101make_dot_vec!(i32);
102make_dot_vec!(i64);
103make_dot_vec!(u8);
104make_dot_vec!(u16);
105make_dot_vec!(u32);
106make_dot_vec!(u64);
107make_dot_vec!(Complex<f32>);
108make_dot_vec!(Complex<f64>);
109make_dot_vec!(Complex<i8>);
110make_dot_vec!(Complex<i16>);
111make_dot_vec!(Complex<i32>);
112make_dot_vec!(Complex<i64>);
113make_dot_vec!(Complex<u8>);
114make_dot_vec!(Complex<u16>);
115make_dot_vec!(Complex<u32>);
116make_dot_vec!(Complex<u64>);
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121    use approx::assert_relative_eq;
122    use paste::item;
123
124    macro_rules! make_test {
125        ($t:ty) => {
126            item! {
127                #[test]
128                fn [<test_vec_vec_ $t>]() {
129                    let a = vec![1 as $t, 2 as $t, 3 as $t];
130                    let b = vec![4 as $t, 5 as $t, 6 as $t];
131                    let res: $t = a.dot(&b);
132                    assert_relative_eq!(32 as f64, res as f64, epsilon = f64::EPSILON);
133                }
134            }
135
136            item! {
137                #[test]
138                fn [<test_vec_vec_complex_ $t>]() {
139                    let a = vec![
140                        Complex::new(2 as $t, 2 as $t),
141                        Complex::new(5 as $t, 2 as $t),
142                        Complex::new(3 as $t, 2 as $t),
143                    ];
144                    let b = vec![
145                        Complex::new(5 as $t, 3 as $t),
146                        Complex::new(2 as $t, 4 as $t),
147                        Complex::new(8 as $t, 4 as $t),
148                    ];
149                    let res: Complex<$t> = a.dot(&b);
150                    let target = a[0]*b[0] + a[1]*b[1] + a[2]*b[2];
151                    assert_relative_eq!(res.re as f64, target.re as f64, epsilon = f64::EPSILON);
152                    assert_relative_eq!(res.im as f64, target.im as f64, epsilon = f64::EPSILON);
153                }
154            }
155
156            item! {
157                #[test]
158                fn [<test_vec_scalar_ $t>]() {
159                    let a = vec![1 as $t, 2 as $t, 3 as $t];
160                    let b = 2 as $t;
161                    let product = a.dot(&b);
162                    let res = vec![2 as $t, 4 as $t, 6 as $t];
163                    for i in 0..3 {
164                        assert_relative_eq!(res[i] as f64, product[i] as f64, epsilon = f64::EPSILON);
165                    }
166                }
167            }
168
169            item! {
170                #[test]
171                fn [<test_vec_scalar_complex_ $t>]() {
172                    let a = vec![
173                        Complex::new(2 as $t, 2 as $t),
174                        Complex::new(5 as $t, 2 as $t),
175                        Complex::new(3 as $t, 2 as $t),
176                    ];
177                    let b = Complex::new(4 as $t, 2 as $t);
178                    let product = a.dot(&b);
179                    let res = vec![a[0]*b, a[1]*b, a[2]*b];
180                    for i in 0..3 {
181                        assert_relative_eq!(res[i].re as f64, product[i].re as f64, epsilon = f64::EPSILON);
182                        assert_relative_eq!(res[i].im as f64, product[i].im as f64, epsilon = f64::EPSILON);
183                    }
184                }
185            }
186
187            item! {
188                #[test]
189                fn [<test_scalar_vec_ $t>]() {
190                    let a = vec![1 as $t, 2 as $t, 3 as $t];
191                    let b = 2 as $t;
192                    let product = b.dot(&a);
193                    let res = vec![2 as $t, 4 as $t, 6 as $t];
194                    for i in 0..3 {
195                        assert_relative_eq!(res[i] as f64, product[i] as f64, epsilon = f64::EPSILON);
196                    }
197                }
198            }
199
200            item! {
201                #[test]
202                fn [<test_scalar_vec_complex_ $t>]() {
203                    let a = vec![
204                        Complex::new(2 as $t, 2 as $t),
205                        Complex::new(5 as $t, 2 as $t),
206                        Complex::new(3 as $t, 2 as $t),
207                    ];
208                    let b = Complex::new(4 as $t, 2 as $t);
209                    let product = b.dot(&a);
210                    let res = vec![a[0]*b, a[1]*b, a[2]*b];
211                    for i in 0..3 {
212                        assert_relative_eq!(res[i].re as f64, product[i].re as f64, epsilon = f64::EPSILON);
213                        assert_relative_eq!(res[i].im as f64, product[i].im as f64, epsilon = f64::EPSILON);
214                    }
215                }
216            }
217
218            item! {
219                #[test]
220                fn [<test_mat_vec_ $t>]() {
221                    let a = vec![1 as $t, 2 as $t, 3 as $t];
222                    let b = vec![4 as $t, 5 as $t, 6 as $t];
223                    let res = vec![
224                        vec![4 as $t, 5 as $t, 6 as $t],
225                        vec![8 as $t, 10 as $t, 12 as $t],
226                        vec![12 as $t, 15 as $t, 18 as $t]
227                    ];
228                    let product: Vec<Vec<$t>> = a.dot(&b);
229                    for i in 0..3 {
230                        for j in 0..3 {
231                            assert_relative_eq!(res[i][j] as f64, product[i][j] as f64, epsilon = f64::EPSILON);
232                        }
233                    }
234                }
235            }
236
237            item! {
238                #[test]
239                fn [<test_mat_vec_complex_ $t>]() {
240                    let a = vec![
241                        Complex::new(2 as $t, 2 as $t),
242                        Complex::new(5 as $t, 2 as $t),
243                    ];
244                    let b = vec![
245                        Complex::new(5 as $t, 1 as $t),
246                        Complex::new(2 as $t, 1 as $t),
247                    ];
248                    let res = vec![
249                        vec![a[0]*b[0], a[0]*b[1]],
250                        vec![a[1]*b[0], a[1]*b[1]],
251                    ];
252                    let product: Vec<Vec<Complex<$t>>> = a.dot(&b);
253                    for i in 0..2 {
254                        for j in 0..2 {
255                            assert_relative_eq!(res[i][j].re as f64, product[i][j].re as f64, epsilon = f64::EPSILON);
256                            assert_relative_eq!(res[i][j].im as f64, product[i][j].im as f64, epsilon = f64::EPSILON);
257                        }
258                    }
259                }
260            }
261
262            item! {
263                #[test]
264                fn [<test_mat_vec_2_ $t>]() {
265                    let a = vec![
266                        vec![1 as $t, 2 as $t, 3 as $t],
267                        vec![4 as $t, 5 as $t, 6 as $t],
268                        vec![7 as $t, 8 as $t, 9 as $t]
269                    ];
270                    let b = vec![1 as $t, 2 as $t, 3 as $t];
271                    let res = vec![14 as $t, 32 as $t, 50 as $t];
272                    let product = a.dot(&b);
273                    for i in 0..3 {
274                        assert_relative_eq!(res[i] as f64, product[i] as f64, epsilon = f64::EPSILON);
275                    }
276                }
277            }
278
279            item! {
280                #[test]
281                fn [<test_mat_vec_2_complex $t>]() {
282                    let a = vec![
283                        vec![Complex::new(2 as $t, 2 as $t), Complex::new(5 as $t, 2 as $t)],
284                        vec![Complex::new(2 as $t, 2 as $t), Complex::new(5 as $t, 2 as $t)],
285                    ];
286                    let b = vec![
287                        Complex::new(5 as $t, 1 as $t),
288                        Complex::new(2 as $t, 1 as $t),
289                    ];
290                    let res = vec![
291                        a[0][0] * b[0] + a[0][1] * b[1],
292                        a[1][0] * b[0] + a[1][1] * b[1],
293                    ];
294                    let product = a.dot(&b);
295                    for i in 0..2 {
296                        assert_relative_eq!(res[i].re as f64, product[i].re as f64, epsilon = f64::EPSILON);
297                        assert_relative_eq!(res[i].im as f64, product[i].im as f64, epsilon = f64::EPSILON);
298                    }
299                }
300            }
301
302            item! {
303                #[test]
304                fn [<test_mat_mat_ $t>]() {
305                    let a = vec![
306                        vec![1 as $t, 2 as $t, 3 as $t],
307                        vec![4 as $t, 5 as $t, 6 as $t],
308                        vec![3 as $t, 2 as $t, 1 as $t]
309                    ];
310                    let b = vec![
311                        vec![3 as $t, 2 as $t, 1 as $t],
312                        vec![6 as $t, 5 as $t, 4 as $t],
313                        vec![2 as $t, 4 as $t, 3 as $t]
314                    ];
315                    let res = vec![
316                        vec![21 as $t, 24 as $t, 18 as $t],
317                        vec![54 as $t, 57 as $t, 42 as $t],
318                        vec![23 as $t, 20 as $t, 14 as $t]
319                    ];
320                    let product = a.dot(&b);
321                    for i in 0..3 {
322                        for j in 0..3 {
323                            assert!((((res[i][j] - product[i][j]) as f64).abs()) < f64::EPSILON);
324                        }
325                    }
326                }
327            }
328
329            item! {
330                #[test]
331                fn [<test_mat_mat_complex $t>]() {
332                    let a = vec![
333                        vec![Complex::new(2 as $t, 1 as $t), Complex::new(5 as $t, 2 as $t)],
334                        vec![Complex::new(4 as $t, 2 as $t), Complex::new(7 as $t, 1 as $t)],
335                    ];
336                    let b = vec![
337                        vec![Complex::new(2 as $t, 2 as $t), Complex::new(5 as $t, 1 as $t)],
338                        vec![Complex::new(3 as $t, 1 as $t), Complex::new(4 as $t, 2 as $t)],
339                    ];
340                    let res = vec![
341                        vec![
342                            a[0][0] * b[0][0] + a[0][1] * b[1][0],
343                            a[0][0] * b[0][1] + a[0][1] * b[1][1]
344                        ],
345                        vec![
346                            a[1][0] * b[0][0] + a[1][1] * b[1][0],
347                            a[1][0] * b[0][1] + a[1][1] * b[1][1]
348                        ],
349                    ];
350                    let product = a.dot(&b);
351                    for i in 0..2 {
352                        for j in 0..2 {
353                            assert_relative_eq!(res[i][j].re as f64, product[i][j].re as f64, epsilon = f64::EPSILON);
354                            assert_relative_eq!(res[i][j].im as f64, product[i][j].im as f64, epsilon = f64::EPSILON);
355                        }
356                    }
357                }
358            }
359
360            item! {
361                #[test]
362                #[should_panic]
363                fn [<test_mat_mat_panic_1_ $t>]() {
364                    let a = vec![];
365                    let b = vec![
366                        vec![3 as $t, 2 as $t, 1 as $t],
367                        vec![6 as $t, 5 as $t, 4 as $t],
368                        vec![2 as $t, 4 as $t, 3 as $t]
369                    ];
370                    a.dot(&b);
371                }
372            }
373
374            item! {
375                #[test]
376                #[should_panic]
377                fn [<test_mat_mat_panic_2_ $t>]() {
378                    let a: Vec<Vec<$t>> = vec![];
379                    let b = vec![
380                        vec![3 as $t, 2 as $t, 1 as $t],
381                        vec![6 as $t, 5 as $t, 4 as $t],
382                        vec![2 as $t, 4 as $t, 3 as $t]
383                    ];
384                    b.dot(&a);
385                }
386            }
387
388            item! {
389                #[test]
390                #[should_panic]
391                fn [<test_mat_mat_panic_3_ $t>]() {
392                    let a = vec![
393                        vec![1 as $t, 2 as $t],
394                        vec![4 as $t, 5 as $t],
395                        vec![3 as $t, 2 as $t]
396                    ];
397                    let b = vec![
398                        vec![3 as $t, 2 as $t, 1 as $t],
399                        vec![6 as $t, 5 as $t, 4 as $t],
400                        vec![2 as $t, 4 as $t, 3 as $t]
401                    ];
402                    a.dot(&b);
403                }
404            }
405
406            item! {
407                #[test]
408                #[should_panic]
409                fn [<test_mat_mat_panic_4_ $t>]() {
410                    let a = vec![
411                        vec![1 as $t, 2 as $t, 3 as $t],
412                        vec![4 as $t, 5 as $t, 6 as $t],
413                        vec![3 as $t, 2 as $t, 1 as $t]
414                    ];
415                    let b = vec![
416                        vec![3 as $t, 2 as $t],
417                        vec![6 as $t, 5 as $t],
418                        vec![3 as $t, 2 as $t]
419                    ];
420                    a.dot(&b);
421                }
422            }
423
424            item! {
425                #[test]
426                #[should_panic]
427                fn [<test_mat_mat_panic_5_ $t>]() {
428                    let a = vec![
429                        vec![1 as $t, 2 as $t, 3 as $t],
430                        vec![4 as $t, 5 as $t, 6 as $t],
431                        vec![3 as $t, 2 as $t, 1 as $t]
432                    ];
433                    let b = vec![
434                        vec![3 as $t, 2 as $t, 1 as $t],
435                        vec![6 as $t, 5 as $t, 4 as $t],
436                        vec![2 as $t, 3 as $t]
437                    ];
438                    a.dot(&b);
439                }
440            }
441
442            item! {
443                #[test]
444                #[should_panic]
445                fn [<test_mat_mat_panic_6_ $t>]() {
446                    let a = vec![
447                        vec![1 as $t, 2 as $t, 3 as $t],
448                        vec![4 as $t, 5 as $t],
449                        vec![3 as $t, 2 as $t, 1 as $t]
450                    ];
451                    let b = vec![
452                        vec![3 as $t, 2 as $t, 1 as $t],
453                        vec![6 as $t, 5 as $t, 4 as $t],
454                        vec![2 as $t, 4 as $t, 3 as $t]
455                    ];
456                    a.dot(&b);
457                }
458            }
459
460            item! {
461                #[test]
462                fn [<test_mat_primitive_ $t>]() {
463                    let a = vec![
464                        vec![1 as $t, 2 as $t, 3 as $t],
465                        vec![4 as $t, 5 as $t, 6 as $t],
466                        vec![3 as $t, 2 as $t, 1 as $t]
467                    ];
468                    let res = vec![
469                        vec![2 as $t, 4 as $t, 6 as $t],
470                        vec![8 as $t, 10 as $t, 12 as $t],
471                        vec![6 as $t, 4 as $t, 2 as $t]
472                    ];
473                    let product = a.dot(&(2 as $t));
474                    for i in 0..3 {
475                        for j in 0..3 {
476                            assert_relative_eq!(res[i][j] as f64, product[i][j] as f64, epsilon = f64::EPSILON);
477                        }
478                    }
479                }
480            }
481
482            item! {
483                #[test]
484                fn [<test_mat_primitive_complex_ $t>]() {
485                    let a = vec![
486                        vec![Complex::new(2 as $t, 1 as $t), Complex::new(5 as $t, 2 as $t)],
487                        vec![Complex::new(4 as $t, 2 as $t), Complex::new(7 as $t, 1 as $t)],
488                    ];
489                    let b = Complex::new(4 as $t, 1 as $t);
490                    let res = vec![
491                        vec![a[0][0] * b, a[0][1] * b],
492                        vec![a[1][0] * b, a[1][1] * b],
493                    ];
494                    let product = a.dot(&b);
495                    for i in 0..2 {
496                        for j in 0..2 {
497                            assert_relative_eq!(res[i][j].re as f64, product[i][j].re as f64, epsilon = f64::EPSILON);
498                            assert_relative_eq!(res[i][j].im as f64, product[i][j].im as f64, epsilon = f64::EPSILON);
499                        }
500                    }
501                }
502            }
503
504            item! {
505                #[test]
506                fn [<test_primitive_mat_ $t>]() {
507                    let a = vec![
508                        vec![1 as $t, 2 as $t, 3 as $t],
509                        vec![4 as $t, 5 as $t, 6 as $t],
510                        vec![3 as $t, 2 as $t, 1 as $t]
511                    ];
512                    let res = vec![
513                        vec![2 as $t, 4 as $t, 6 as $t],
514                        vec![8 as $t, 10 as $t, 12 as $t],
515                        vec![6 as $t, 4 as $t, 2 as $t]
516                    ];
517                    let product = (2 as $t).dot(&a);
518                    for i in 0..3 {
519                        for j in 0..3 {
520                            assert_relative_eq!(res[i][j] as f64, product[i][j] as f64, epsilon = f64::EPSILON);
521                        }
522                    }
523                }
524            }
525
526            item! {
527                #[test]
528                fn [<test_primitive_mat_complex_ $t>]() {
529                    let a = vec![
530                        vec![Complex::new(2 as $t, 1 as $t), Complex::new(5 as $t, 2 as $t)],
531                        vec![Complex::new(4 as $t, 2 as $t), Complex::new(7 as $t, 1 as $t)],
532                    ];
533                    let b = Complex::new(4 as $t, 1 as $t);
534                    let res = vec![
535                        vec![a[0][0] * b, a[0][1] * b],
536                        vec![a[1][0] * b, a[1][1] * b],
537                    ];
538                    let product = b.dot(&a);
539                    for i in 0..2 {
540                        for j in 0..2 {
541                            assert_relative_eq!(res[i][j].re as f64, product[i][j].re as f64, epsilon = f64::EPSILON);
542                            assert_relative_eq!(res[i][j].im as f64, product[i][j].im as f64, epsilon = f64::EPSILON);
543                        }
544                    }
545                }
546            }
547        };
548    }
549
550    make_test!(i8);
551    make_test!(u8);
552    make_test!(i16);
553    make_test!(u16);
554    make_test!(i32);
555    make_test!(u32);
556    make_test!(i64);
557    make_test!(u64);
558    make_test!(f32);
559    make_test!(f64);
560}