1!--------------------------------------------------------------------------------------------------!
2!   CP2K: A general program to perform molecular dynamics simulations                              !
3!   Copyright (C) 2000 - 2019  CP2K developers group                                               !
4!--------------------------------------------------------------------------------------------------!
5
6! **************************************************************************************************
7!> \brief basic linear algebra operations for full matrixes
8!> \par History
9!>      08.2002 splitted out of qs_blacs [fawzi]
10!> \author Fawzi Mohamed
11! **************************************************************************************************
12MODULE cp_gemm_interface
13   USE cp_dbcsr_operations,             ONLY: copy_dbcsr_to_fm_bc,&
14                                              copy_fm_to_dbcsr_bc
15   USE cp_fm_basic_linalg,              ONLY: cp_fm_gemm
16   USE cp_fm_types,                     ONLY: cp_fm_get_info,&
17                                              cp_fm_get_mm_type,&
18                                              cp_fm_type
19   USE dbcsr_api,                       ONLY: dbcsr_multiply,&
20                                              dbcsr_release,&
21                                              dbcsr_type
22   USE input_constants,                 ONLY: do_dbcsr,&
23                                              do_pdgemm
24   USE kinds,                           ONLY: dp
25   USE message_passing,                 ONLY: mp_min
26   USE string_utilities,                ONLY: uppercase
27#include "./base/base_uses.f90"
28
29   IMPLICIT NONE
30   PRIVATE
31
32   CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'cp_gemm_interface'
33
34   PUBLIC :: cp_gemm
35
36CONTAINS
37
38! **************************************************************************************************
39!> \brief ...
40!> \param transa ...
41!> \param transb ...
42!> \param m ...
43!> \param n ...
44!> \param k ...
45!> \param alpha ...
46!> \param matrix_a ...
47!> \param matrix_b ...
48!> \param beta ...
49!> \param matrix_c ...
50!> \param a_first_col ...
51!> \param a_first_row ...
52!> \param b_first_col ...
53!> \param b_first_row ...
54!> \param c_first_col ...
55!> \param c_first_row ...
56! **************************************************************************************************
57   SUBROUTINE cp_gemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, &
58                      matrix_c, a_first_col, a_first_row, b_first_col, b_first_row, &
59                      c_first_col, c_first_row)
60      CHARACTER(LEN=1), INTENT(IN)                       :: transa, transb
61      INTEGER, INTENT(IN)                                :: m, n, k
62      REAL(KIND=dp), INTENT(IN)                          :: alpha
63      TYPE(cp_fm_type), POINTER                          :: matrix_a, matrix_b
64      REAL(KIND=dp), INTENT(IN)                          :: beta
65      TYPE(cp_fm_type), POINTER                          :: matrix_c
66      INTEGER, INTENT(IN), OPTIONAL                      :: a_first_col, a_first_row, b_first_col, &
67                                                            b_first_row, c_first_col, c_first_row
68
69      CHARACTER(len=*), PARAMETER :: routineN = 'cp_gemm', routineP = moduleN//':'//routineN
70
71      CHARACTER(LEN=1)                                   :: my_trans
72      INTEGER                                            :: handle, handle1, my_multi
73      INTEGER, DIMENSION(:), POINTER                     :: a_col_loc, a_row_loc, b_col_loc, &
74                                                            b_row_loc, c_col_loc, c_row_loc
75      TYPE(dbcsr_type)                                   :: a_db, b_db, c_db
76
77      CALL timeset(routineN, handle)
78
79      my_multi = cp_fm_get_mm_type()
80      ! catch the special case that matrices have different blocking
81      ! SCALAPACK can deal with it but dbcsr doesn't like it
82      CALL cp_fm_get_info(matrix_a, nrow_locals=a_row_loc, ncol_locals=a_col_loc)
83      CALL cp_fm_get_info(matrix_b, nrow_locals=b_row_loc, ncol_locals=b_col_loc)
84      CALL cp_fm_get_info(matrix_c, nrow_locals=c_row_loc, ncol_locals=c_col_loc)
85      IF (PRESENT(a_first_row)) my_multi = do_pdgemm
86      IF (PRESENT(a_first_col)) my_multi = do_pdgemm
87      IF (PRESENT(b_first_row)) my_multi = do_pdgemm
88      IF (PRESENT(b_first_col)) my_multi = do_pdgemm
89      IF (PRESENT(c_first_row)) my_multi = do_pdgemm
90      IF (PRESENT(c_first_col)) my_multi = do_pdgemm
91
92      my_trans = transa; CALL uppercase(my_trans)
93      IF (my_trans == 'T') THEN
94         CALL cp_fm_get_info(matrix_a, nrow_locals=a_col_loc, ncol_locals=a_row_loc)
95      END IF
96
97      my_trans = transb; CALL uppercase(my_trans)
98      IF (my_trans == 'T') THEN
99         CALL cp_fm_get_info(matrix_b, nrow_locals=b_col_loc, ncol_locals=b_row_loc)
100      END IF
101
102      IF (my_multi .NE. do_pdgemm) THEN
103         IF (SIZE(a_row_loc) == SIZE(c_row_loc)) THEN
104            IF (ANY(a_row_loc - c_row_loc .NE. 0)) my_multi = do_pdgemm
105         ELSE
106            my_multi = do_pdgemm
107         END IF
108      END IF
109      IF (my_multi .NE. do_pdgemm) THEN
110         IF (SIZE(b_col_loc) == SIZE(c_col_loc)) THEN
111            IF (ANY(b_col_loc - c_col_loc .NE. 0)) my_multi = do_pdgemm
112         ELSE
113            my_multi = do_pdgemm
114         END IF
115      END IF
116      IF (my_multi .NE. do_pdgemm) THEN
117         IF (SIZE(a_col_loc) == SIZE(b_row_loc)) THEN
118            IF (ANY(a_col_loc - b_row_loc .NE. 0)) my_multi = do_pdgemm
119         ELSE
120            my_multi = do_pdgemm
121         END IF
122      END IF
123
124      ! IMPORTANT do_pdgemm is lowest value. If one processor has it set make all do pdgemm
125      IF (cp_fm_get_mm_type() .NE. do_pdgemm) CALL mp_min(my_multi, matrix_a%matrix_struct%para_env%group)
126
127      SELECT CASE (my_multi)
128      CASE (do_pdgemm)
129         CALL timeset("cp_gemm_fm_gemm", handle1)
130         CALL cp_fm_gemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, matrix_c, &
131                         a_first_col=a_first_col, &
132                         a_first_row=a_first_row, &
133                         b_first_col=b_first_col, &
134                         b_first_row=b_first_row, &
135                         c_first_col=c_first_col, &
136                         c_first_row=c_first_row)
137         CALL timestop(handle1)
138      CASE (do_dbcsr)
139         CALL timeset("cp_gemm_dbcsr_mm", handle1)
140         CALL copy_fm_to_dbcsr_bc(matrix_a, a_db)
141         CALL copy_fm_to_dbcsr_bc(matrix_b, b_db)
142         CALL copy_fm_to_dbcsr_bc(matrix_c, c_db)
143
144         CALL dbcsr_multiply(transa, transb, alpha, a_db, b_db, beta, c_db, last_k=k)
145
146         CALL copy_dbcsr_to_fm_bc(c_db, matrix_c)
147         CALL dbcsr_release(a_db)
148         CALL dbcsr_release(b_db)
149         CALL dbcsr_release(c_db)
150         CALL timestop(handle1)
151      END SELECT
152      CALL timestop(handle)
153
154   END SUBROUTINE cp_gemm
155
156END MODULE cp_gemm_interface
157