1 #include "dot.h"
2 
3 #include "util/macros.h"
4 #include "internal/1m/dot.hpp"
5 
6 namespace tblis
7 {
8 
9 extern "C"
10 {
11 
tblis_matrix_dot(const tblis_comm * comm,const tblis_config * cfg,const tblis_matrix * A,const tblis_matrix * B,tblis_scalar * result)12 void tblis_matrix_dot(const tblis_comm* comm, const tblis_config* cfg,
13                       const tblis_matrix* A, const tblis_matrix* B,
14                       tblis_scalar* result)
15 {
16     TBLIS_ASSERT(A->m == B->m);
17     TBLIS_ASSERT(A->n == B->n);
18     TBLIS_ASSERT(A->type == B->type);
19     TBLIS_ASSERT(A->type == result->type);
20 
21     TBLIS_WITH_TYPE_AS(A->type, T,
22     {
23         parallelize_if(
24         [&](const communicator& comm)
25         {
26             internal::dot<T>(comm, get_config(cfg), A->m, A->n,
27                              A->conj, static_cast<const T*>(A->data), A->rs, A->cs,
28                              B->conj, static_cast<const T*>(B->data), B->rs, B->cs,
29                              result->get<T>());
30         }, comm);
31 
32         result->get<T>() *= A->alpha<T>()*B->alpha<T>();
33     })
34 }
35 
36 }
37 
38 }
39