1 /*
2 
3     Copyright (C) 2014, The University of Texas at Austin
4 
5     This file is part of libflame and is available under the 3-Clause
6     BSD license, which can be found in the LICENSE file at the top-level
7     directory, or at http://opensource.org/licenses/BSD-3-Clause
8 
9 */
10 
11 #include "FLAME.h"
12 
FLA_Gemm_external(FLA_Trans transa,FLA_Trans transb,FLA_Obj alpha,FLA_Obj A,FLA_Obj B,FLA_Obj beta,FLA_Obj C)13 FLA_Error FLA_Gemm_external( FLA_Trans transa, FLA_Trans transb, FLA_Obj alpha, FLA_Obj A, FLA_Obj B, FLA_Obj beta, FLA_Obj C )
14 {
15   FLA_Datatype datatype;
16   int          k_AB;
17   int          m_A, n_A;
18   int          m_C, n_C;
19   int          rs_A, cs_A;
20   int          rs_B, cs_B;
21   int          rs_C, cs_C;
22   trans1_t      blis_transa;
23   trans1_t      blis_transb;
24 
25   if ( FLA_Check_error_level() == FLA_FULL_ERROR_CHECKING )
26     FLA_Gemm_check( transa, transb, alpha, A, B, beta, C );
27 
28   if ( FLA_Obj_has_zero_dim( C ) ) return FLA_SUCCESS;
29 
30   if ( FLA_Obj_has_zero_dim( A ) || FLA_Obj_has_zero_dim( B ) )
31   {
32     FLA_Scal_external( beta, C );
33     return FLA_SUCCESS;
34   }
35 
36   datatype = FLA_Obj_datatype( A );
37 
38   m_A      = FLA_Obj_length( A );
39   n_A      = FLA_Obj_width( A );
40   rs_A     = FLA_Obj_row_stride( A );
41   cs_A     = FLA_Obj_col_stride( A );
42 
43   rs_B     = FLA_Obj_row_stride( B );
44   cs_B     = FLA_Obj_col_stride( B );
45 
46   m_C      = FLA_Obj_length( C );
47   n_C      = FLA_Obj_width( C );
48   rs_C     = FLA_Obj_row_stride( C );
49   cs_C     = FLA_Obj_col_stride( C );
50 
51   if ( transa == FLA_NO_TRANSPOSE || transa == FLA_CONJ_NO_TRANSPOSE )
52     k_AB = n_A;
53   else
54     k_AB = m_A;
55 
56   FLA_Param_map_flame_to_blis_trans( transa, &blis_transa );
57   FLA_Param_map_flame_to_blis_trans( transb, &blis_transb );
58 
59 
60   switch( datatype ){
61 
62   case FLA_FLOAT:
63   {
64     float *buff_A     = ( float * ) FLA_FLOAT_PTR( A );
65     float *buff_B     = ( float * ) FLA_FLOAT_PTR( B );
66     float *buff_C     = ( float * ) FLA_FLOAT_PTR( C );
67     float *buff_alpha = ( float * ) FLA_FLOAT_PTR( alpha );
68     float *buff_beta  = ( float * ) FLA_FLOAT_PTR( beta );
69 
70     bl1_sgemm( blis_transa,
71                blis_transb,
72                m_C,
73                k_AB,
74                n_C,
75                buff_alpha,
76                buff_A, rs_A, cs_A,
77                buff_B, rs_B, cs_B,
78                buff_beta,
79                buff_C, rs_C, cs_C );
80 
81     break;
82   }
83 
84   case FLA_DOUBLE:
85   {
86     double *buff_A     = ( double * ) FLA_DOUBLE_PTR( A );
87     double *buff_B     = ( double * ) FLA_DOUBLE_PTR( B );
88     double *buff_C     = ( double * ) FLA_DOUBLE_PTR( C );
89     double *buff_alpha = ( double * ) FLA_DOUBLE_PTR( alpha );
90     double *buff_beta  = ( double * ) FLA_DOUBLE_PTR( beta );
91 
92     bl1_dgemm( blis_transa,
93                blis_transb,
94                m_C,
95                k_AB,
96                n_C,
97                buff_alpha,
98                buff_A, rs_A, cs_A,
99                buff_B, rs_B, cs_B,
100                buff_beta,
101                buff_C, rs_C, cs_C );
102 
103     break;
104   }
105 
106   case FLA_COMPLEX:
107   {
108     scomplex *buff_A     = ( scomplex * ) FLA_COMPLEX_PTR( A );
109     scomplex *buff_B     = ( scomplex * ) FLA_COMPLEX_PTR( B );
110     scomplex *buff_C     = ( scomplex * ) FLA_COMPLEX_PTR( C );
111     scomplex *buff_alpha = ( scomplex * ) FLA_COMPLEX_PTR( alpha );
112     scomplex *buff_beta  = ( scomplex * ) FLA_COMPLEX_PTR( beta );
113 
114     bl1_cgemm( blis_transa,
115                blis_transb,
116                m_C,
117                k_AB,
118                n_C,
119                buff_alpha,
120                buff_A, rs_A, cs_A,
121                buff_B, rs_B, cs_B,
122                buff_beta,
123                buff_C, rs_C, cs_C );
124 
125     break;
126   }
127 
128   case FLA_DOUBLE_COMPLEX:
129   {
130     dcomplex *buff_A     = ( dcomplex * ) FLA_DOUBLE_COMPLEX_PTR( A );
131     dcomplex *buff_B     = ( dcomplex * ) FLA_DOUBLE_COMPLEX_PTR( B );
132     dcomplex *buff_C     = ( dcomplex * ) FLA_DOUBLE_COMPLEX_PTR( C );
133     dcomplex *buff_alpha = ( dcomplex * ) FLA_DOUBLE_COMPLEX_PTR( alpha );
134     dcomplex *buff_beta  = ( dcomplex * ) FLA_DOUBLE_COMPLEX_PTR( beta );
135 
136     bl1_zgemm( blis_transa,
137                blis_transb,
138                m_C,
139                k_AB,
140                n_C,
141                buff_alpha,
142                buff_A, rs_A, cs_A,
143                buff_B, rs_B, cs_B,
144                buff_beta,
145                buff_C, rs_C, cs_C );
146 
147     break;
148   }
149 
150   }
151 
152   return FLA_SUCCESS;
153 }
154 
155