1 #ifndef _TBLIS_IFACE_1V_DOT_H_
2 #define _TBLIS_IFACE_1V_DOT_H_
3
4 #include "../../util/thread.h"
5 #include "../../util/basic_types.h"
6
7 #ifdef __cplusplus
8
9 namespace tblis
10 {
11
12 extern "C"
13 {
14
15 #endif
16
17 void tblis_vector_dot(const tblis_comm* comm, const tblis_config* cfg,
18 const tblis_vector* A, const tblis_vector* B,
19 tblis_scalar* result);
20
21 #ifdef __cplusplus
22 }
23 #endif
24
25 #if defined(__cplusplus) && !defined(TBLIS_DONT_USE_CXX11)
26
27 template <typename T>
dot(row_view<const T> A,row_view<const T> B,T & result)28 void dot(row_view<const T> A, row_view<const T> B, T& result)
29 {
30 tblis_vector A_s(A);
31 tblis_vector B_s(B);
32 tblis_scalar result_s(result);
33 tblis_vector_dot(nullptr, nullptr, &A_s, &B_s, &result_s);
34 result = result_s.get<T>();
35 }
36
37 template <typename T>
dot(const communicator & comm,row_view<const T> A,row_view<const T> B,T & result)38 void dot(const communicator& comm, row_view<const T> A, row_view<const T> B, T& result)
39 {
40 tblis_vector A_s(A);
41 tblis_vector B_s(B);
42 tblis_scalar result_s(result);
43 tblis_vector_dot(comm, nullptr, &A_s, &B_s, &result_s);
44 result = result_s.get<T>();
45 }
46
47 template <typename T>
dot(row_view<const T> A,row_view<const T> B)48 T dot(row_view<const T> A, row_view<const T> B)
49 {
50 T result;
51 dot(A, B, result);
52 return result;
53 }
54
55 template <typename T>
dot(const communicator & comm,row_view<const T> A,row_view<const T> B)56 T dot(const communicator& comm, row_view<const T> A, row_view<const T> B)
57 {
58 T result;
59 dot(comm, A, B, result);
60 return result;
61 }
62
63 #endif
64
65 #ifdef __cplusplus
66 }
67 #endif
68
69 #endif
70