1 #ifndef _TBLIS_IFACE_1V_DOT_H_
2 #define _TBLIS_IFACE_1V_DOT_H_
3 
4 #include "../../util/thread.h"
5 #include "../../util/basic_types.h"
6 
7 #ifdef __cplusplus
8 
9 namespace tblis
10 {
11 
12 extern "C"
13 {
14 
15 #endif
16 
17 void tblis_vector_dot(const tblis_comm* comm, const tblis_config* cfg,
18                       const tblis_vector* A, const tblis_vector* B,
19                       tblis_scalar* result);
20 
21 #ifdef __cplusplus
22 }
23 #endif
24 
25 #if defined(__cplusplus) && !defined(TBLIS_DONT_USE_CXX11)
26 
27 template <typename T>
dot(row_view<const T> A,row_view<const T> B,T & result)28 void dot(row_view<const T> A, row_view<const T> B, T& result)
29 {
30     tblis_vector A_s(A);
31     tblis_vector B_s(B);
32     tblis_scalar result_s(result);
33     tblis_vector_dot(nullptr, nullptr, &A_s, &B_s, &result_s);
34     result = result_s.get<T>();
35 }
36 
37 template <typename T>
dot(const communicator & comm,row_view<const T> A,row_view<const T> B,T & result)38 void dot(const communicator& comm, row_view<const T> A, row_view<const T> B, T& result)
39 {
40     tblis_vector A_s(A);
41     tblis_vector B_s(B);
42     tblis_scalar result_s(result);
43     tblis_vector_dot(comm, nullptr, &A_s, &B_s, &result_s);
44     result = result_s.get<T>();
45 }
46 
47 template <typename T>
dot(row_view<const T> A,row_view<const T> B)48 T dot(row_view<const T> A, row_view<const T> B)
49 {
50     T result;
51     dot(A, B, result);
52     return result;
53 }
54 
55 template <typename T>
dot(const communicator & comm,row_view<const T> A,row_view<const T> B)56 T dot(const communicator& comm, row_view<const T> A, row_view<const T> B)
57 {
58     T result;
59     dot(comm, A, B, result);
60     return result;
61 }
62 
63 #endif
64 
65 #ifdef __cplusplus
66 }
67 #endif
68 
69 #endif
70