1 /*
2 
3     Copyright (C) 2014, The University of Texas at Austin
4 
5     This file is part of libflame and is available under the 3-Clause
6     BSD license, which can be found in the LICENSE file at the top-level
7     directory, or at http://opensource.org/licenses/BSD-3-Clause
8 
9 */
10 
11 #include "FLAME.h"
12 
FLA_Trmmsx_external(FLA_Side side,FLA_Uplo uplo,FLA_Trans trans,FLA_Diag diag,FLA_Obj alpha,FLA_Obj A,FLA_Obj B,FLA_Obj beta,FLA_Obj C)13 FLA_Error FLA_Trmmsx_external( FLA_Side side, FLA_Uplo uplo, FLA_Trans trans, FLA_Diag diag, FLA_Obj alpha, FLA_Obj A, FLA_Obj B, FLA_Obj beta, FLA_Obj C )
14 {
15   FLA_Datatype datatype;
16   int          m_B, n_B;
17   int          rs_A, cs_A;
18   int          rs_B, cs_B;
19   int          rs_C, cs_C;
20   side1_t       blis_side;
21   uplo1_t       blis_uplo;
22   trans1_t      blis_trans;
23   diag1_t       blis_diag;
24 
25   if ( FLA_Check_error_level() == FLA_FULL_ERROR_CHECKING )
26     FLA_Trmmsx_check( side, uplo, trans, diag, alpha, A, B, beta, C );
27 
28   if ( FLA_Obj_has_zero_dim( B ) ) return FLA_SUCCESS;
29 
30   datatype = FLA_Obj_datatype( A );
31 
32   rs_A     = FLA_Obj_row_stride( A );
33   cs_A     = FLA_Obj_col_stride( A );
34 
35   m_B      = FLA_Obj_length( B );
36   n_B      = FLA_Obj_width( B );
37   rs_B     = FLA_Obj_row_stride( B );
38   cs_B     = FLA_Obj_col_stride( B );
39 
40   rs_C     = FLA_Obj_row_stride( C );
41   cs_C     = FLA_Obj_col_stride( C );
42 
43   FLA_Param_map_flame_to_blis_side( side, &blis_side );
44   FLA_Param_map_flame_to_blis_uplo( uplo, &blis_uplo );
45   FLA_Param_map_flame_to_blis_trans( trans, &blis_trans );
46   FLA_Param_map_flame_to_blis_diag( diag, &blis_diag );
47 
48 
49   switch( datatype ){
50 
51   case FLA_FLOAT:
52   {
53     float *buff_A     = ( float * ) FLA_FLOAT_PTR( A );
54     float *buff_B     = ( float * ) FLA_FLOAT_PTR( B );
55     float *buff_C     = ( float * ) FLA_FLOAT_PTR( C );
56     float *buff_alpha = ( float * ) FLA_FLOAT_PTR( alpha );
57     float *buff_beta  = ( float * ) FLA_FLOAT_PTR( beta );
58 
59     bl1_strmmsx( blis_side,
60                  blis_uplo,
61                  blis_trans,
62                  blis_diag,
63                  m_B,
64                  n_B,
65                  buff_alpha,
66                  buff_A, rs_A, cs_A,
67                  buff_B, rs_B, cs_B,
68                  buff_beta,
69                  buff_C, rs_C, cs_C );
70 
71     break;
72   }
73 
74   case FLA_DOUBLE:
75   {
76     double *buff_A     = ( double * ) FLA_DOUBLE_PTR( A );
77     double *buff_B     = ( double * ) FLA_DOUBLE_PTR( B );
78     double *buff_C     = ( double * ) FLA_DOUBLE_PTR( C );
79     double *buff_alpha = ( double * ) FLA_DOUBLE_PTR( alpha );
80     double *buff_beta  = ( double * ) FLA_DOUBLE_PTR( beta );
81 
82     bl1_dtrmmsx( blis_side,
83                  blis_uplo,
84                  blis_trans,
85                  blis_diag,
86                  m_B,
87                  n_B,
88                  buff_alpha,
89                  buff_A, rs_A, cs_A,
90                  buff_B, rs_B, cs_B,
91                  buff_beta,
92                  buff_C, rs_C, cs_C );
93 
94     break;
95   }
96 
97   case FLA_COMPLEX:
98   {
99     scomplex *buff_A     = ( scomplex * ) FLA_COMPLEX_PTR( A );
100     scomplex *buff_B     = ( scomplex * ) FLA_COMPLEX_PTR( B );
101     scomplex *buff_C     = ( scomplex * ) FLA_COMPLEX_PTR( C );
102     scomplex *buff_alpha = ( scomplex * ) FLA_COMPLEX_PTR( alpha );
103     scomplex *buff_beta  = ( scomplex * ) FLA_COMPLEX_PTR( beta );
104 
105     bl1_ctrmmsx( blis_side,
106                  blis_uplo,
107                  blis_trans,
108                  blis_diag,
109                  m_B,
110                  n_B,
111                  buff_alpha,
112                  buff_A, rs_A, cs_A,
113                  buff_B, rs_B, cs_B,
114                  buff_beta,
115                  buff_C, rs_C, cs_C );
116 
117     break;
118   }
119 
120 
121   case FLA_DOUBLE_COMPLEX:
122   {
123     dcomplex *buff_A     = ( dcomplex * ) FLA_DOUBLE_COMPLEX_PTR( A );
124     dcomplex *buff_B     = ( dcomplex * ) FLA_DOUBLE_COMPLEX_PTR( B );
125     dcomplex *buff_C     = ( dcomplex * ) FLA_DOUBLE_COMPLEX_PTR( C );
126     dcomplex *buff_alpha = ( dcomplex * ) FLA_DOUBLE_COMPLEX_PTR( alpha );
127     dcomplex *buff_beta  = ( dcomplex * ) FLA_DOUBLE_COMPLEX_PTR( beta );
128 
129     bl1_ztrmmsx( blis_side,
130                  blis_uplo,
131                  blis_trans,
132                  blis_diag,
133                  m_B,
134                  n_B,
135                  buff_alpha,
136                  buff_A, rs_A, cs_A,
137                  buff_B, rs_B, cs_B,
138                  buff_beta,
139                  buff_C, rs_C, cs_C );
140 
141     break;
142   }
143 
144   }
145 
146   return FLA_SUCCESS;
147 }
148 
149