1 #ifndef _TBLIS_IFACE_1T_DOT_H_
2 #define _TBLIS_IFACE_1T_DOT_H_
3
4 #include "../../util/thread.h"
5 #include "../../util/basic_types.h"
6
7 #ifdef __cplusplus
8
9 namespace tblis
10 {
11
12 extern "C"
13 {
14
15 #endif
16
17 void tblis_tensor_dot(const tblis_comm* comm, const tblis_config* cfg,
18 const tblis_tensor* A, const label_type* idx_A,
19 const tblis_tensor* B, const label_type* idx_B,
20 tblis_scalar* result);
21
22 #ifdef __cplusplus
23 }
24 #endif
25
26 #if defined(__cplusplus) && !defined(TBLIS_DONT_USE_CXX11)
27
28 template <typename T>
dot(varray_view<const T> A,const label_type * idx_A,varray_view<const T> B,const label_type * idx_B,T & result)29 void dot(varray_view<const T> A, const label_type* idx_A,
30 varray_view<const T> B, const label_type* idx_B, T& result)
31 {
32 tblis_tensor A_s(A);
33 tblis_tensor B_s(B);
34 tblis_scalar result_s(result);
35 tblis_tensor_dot(nullptr, nullptr, &A_s, idx_A, &B_s, idx_B, &result_s);
36 result = result_s.get<T>();
37 }
38
39 template <typename T>
dot(const communicator & comm,varray_view<const T> A,const label_type * idx_A,varray_view<const T> B,const label_type * idx_B,T & result)40 void dot(const communicator& comm,
41 varray_view<const T> A, const label_type* idx_A,
42 varray_view<const T> B, const label_type* idx_B, T& result)
43 {
44 tblis_tensor A_s(A);
45 tblis_tensor B_s(B);
46 tblis_scalar result_s(result);
47 tblis_tensor_dot(comm, nullptr, &A_s, idx_A, &B_s, idx_B, &result_s);
48 result = result_s.get<T>();
49 }
50
51 template <typename T>
dot(varray_view<const T> A,const label_type * idx_A,varray_view<const T> B,const label_type * idx_B)52 T dot(varray_view<const T> A, const label_type* idx_A,
53 varray_view<const T> B, const label_type* idx_B)
54 {
55 T result;
56 dot(A, idx_A, B, idx_B, result);
57 return result;
58 }
59
60 template <typename T>
dot(const communicator & comm,varray_view<const T> A,const label_type * idx_A,varray_view<const T> B,const label_type * idx_B)61 T dot(const communicator& comm,
62 varray_view<const T> A, const label_type* idx_A,
63 varray_view<const T> B, const label_type* idx_B)
64 {
65 T result;
66 dot(comm, A, idx_A, B, idx_B, result);
67 return result;
68 }
69
70 template <typename T>
71 void dot(const communicator& comm,
72 dpd_varray_view<const T> A, const label_type* idx_A,
73 dpd_varray_view<const T> B, const label_type* idx_B, T& result);
74
75 template <typename T>
dot(dpd_varray_view<const T> A,const label_type * idx_A,dpd_varray_view<const T> B,const label_type * idx_B,T & result)76 void dot(dpd_varray_view<const T> A, const label_type* idx_A,
77 dpd_varray_view<const T> B, const label_type* idx_B, T& result)
78 {
79 parallelize
80 (
81 [&](const communicator& comm)
82 {
83 dot(comm, A, idx_A, B, idx_B, result);
84 },
85 tblis_get_num_threads()
86 );
87 }
88
89 template <typename T>
dot(dpd_varray_view<const T> A,const label_type * idx_A,dpd_varray_view<const T> B,const label_type * idx_B)90 T dot(dpd_varray_view<const T> A, const label_type* idx_A,
91 dpd_varray_view<const T> B, const label_type* idx_B)
92 {
93 T result;
94 dot(A, idx_A, B, idx_B, result);
95 return result;
96 }
97
98 template <typename T>
dot(const communicator & comm,dpd_varray_view<const T> A,const label_type * idx_A,dpd_varray_view<const T> B,const label_type * idx_B)99 T dot(const communicator& comm,
100 dpd_varray_view<const T> A, const label_type* idx_A,
101 dpd_varray_view<const T> B, const label_type* idx_B)
102 {
103 T result;
104 dot(comm, A, idx_A, B, idx_B, result);
105 return result;
106 }
107
108 template <typename T>
109 void dot(const communicator& comm,
110 indexed_varray_view<const T> A, const label_type* idx_A,
111 indexed_varray_view<const T> B, const label_type* idx_B, T& result);
112
113 template <typename T>
dot(indexed_varray_view<const T> A,const label_type * idx_A,indexed_varray_view<const T> B,const label_type * idx_B,T & result)114 void dot(indexed_varray_view<const T> A, const label_type* idx_A,
115 indexed_varray_view<const T> B, const label_type* idx_B, T& result)
116 {
117 parallelize
118 (
119 [&](const communicator& comm)
120 {
121 dot(comm, A, idx_A, B, idx_B, result);
122 },
123 tblis_get_num_threads()
124 );
125 }
126
127 template <typename T>
dot(indexed_varray_view<const T> A,const label_type * idx_A,indexed_varray_view<const T> B,const label_type * idx_B)128 T dot(indexed_varray_view<const T> A, const label_type* idx_A,
129 indexed_varray_view<const T> B, const label_type* idx_B)
130 {
131 T result;
132 dot(A, idx_A, B, idx_B, result);
133 return result;
134 }
135
136 template <typename T>
dot(const communicator & comm,indexed_varray_view<const T> A,const label_type * idx_A,indexed_varray_view<const T> B,const label_type * idx_B)137 T dot(const communicator& comm,
138 indexed_varray_view<const T> A, const label_type* idx_A,
139 indexed_varray_view<const T> B, const label_type* idx_B)
140 {
141 T result;
142 dot(comm, A, idx_A, B, idx_B, result);
143 return result;
144 }
145
146 template <typename T>
147 void dot(const communicator& comm,
148 indexed_dpd_varray_view<const T> A, const label_type* idx_A,
149 indexed_dpd_varray_view<const T> B, const label_type* idx_B, T& result);
150
151 template <typename T>
dot(indexed_dpd_varray_view<const T> A,const label_type * idx_A,indexed_dpd_varray_view<const T> B,const label_type * idx_B,T & result)152 void dot(indexed_dpd_varray_view<const T> A, const label_type* idx_A,
153 indexed_dpd_varray_view<const T> B, const label_type* idx_B, T& result)
154 {
155 parallelize
156 (
157 [&](const communicator& comm)
158 {
159 dot(comm, A, idx_A, B, idx_B, result);
160 },
161 tblis_get_num_threads()
162 );
163 }
164
165 template <typename T>
dot(indexed_dpd_varray_view<const T> A,const label_type * idx_A,indexed_dpd_varray_view<const T> B,const label_type * idx_B)166 T dot(indexed_dpd_varray_view<const T> A, const label_type* idx_A,
167 indexed_dpd_varray_view<const T> B, const label_type* idx_B)
168 {
169 T result;
170 dot(A, idx_A, B, idx_B, result);
171 return result;
172 }
173
174 template <typename T>
dot(const communicator & comm,indexed_dpd_varray_view<const T> A,const label_type * idx_A,indexed_dpd_varray_view<const T> B,const label_type * idx_B)175 T dot(const communicator& comm,
176 indexed_dpd_varray_view<const T> A, const label_type* idx_A,
177 indexed_dpd_varray_view<const T> B, const label_type* idx_B)
178 {
179 T result;
180 dot(comm, A, idx_A, B, idx_B, result);
181 return result;
182 }
183
184 #endif
185
186 #ifdef __cplusplus
187 }
188 #endif
189
190 #endif
191