1 #include "dot.h"
2
3 #include "util/macros.h"
4 #include "internal/1m/dot.hpp"
5
6 namespace tblis
7 {
8
9 extern "C"
10 {
11
tblis_matrix_dot(const tblis_comm * comm,const tblis_config * cfg,const tblis_matrix * A,const tblis_matrix * B,tblis_scalar * result)12 void tblis_matrix_dot(const tblis_comm* comm, const tblis_config* cfg,
13 const tblis_matrix* A, const tblis_matrix* B,
14 tblis_scalar* result)
15 {
16 TBLIS_ASSERT(A->m == B->m);
17 TBLIS_ASSERT(A->n == B->n);
18 TBLIS_ASSERT(A->type == B->type);
19 TBLIS_ASSERT(A->type == result->type);
20
21 TBLIS_WITH_TYPE_AS(A->type, T,
22 {
23 parallelize_if(
24 [&](const communicator& comm)
25 {
26 internal::dot<T>(comm, get_config(cfg), A->m, A->n,
27 A->conj, static_cast<const T*>(A->data), A->rs, A->cs,
28 B->conj, static_cast<const T*>(B->data), B->rs, B->cs,
29 result->get<T>());
30 }, comm);
31
32 result->get<T>() *= A->alpha<T>()*B->alpha<T>();
33 })
34 }
35
36 }
37
38 }
39