1use crate::ArgminDot;
9use crate::ArgminTranspose;
10use num_complex::Complex;
11
12macro_rules! make_dot_vec {
13 ($t:ty) => {
14 impl ArgminDot<Vec<$t>, $t> for Vec<$t> {
15 #[inline]
16 fn dot(&self, other: &Vec<$t>) -> $t {
17 self.iter().zip(other.iter()).map(|(a, b)| a * b).sum()
18 }
19 }
20
21 impl ArgminDot<$t, Vec<$t>> for Vec<$t> {
22 #[inline]
23 fn dot(&self, other: &$t) -> Vec<$t> {
24 self.iter().map(|a| a * other).collect()
25 }
26 }
27
28 impl ArgminDot<Vec<$t>, Vec<$t>> for $t {
29 #[inline]
30 fn dot(&self, other: &Vec<$t>) -> Vec<$t> {
31 other.iter().map(|a| a * self).collect()
32 }
33 }
34
35 impl ArgminDot<Vec<$t>, Vec<Vec<$t>>> for Vec<$t> {
36 #[inline]
37 fn dot(&self, other: &Vec<$t>) -> Vec<Vec<$t>> {
38 self.iter()
39 .map(|b| other.iter().map(|a| a * b).collect())
40 .collect()
41 }
42 }
43
44 impl ArgminDot<Vec<$t>, Vec<$t>> for Vec<Vec<$t>> {
45 #[inline]
46 fn dot(&self, other: &Vec<$t>) -> Vec<$t> {
47 (0..self.len()).map(|i| self[i].dot(other)).collect()
48 }
49 }
50
51 impl ArgminDot<Vec<Vec<$t>>, Vec<Vec<$t>>> for Vec<Vec<$t>> {
52 #[inline]
53 fn dot(&self, other: &Vec<Vec<$t>>) -> Vec<Vec<$t>> {
54 let other = other.clone().t();
56 let sr = self.len();
57 assert!(sr > 0);
58 let sc = self[0].len();
59 assert!(sc > 0);
60 let or = other.len();
61 assert!(or > 0);
62 let oc = other[0].len();
63 assert_eq!(sc, or);
64 assert!(oc > 0);
65 let v = vec![<$t>::default(); oc];
66 let mut out = vec![v; sr];
67 for i in 0..sr {
68 assert_eq!(self[i].len(), sc);
69 for j in 0..oc {
70 out[i][j] = self[i].dot(&other[j]);
71 }
72 }
73 out
74 }
75 }
76
77 impl ArgminDot<$t, Vec<Vec<$t>>> for Vec<Vec<$t>> {
78 #[inline]
79 fn dot(&self, other: &$t) -> Vec<Vec<$t>> {
80 (0..self.len())
81 .map(|i| self[i].iter().map(|a| a * other).collect())
82 .collect()
83 }
84 }
85
86 impl ArgminDot<Vec<Vec<$t>>, Vec<Vec<$t>>> for $t {
87 #[inline]
88 fn dot(&self, other: &Vec<Vec<$t>>) -> Vec<Vec<$t>> {
89 (0..other.len())
90 .map(|i| other[i].iter().map(|a| a * self).collect())
91 .collect()
92 }
93 }
94 };
95}
96
97make_dot_vec!(f32);
98make_dot_vec!(f64);
99make_dot_vec!(i8);
100make_dot_vec!(i16);
101make_dot_vec!(i32);
102make_dot_vec!(i64);
103make_dot_vec!(u8);
104make_dot_vec!(u16);
105make_dot_vec!(u32);
106make_dot_vec!(u64);
107make_dot_vec!(Complex<f32>);
108make_dot_vec!(Complex<f64>);
109make_dot_vec!(Complex<i8>);
110make_dot_vec!(Complex<i16>);
111make_dot_vec!(Complex<i32>);
112make_dot_vec!(Complex<i64>);
113make_dot_vec!(Complex<u8>);
114make_dot_vec!(Complex<u16>);
115make_dot_vec!(Complex<u32>);
116make_dot_vec!(Complex<u64>);
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121 use approx::assert_relative_eq;
122 use paste::item;
123
124 macro_rules! make_test {
125 ($t:ty) => {
126 item! {
127 #[test]
128 fn [<test_vec_vec_ $t>]() {
129 let a = vec![1 as $t, 2 as $t, 3 as $t];
130 let b = vec![4 as $t, 5 as $t, 6 as $t];
131 let res: $t = a.dot(&b);
132 assert_relative_eq!(32 as f64, res as f64, epsilon = f64::EPSILON);
133 }
134 }
135
136 item! {
137 #[test]
138 fn [<test_vec_vec_complex_ $t>]() {
139 let a = vec![
140 Complex::new(2 as $t, 2 as $t),
141 Complex::new(5 as $t, 2 as $t),
142 Complex::new(3 as $t, 2 as $t),
143 ];
144 let b = vec![
145 Complex::new(5 as $t, 3 as $t),
146 Complex::new(2 as $t, 4 as $t),
147 Complex::new(8 as $t, 4 as $t),
148 ];
149 let res: Complex<$t> = a.dot(&b);
150 let target = a[0]*b[0] + a[1]*b[1] + a[2]*b[2];
151 assert_relative_eq!(res.re as f64, target.re as f64, epsilon = f64::EPSILON);
152 assert_relative_eq!(res.im as f64, target.im as f64, epsilon = f64::EPSILON);
153 }
154 }
155
156 item! {
157 #[test]
158 fn [<test_vec_scalar_ $t>]() {
159 let a = vec![1 as $t, 2 as $t, 3 as $t];
160 let b = 2 as $t;
161 let product = a.dot(&b);
162 let res = vec![2 as $t, 4 as $t, 6 as $t];
163 for i in 0..3 {
164 assert_relative_eq!(res[i] as f64, product[i] as f64, epsilon = f64::EPSILON);
165 }
166 }
167 }
168
169 item! {
170 #[test]
171 fn [<test_vec_scalar_complex_ $t>]() {
172 let a = vec![
173 Complex::new(2 as $t, 2 as $t),
174 Complex::new(5 as $t, 2 as $t),
175 Complex::new(3 as $t, 2 as $t),
176 ];
177 let b = Complex::new(4 as $t, 2 as $t);
178 let product = a.dot(&b);
179 let res = vec![a[0]*b, a[1]*b, a[2]*b];
180 for i in 0..3 {
181 assert_relative_eq!(res[i].re as f64, product[i].re as f64, epsilon = f64::EPSILON);
182 assert_relative_eq!(res[i].im as f64, product[i].im as f64, epsilon = f64::EPSILON);
183 }
184 }
185 }
186
187 item! {
188 #[test]
189 fn [<test_scalar_vec_ $t>]() {
190 let a = vec![1 as $t, 2 as $t, 3 as $t];
191 let b = 2 as $t;
192 let product = b.dot(&a);
193 let res = vec![2 as $t, 4 as $t, 6 as $t];
194 for i in 0..3 {
195 assert_relative_eq!(res[i] as f64, product[i] as f64, epsilon = f64::EPSILON);
196 }
197 }
198 }
199
200 item! {
201 #[test]
202 fn [<test_scalar_vec_complex_ $t>]() {
203 let a = vec![
204 Complex::new(2 as $t, 2 as $t),
205 Complex::new(5 as $t, 2 as $t),
206 Complex::new(3 as $t, 2 as $t),
207 ];
208 let b = Complex::new(4 as $t, 2 as $t);
209 let product = b.dot(&a);
210 let res = vec![a[0]*b, a[1]*b, a[2]*b];
211 for i in 0..3 {
212 assert_relative_eq!(res[i].re as f64, product[i].re as f64, epsilon = f64::EPSILON);
213 assert_relative_eq!(res[i].im as f64, product[i].im as f64, epsilon = f64::EPSILON);
214 }
215 }
216 }
217
218 item! {
219 #[test]
220 fn [<test_mat_vec_ $t>]() {
221 let a = vec![1 as $t, 2 as $t, 3 as $t];
222 let b = vec![4 as $t, 5 as $t, 6 as $t];
223 let res = vec![
224 vec![4 as $t, 5 as $t, 6 as $t],
225 vec![8 as $t, 10 as $t, 12 as $t],
226 vec![12 as $t, 15 as $t, 18 as $t]
227 ];
228 let product: Vec<Vec<$t>> = a.dot(&b);
229 for i in 0..3 {
230 for j in 0..3 {
231 assert_relative_eq!(res[i][j] as f64, product[i][j] as f64, epsilon = f64::EPSILON);
232 }
233 }
234 }
235 }
236
237 item! {
238 #[test]
239 fn [<test_mat_vec_complex_ $t>]() {
240 let a = vec![
241 Complex::new(2 as $t, 2 as $t),
242 Complex::new(5 as $t, 2 as $t),
243 ];
244 let b = vec![
245 Complex::new(5 as $t, 1 as $t),
246 Complex::new(2 as $t, 1 as $t),
247 ];
248 let res = vec![
249 vec![a[0]*b[0], a[0]*b[1]],
250 vec![a[1]*b[0], a[1]*b[1]],
251 ];
252 let product: Vec<Vec<Complex<$t>>> = a.dot(&b);
253 for i in 0..2 {
254 for j in 0..2 {
255 assert_relative_eq!(res[i][j].re as f64, product[i][j].re as f64, epsilon = f64::EPSILON);
256 assert_relative_eq!(res[i][j].im as f64, product[i][j].im as f64, epsilon = f64::EPSILON);
257 }
258 }
259 }
260 }
261
262 item! {
263 #[test]
264 fn [<test_mat_vec_2_ $t>]() {
265 let a = vec![
266 vec![1 as $t, 2 as $t, 3 as $t],
267 vec![4 as $t, 5 as $t, 6 as $t],
268 vec![7 as $t, 8 as $t, 9 as $t]
269 ];
270 let b = vec![1 as $t, 2 as $t, 3 as $t];
271 let res = vec![14 as $t, 32 as $t, 50 as $t];
272 let product = a.dot(&b);
273 for i in 0..3 {
274 assert_relative_eq!(res[i] as f64, product[i] as f64, epsilon = f64::EPSILON);
275 }
276 }
277 }
278
279 item! {
280 #[test]
281 fn [<test_mat_vec_2_complex $t>]() {
282 let a = vec![
283 vec![Complex::new(2 as $t, 2 as $t), Complex::new(5 as $t, 2 as $t)],
284 vec![Complex::new(2 as $t, 2 as $t), Complex::new(5 as $t, 2 as $t)],
285 ];
286 let b = vec![
287 Complex::new(5 as $t, 1 as $t),
288 Complex::new(2 as $t, 1 as $t),
289 ];
290 let res = vec![
291 a[0][0] * b[0] + a[0][1] * b[1],
292 a[1][0] * b[0] + a[1][1] * b[1],
293 ];
294 let product = a.dot(&b);
295 for i in 0..2 {
296 assert_relative_eq!(res[i].re as f64, product[i].re as f64, epsilon = f64::EPSILON);
297 assert_relative_eq!(res[i].im as f64, product[i].im as f64, epsilon = f64::EPSILON);
298 }
299 }
300 }
301
302 item! {
303 #[test]
304 fn [<test_mat_mat_ $t>]() {
305 let a = vec![
306 vec![1 as $t, 2 as $t, 3 as $t],
307 vec![4 as $t, 5 as $t, 6 as $t],
308 vec![3 as $t, 2 as $t, 1 as $t]
309 ];
310 let b = vec![
311 vec![3 as $t, 2 as $t, 1 as $t],
312 vec![6 as $t, 5 as $t, 4 as $t],
313 vec![2 as $t, 4 as $t, 3 as $t]
314 ];
315 let res = vec![
316 vec![21 as $t, 24 as $t, 18 as $t],
317 vec![54 as $t, 57 as $t, 42 as $t],
318 vec![23 as $t, 20 as $t, 14 as $t]
319 ];
320 let product = a.dot(&b);
321 for i in 0..3 {
322 for j in 0..3 {
323 assert!((((res[i][j] - product[i][j]) as f64).abs()) < f64::EPSILON);
324 }
325 }
326 }
327 }
328
329 item! {
330 #[test]
331 fn [<test_mat_mat_complex $t>]() {
332 let a = vec![
333 vec![Complex::new(2 as $t, 1 as $t), Complex::new(5 as $t, 2 as $t)],
334 vec![Complex::new(4 as $t, 2 as $t), Complex::new(7 as $t, 1 as $t)],
335 ];
336 let b = vec![
337 vec![Complex::new(2 as $t, 2 as $t), Complex::new(5 as $t, 1 as $t)],
338 vec![Complex::new(3 as $t, 1 as $t), Complex::new(4 as $t, 2 as $t)],
339 ];
340 let res = vec![
341 vec![
342 a[0][0] * b[0][0] + a[0][1] * b[1][0],
343 a[0][0] * b[0][1] + a[0][1] * b[1][1]
344 ],
345 vec![
346 a[1][0] * b[0][0] + a[1][1] * b[1][0],
347 a[1][0] * b[0][1] + a[1][1] * b[1][1]
348 ],
349 ];
350 let product = a.dot(&b);
351 for i in 0..2 {
352 for j in 0..2 {
353 assert_relative_eq!(res[i][j].re as f64, product[i][j].re as f64, epsilon = f64::EPSILON);
354 assert_relative_eq!(res[i][j].im as f64, product[i][j].im as f64, epsilon = f64::EPSILON);
355 }
356 }
357 }
358 }
359
360 item! {
361 #[test]
362 #[should_panic]
363 fn [<test_mat_mat_panic_1_ $t>]() {
364 let a = vec![];
365 let b = vec![
366 vec![3 as $t, 2 as $t, 1 as $t],
367 vec![6 as $t, 5 as $t, 4 as $t],
368 vec![2 as $t, 4 as $t, 3 as $t]
369 ];
370 a.dot(&b);
371 }
372 }
373
374 item! {
375 #[test]
376 #[should_panic]
377 fn [<test_mat_mat_panic_2_ $t>]() {
378 let a: Vec<Vec<$t>> = vec![];
379 let b = vec![
380 vec![3 as $t, 2 as $t, 1 as $t],
381 vec![6 as $t, 5 as $t, 4 as $t],
382 vec![2 as $t, 4 as $t, 3 as $t]
383 ];
384 b.dot(&a);
385 }
386 }
387
388 item! {
389 #[test]
390 #[should_panic]
391 fn [<test_mat_mat_panic_3_ $t>]() {
392 let a = vec![
393 vec![1 as $t, 2 as $t],
394 vec![4 as $t, 5 as $t],
395 vec![3 as $t, 2 as $t]
396 ];
397 let b = vec![
398 vec![3 as $t, 2 as $t, 1 as $t],
399 vec![6 as $t, 5 as $t, 4 as $t],
400 vec![2 as $t, 4 as $t, 3 as $t]
401 ];
402 a.dot(&b);
403 }
404 }
405
406 item! {
407 #[test]
408 #[should_panic]
409 fn [<test_mat_mat_panic_4_ $t>]() {
410 let a = vec![
411 vec![1 as $t, 2 as $t, 3 as $t],
412 vec![4 as $t, 5 as $t, 6 as $t],
413 vec![3 as $t, 2 as $t, 1 as $t]
414 ];
415 let b = vec![
416 vec![3 as $t, 2 as $t],
417 vec![6 as $t, 5 as $t],
418 vec![3 as $t, 2 as $t]
419 ];
420 a.dot(&b);
421 }
422 }
423
424 item! {
425 #[test]
426 #[should_panic]
427 fn [<test_mat_mat_panic_5_ $t>]() {
428 let a = vec![
429 vec![1 as $t, 2 as $t, 3 as $t],
430 vec![4 as $t, 5 as $t, 6 as $t],
431 vec![3 as $t, 2 as $t, 1 as $t]
432 ];
433 let b = vec![
434 vec![3 as $t, 2 as $t, 1 as $t],
435 vec![6 as $t, 5 as $t, 4 as $t],
436 vec![2 as $t, 3 as $t]
437 ];
438 a.dot(&b);
439 }
440 }
441
442 item! {
443 #[test]
444 #[should_panic]
445 fn [<test_mat_mat_panic_6_ $t>]() {
446 let a = vec![
447 vec![1 as $t, 2 as $t, 3 as $t],
448 vec![4 as $t, 5 as $t],
449 vec![3 as $t, 2 as $t, 1 as $t]
450 ];
451 let b = vec![
452 vec![3 as $t, 2 as $t, 1 as $t],
453 vec![6 as $t, 5 as $t, 4 as $t],
454 vec![2 as $t, 4 as $t, 3 as $t]
455 ];
456 a.dot(&b);
457 }
458 }
459
460 item! {
461 #[test]
462 fn [<test_mat_primitive_ $t>]() {
463 let a = vec![
464 vec![1 as $t, 2 as $t, 3 as $t],
465 vec![4 as $t, 5 as $t, 6 as $t],
466 vec![3 as $t, 2 as $t, 1 as $t]
467 ];
468 let res = vec![
469 vec![2 as $t, 4 as $t, 6 as $t],
470 vec![8 as $t, 10 as $t, 12 as $t],
471 vec![6 as $t, 4 as $t, 2 as $t]
472 ];
473 let product = a.dot(&(2 as $t));
474 for i in 0..3 {
475 for j in 0..3 {
476 assert_relative_eq!(res[i][j] as f64, product[i][j] as f64, epsilon = f64::EPSILON);
477 }
478 }
479 }
480 }
481
482 item! {
483 #[test]
484 fn [<test_mat_primitive_complex_ $t>]() {
485 let a = vec![
486 vec![Complex::new(2 as $t, 1 as $t), Complex::new(5 as $t, 2 as $t)],
487 vec![Complex::new(4 as $t, 2 as $t), Complex::new(7 as $t, 1 as $t)],
488 ];
489 let b = Complex::new(4 as $t, 1 as $t);
490 let res = vec![
491 vec![a[0][0] * b, a[0][1] * b],
492 vec![a[1][0] * b, a[1][1] * b],
493 ];
494 let product = a.dot(&b);
495 for i in 0..2 {
496 for j in 0..2 {
497 assert_relative_eq!(res[i][j].re as f64, product[i][j].re as f64, epsilon = f64::EPSILON);
498 assert_relative_eq!(res[i][j].im as f64, product[i][j].im as f64, epsilon = f64::EPSILON);
499 }
500 }
501 }
502 }
503
504 item! {
505 #[test]
506 fn [<test_primitive_mat_ $t>]() {
507 let a = vec![
508 vec![1 as $t, 2 as $t, 3 as $t],
509 vec![4 as $t, 5 as $t, 6 as $t],
510 vec![3 as $t, 2 as $t, 1 as $t]
511 ];
512 let res = vec![
513 vec![2 as $t, 4 as $t, 6 as $t],
514 vec![8 as $t, 10 as $t, 12 as $t],
515 vec![6 as $t, 4 as $t, 2 as $t]
516 ];
517 let product = (2 as $t).dot(&a);
518 for i in 0..3 {
519 for j in 0..3 {
520 assert_relative_eq!(res[i][j] as f64, product[i][j] as f64, epsilon = f64::EPSILON);
521 }
522 }
523 }
524 }
525
526 item! {
527 #[test]
528 fn [<test_primitive_mat_complex_ $t>]() {
529 let a = vec![
530 vec![Complex::new(2 as $t, 1 as $t), Complex::new(5 as $t, 2 as $t)],
531 vec![Complex::new(4 as $t, 2 as $t), Complex::new(7 as $t, 1 as $t)],
532 ];
533 let b = Complex::new(4 as $t, 1 as $t);
534 let res = vec![
535 vec![a[0][0] * b, a[0][1] * b],
536 vec![a[1][0] * b, a[1][1] * b],
537 ];
538 let product = b.dot(&a);
539 for i in 0..2 {
540 for j in 0..2 {
541 assert_relative_eq!(res[i][j].re as f64, product[i][j].re as f64, epsilon = f64::EPSILON);
542 assert_relative_eq!(res[i][j].im as f64, product[i][j].im as f64, epsilon = f64::EPSILON);
543 }
544 }
545 }
546 }
547 };
548 }
549
550 make_test!(i8);
551 make_test!(u8);
552 make_test!(i16);
553 make_test!(u16);
554 make_test!(i32);
555 make_test!(u32);
556 make_test!(i64);
557 make_test!(u64);
558 make_test!(f32);
559 make_test!(f64);
560}