1 /*
2
3 Copyright (C) 2014, The University of Texas at Austin
4
5 This file is part of libflame and is available under the 3-Clause
6 BSD license, which can be found in the LICENSE file at the top-level
7 directory, or at http://opensource.org/licenses/BSD-3-Clause
8
9 */
10
11 #include "FLAME.h"
12
FLA_Gemm_external(FLA_Trans transa,FLA_Trans transb,FLA_Obj alpha,FLA_Obj A,FLA_Obj B,FLA_Obj beta,FLA_Obj C)13 FLA_Error FLA_Gemm_external( FLA_Trans transa, FLA_Trans transb, FLA_Obj alpha, FLA_Obj A, FLA_Obj B, FLA_Obj beta, FLA_Obj C )
14 {
15 FLA_Datatype datatype;
16 int k_AB;
17 int m_A, n_A;
18 int m_C, n_C;
19 int rs_A, cs_A;
20 int rs_B, cs_B;
21 int rs_C, cs_C;
22 trans1_t blis_transa;
23 trans1_t blis_transb;
24
25 if ( FLA_Check_error_level() == FLA_FULL_ERROR_CHECKING )
26 FLA_Gemm_check( transa, transb, alpha, A, B, beta, C );
27
28 if ( FLA_Obj_has_zero_dim( C ) ) return FLA_SUCCESS;
29
30 if ( FLA_Obj_has_zero_dim( A ) || FLA_Obj_has_zero_dim( B ) )
31 {
32 FLA_Scal_external( beta, C );
33 return FLA_SUCCESS;
34 }
35
36 datatype = FLA_Obj_datatype( A );
37
38 m_A = FLA_Obj_length( A );
39 n_A = FLA_Obj_width( A );
40 rs_A = FLA_Obj_row_stride( A );
41 cs_A = FLA_Obj_col_stride( A );
42
43 rs_B = FLA_Obj_row_stride( B );
44 cs_B = FLA_Obj_col_stride( B );
45
46 m_C = FLA_Obj_length( C );
47 n_C = FLA_Obj_width( C );
48 rs_C = FLA_Obj_row_stride( C );
49 cs_C = FLA_Obj_col_stride( C );
50
51 if ( transa == FLA_NO_TRANSPOSE || transa == FLA_CONJ_NO_TRANSPOSE )
52 k_AB = n_A;
53 else
54 k_AB = m_A;
55
56 FLA_Param_map_flame_to_blis_trans( transa, &blis_transa );
57 FLA_Param_map_flame_to_blis_trans( transb, &blis_transb );
58
59
60 switch( datatype ){
61
62 case FLA_FLOAT:
63 {
64 float *buff_A = ( float * ) FLA_FLOAT_PTR( A );
65 float *buff_B = ( float * ) FLA_FLOAT_PTR( B );
66 float *buff_C = ( float * ) FLA_FLOAT_PTR( C );
67 float *buff_alpha = ( float * ) FLA_FLOAT_PTR( alpha );
68 float *buff_beta = ( float * ) FLA_FLOAT_PTR( beta );
69
70 bl1_sgemm( blis_transa,
71 blis_transb,
72 m_C,
73 k_AB,
74 n_C,
75 buff_alpha,
76 buff_A, rs_A, cs_A,
77 buff_B, rs_B, cs_B,
78 buff_beta,
79 buff_C, rs_C, cs_C );
80
81 break;
82 }
83
84 case FLA_DOUBLE:
85 {
86 double *buff_A = ( double * ) FLA_DOUBLE_PTR( A );
87 double *buff_B = ( double * ) FLA_DOUBLE_PTR( B );
88 double *buff_C = ( double * ) FLA_DOUBLE_PTR( C );
89 double *buff_alpha = ( double * ) FLA_DOUBLE_PTR( alpha );
90 double *buff_beta = ( double * ) FLA_DOUBLE_PTR( beta );
91
92 bl1_dgemm( blis_transa,
93 blis_transb,
94 m_C,
95 k_AB,
96 n_C,
97 buff_alpha,
98 buff_A, rs_A, cs_A,
99 buff_B, rs_B, cs_B,
100 buff_beta,
101 buff_C, rs_C, cs_C );
102
103 break;
104 }
105
106 case FLA_COMPLEX:
107 {
108 scomplex *buff_A = ( scomplex * ) FLA_COMPLEX_PTR( A );
109 scomplex *buff_B = ( scomplex * ) FLA_COMPLEX_PTR( B );
110 scomplex *buff_C = ( scomplex * ) FLA_COMPLEX_PTR( C );
111 scomplex *buff_alpha = ( scomplex * ) FLA_COMPLEX_PTR( alpha );
112 scomplex *buff_beta = ( scomplex * ) FLA_COMPLEX_PTR( beta );
113
114 bl1_cgemm( blis_transa,
115 blis_transb,
116 m_C,
117 k_AB,
118 n_C,
119 buff_alpha,
120 buff_A, rs_A, cs_A,
121 buff_B, rs_B, cs_B,
122 buff_beta,
123 buff_C, rs_C, cs_C );
124
125 break;
126 }
127
128 case FLA_DOUBLE_COMPLEX:
129 {
130 dcomplex *buff_A = ( dcomplex * ) FLA_DOUBLE_COMPLEX_PTR( A );
131 dcomplex *buff_B = ( dcomplex * ) FLA_DOUBLE_COMPLEX_PTR( B );
132 dcomplex *buff_C = ( dcomplex * ) FLA_DOUBLE_COMPLEX_PTR( C );
133 dcomplex *buff_alpha = ( dcomplex * ) FLA_DOUBLE_COMPLEX_PTR( alpha );
134 dcomplex *buff_beta = ( dcomplex * ) FLA_DOUBLE_COMPLEX_PTR( beta );
135
136 bl1_zgemm( blis_transa,
137 blis_transb,
138 m_C,
139 k_AB,
140 n_C,
141 buff_alpha,
142 buff_A, rs_A, cs_A,
143 buff_B, rs_B, cs_B,
144 buff_beta,
145 buff_C, rs_C, cs_C );
146
147 break;
148 }
149
150 }
151
152 return FLA_SUCCESS;
153 }
154
155