1 #ifndef _TBLIS_IFACE_1T_DOT_H_
2 #define _TBLIS_IFACE_1T_DOT_H_
3 
4 #include "../../util/thread.h"
5 #include "../../util/basic_types.h"
6 
7 #ifdef __cplusplus
8 
9 namespace tblis
10 {
11 
12 extern "C"
13 {
14 
15 #endif
16 
17 void tblis_tensor_dot(const tblis_comm* comm, const tblis_config* cfg,
18                       const tblis_tensor* A, const label_type* idx_A,
19                       const tblis_tensor* B, const label_type* idx_B,
20                       tblis_scalar* result);
21 
22 #ifdef __cplusplus
23 }
24 #endif
25 
26 #if defined(__cplusplus) && !defined(TBLIS_DONT_USE_CXX11)
27 
28 template <typename T>
dot(varray_view<const T> A,const label_type * idx_A,varray_view<const T> B,const label_type * idx_B,T & result)29 void dot(varray_view<const T> A, const label_type* idx_A,
30          varray_view<const T> B, const label_type* idx_B, T& result)
31 {
32     tblis_tensor A_s(A);
33     tblis_tensor B_s(B);
34     tblis_scalar result_s(result);
35     tblis_tensor_dot(nullptr, nullptr, &A_s, idx_A, &B_s, idx_B, &result_s);
36     result = result_s.get<T>();
37 }
38 
39 template <typename T>
dot(const communicator & comm,varray_view<const T> A,const label_type * idx_A,varray_view<const T> B,const label_type * idx_B,T & result)40 void dot(const communicator& comm,
41          varray_view<const T> A, const label_type* idx_A,
42          varray_view<const T> B, const label_type* idx_B, T& result)
43 {
44     tblis_tensor A_s(A);
45     tblis_tensor B_s(B);
46     tblis_scalar result_s(result);
47     tblis_tensor_dot(comm, nullptr, &A_s, idx_A, &B_s, idx_B, &result_s);
48     result = result_s.get<T>();
49 }
50 
51 template <typename T>
dot(varray_view<const T> A,const label_type * idx_A,varray_view<const T> B,const label_type * idx_B)52 T dot(varray_view<const T> A, const label_type* idx_A,
53       varray_view<const T> B, const label_type* idx_B)
54 {
55     T result;
56     dot(A, idx_A, B, idx_B, result);
57     return result;
58 }
59 
60 template <typename T>
dot(const communicator & comm,varray_view<const T> A,const label_type * idx_A,varray_view<const T> B,const label_type * idx_B)61 T dot(const communicator& comm,
62       varray_view<const T> A, const label_type* idx_A,
63       varray_view<const T> B, const label_type* idx_B)
64 {
65     T result;
66     dot(comm, A, idx_A, B, idx_B, result);
67     return result;
68 }
69 
70 template <typename T>
71 void dot(const communicator& comm,
72          dpd_varray_view<const T> A, const label_type* idx_A,
73          dpd_varray_view<const T> B, const label_type* idx_B, T& result);
74 
75 template <typename T>
dot(dpd_varray_view<const T> A,const label_type * idx_A,dpd_varray_view<const T> B,const label_type * idx_B,T & result)76 void dot(dpd_varray_view<const T> A, const label_type* idx_A,
77          dpd_varray_view<const T> B, const label_type* idx_B, T& result)
78 {
79     parallelize
80     (
81         [&](const communicator& comm)
82         {
83             dot(comm, A, idx_A, B, idx_B, result);
84         },
85         tblis_get_num_threads()
86     );
87 }
88 
89 template <typename T>
dot(dpd_varray_view<const T> A,const label_type * idx_A,dpd_varray_view<const T> B,const label_type * idx_B)90 T dot(dpd_varray_view<const T> A, const label_type* idx_A,
91       dpd_varray_view<const T> B, const label_type* idx_B)
92 {
93     T result;
94     dot(A, idx_A, B, idx_B, result);
95     return result;
96 }
97 
98 template <typename T>
dot(const communicator & comm,dpd_varray_view<const T> A,const label_type * idx_A,dpd_varray_view<const T> B,const label_type * idx_B)99 T dot(const communicator& comm,
100       dpd_varray_view<const T> A, const label_type* idx_A,
101       dpd_varray_view<const T> B, const label_type* idx_B)
102 {
103     T result;
104     dot(comm, A, idx_A, B, idx_B, result);
105     return result;
106 }
107 
108 template <typename T>
109 void dot(const communicator& comm,
110          indexed_varray_view<const T> A, const label_type* idx_A,
111          indexed_varray_view<const T> B, const label_type* idx_B, T& result);
112 
113 template <typename T>
dot(indexed_varray_view<const T> A,const label_type * idx_A,indexed_varray_view<const T> B,const label_type * idx_B,T & result)114 void dot(indexed_varray_view<const T> A, const label_type* idx_A,
115          indexed_varray_view<const T> B, const label_type* idx_B, T& result)
116 {
117     parallelize
118     (
119         [&](const communicator& comm)
120         {
121             dot(comm, A, idx_A, B, idx_B, result);
122         },
123         tblis_get_num_threads()
124     );
125 }
126 
127 template <typename T>
dot(indexed_varray_view<const T> A,const label_type * idx_A,indexed_varray_view<const T> B,const label_type * idx_B)128 T dot(indexed_varray_view<const T> A, const label_type* idx_A,
129       indexed_varray_view<const T> B, const label_type* idx_B)
130 {
131     T result;
132     dot(A, idx_A, B, idx_B, result);
133     return result;
134 }
135 
136 template <typename T>
dot(const communicator & comm,indexed_varray_view<const T> A,const label_type * idx_A,indexed_varray_view<const T> B,const label_type * idx_B)137 T dot(const communicator& comm,
138       indexed_varray_view<const T> A, const label_type* idx_A,
139       indexed_varray_view<const T> B, const label_type* idx_B)
140 {
141     T result;
142     dot(comm, A, idx_A, B, idx_B, result);
143     return result;
144 }
145 
146 template <typename T>
147 void dot(const communicator& comm,
148          indexed_dpd_varray_view<const T> A, const label_type* idx_A,
149          indexed_dpd_varray_view<const T> B, const label_type* idx_B, T& result);
150 
151 template <typename T>
dot(indexed_dpd_varray_view<const T> A,const label_type * idx_A,indexed_dpd_varray_view<const T> B,const label_type * idx_B,T & result)152 void dot(indexed_dpd_varray_view<const T> A, const label_type* idx_A,
153          indexed_dpd_varray_view<const T> B, const label_type* idx_B, T& result)
154 {
155     parallelize
156     (
157         [&](const communicator& comm)
158         {
159             dot(comm, A, idx_A, B, idx_B, result);
160         },
161         tblis_get_num_threads()
162     );
163 }
164 
165 template <typename T>
dot(indexed_dpd_varray_view<const T> A,const label_type * idx_A,indexed_dpd_varray_view<const T> B,const label_type * idx_B)166 T dot(indexed_dpd_varray_view<const T> A, const label_type* idx_A,
167       indexed_dpd_varray_view<const T> B, const label_type* idx_B)
168 {
169     T result;
170     dot(A, idx_A, B, idx_B, result);
171     return result;
172 }
173 
174 template <typename T>
dot(const communicator & comm,indexed_dpd_varray_view<const T> A,const label_type * idx_A,indexed_dpd_varray_view<const T> B,const label_type * idx_B)175 T dot(const communicator& comm,
176       indexed_dpd_varray_view<const T> A, const label_type* idx_A,
177       indexed_dpd_varray_view<const T> B, const label_type* idx_B)
178 {
179     T result;
180     dot(comm, A, idx_A, B, idx_B, result);
181     return result;
182 }
183 
184 #endif
185 
186 #ifdef __cplusplus
187 }
188 #endif
189 
190 #endif
191