1!--------------------------------------------------------------------------------------------------!
2! Copyright (C) by the DBCSR developers group - All rights reserved                                !
3! This file is part of the DBCSR library.                                                          !
4!                                                                                                  !
5! For information on the license, see the LICENSE file.                                            !
6! For further information please visit https://dbcsr.cp2k.org                                      !
7! SPDX-License-Identifier: GPL-2.0+                                                                !
8!--------------------------------------------------------------------------------------------------!
9
10MODULE dbcsr_test_multiply
11   !! Tests for DBCSR multiply
12   USE dbcsr_data_methods, ONLY: dbcsr_data_get_sizes, &
13                                 dbcsr_data_init, &
14                                 dbcsr_data_new, &
15                                 dbcsr_data_release, &
16                                 dbcsr_scalar_negative, &
17                                 dbcsr_scalar_one, &
18                                 dbcsr_type_1d_to_2d
19   USE dbcsr_dist_methods, ONLY: dbcsr_distribution_new, &
20                                 dbcsr_distribution_release
21   USE dbcsr_io, ONLY: dbcsr_print
22   USE dbcsr_kinds, ONLY: real_4, &
23                          real_8
24   USE dbcsr_methods, ONLY: &
25      dbcsr_col_block_offsets, dbcsr_col_block_sizes, dbcsr_get_data_type, &
26      dbcsr_get_matrix_type, dbcsr_name, dbcsr_nblkcols_total, dbcsr_nblkrows_total, &
27      dbcsr_nfullcols_total, dbcsr_nfullrows_total, dbcsr_release, dbcsr_row_block_offsets, &
28      dbcsr_row_block_sizes
29   USE dbcsr_mpiwrap, ONLY: mp_bcast, &
30                            mp_environ
31   USE dbcsr_multiply_api, ONLY: dbcsr_multiply
32   USE dbcsr_operations, ONLY: dbcsr_copy, &
33                               dbcsr_get_occupation, &
34                               dbcsr_scale
35   USE dbcsr_test_methods, ONLY: compx_to_dbcsr_scalar, &
36                                 dbcsr_impose_sparsity, &
37                                 dbcsr_make_random_block_sizes, &
38                                 dbcsr_make_random_matrix, &
39                                 dbcsr_random_dist, &
40                                 dbcsr_to_dense_local
41   USE dbcsr_transformations, ONLY: dbcsr_redistribute, &
42                                    dbcsr_replicate_all
43   USE dbcsr_types, ONLY: &
44      dbcsr_conjugate_transpose, dbcsr_data_obj, dbcsr_distribution_obj, dbcsr_mp_obj, &
45      dbcsr_no_transpose, dbcsr_scalar_type, dbcsr_transpose, dbcsr_type, &
46      dbcsr_type_antisymmetric, dbcsr_type_complex_4, dbcsr_type_complex_4_2d, &
47      dbcsr_type_complex_8, dbcsr_type_complex_8_2d, dbcsr_type_no_symmetry, dbcsr_type_real_4, &
48      dbcsr_type_real_4_2d, dbcsr_type_real_8, dbcsr_type_real_8_2d, dbcsr_type_symmetric
49   USE dbcsr_work_operations, ONLY: dbcsr_create
50#include "base/dbcsr_base_uses.f90"
51
52!$ USE OMP_LIB, ONLY: omp_get_max_threads, omp_get_thread_num, omp_get_num_threads
53
54   IMPLICIT NONE
55
56   PRIVATE
57
58   PUBLIC :: dbcsr_test_multiplies
59
60   CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'dbcsr_test_multiply'
61
62   LOGICAL, PARAMETER :: debug_mod = .FALSE.
63
64CONTAINS
65
66   SUBROUTINE dbcsr_test_multiplies(test_name, mp_group, mp_env, npdims, io_unit, &
67                                    matrix_sizes, bs_m, bs_n, bs_k, sparsities, &
68                                    alpha, beta, limits, retain_sparsity)
69      !! Performs a variety of matrix multiplies of same matrices on different
70      !! processor grids
71
72      CHARACTER(len=*), INTENT(IN)                       :: test_name
73      INTEGER, INTENT(IN)                                :: mp_group
74         !! MPI communicator
75      TYPE(dbcsr_mp_obj), INTENT(IN)                     :: mp_env
76      INTEGER, DIMENSION(2), INTENT(in)                  :: npdims
77      INTEGER, INTENT(IN)                                :: io_unit
78         !! which unit to write to, if not negative
79      INTEGER, DIMENSION(:), INTENT(in)                  :: matrix_sizes, bs_m, bs_n, bs_k
80         !! size of matrices to test
81         !! block sizes of the 3 dimension
82         !! block sizes of the 3 dimension
83         !! block sizes of the 3 dimension
84      REAL(real_8), DIMENSION(3), INTENT(in)             :: sparsities
85         !! sparsities of matrices to create
86      COMPLEX(real_8), INTENT(in)                        :: alpha, beta
87         !! alpha value to use in multiply
88         !! beta value to use in multiply
89      INTEGER, DIMENSION(6), INTENT(in)                  :: limits
90      LOGICAL, INTENT(in)                                :: retain_sparsity
91
92      CHARACTER(len=*), PARAMETER :: routineN = 'dbcsr_test_multiplies', &
93                                     routineP = moduleN//':'//routineN
94      CHARACTER, DIMENSION(3), PARAMETER :: &
95         trans = (/dbcsr_no_transpose, dbcsr_transpose, dbcsr_conjugate_transpose/)
96      CHARACTER, DIMENSION(3, 12), PARAMETER :: symmetries = &
97                                                RESHAPE((/dbcsr_type_no_symmetry, dbcsr_type_no_symmetry, &
98                                                          dbcsr_type_no_symmetry, dbcsr_type_symmetric, &
99                                                          dbcsr_type_no_symmetry, dbcsr_type_no_symmetry, &
100                                                          dbcsr_type_antisymmetric, dbcsr_type_no_symmetry, &
101                                                          dbcsr_type_no_symmetry, dbcsr_type_no_symmetry, &
102                                                          dbcsr_type_symmetric, dbcsr_type_no_symmetry, &
103                                                          dbcsr_type_symmetric, dbcsr_type_symmetric, &
104                                                          dbcsr_type_no_symmetry, dbcsr_type_antisymmetric, &
105                                                          dbcsr_type_symmetric, dbcsr_type_no_symmetry, &
106                                                          dbcsr_type_no_symmetry, dbcsr_type_antisymmetric, &
107                                                          dbcsr_type_no_symmetry, dbcsr_type_symmetric, &
108                                                          dbcsr_type_antisymmetric, dbcsr_type_no_symmetry, &
109                                                          dbcsr_type_antisymmetric, dbcsr_type_antisymmetric, &
110                                                          dbcsr_type_no_symmetry, dbcsr_type_no_symmetry, &
111                                                          dbcsr_type_no_symmetry, dbcsr_type_symmetric, &
112                                                          dbcsr_type_symmetric, dbcsr_type_symmetric, &
113                                                          dbcsr_type_symmetric, dbcsr_type_antisymmetric, &
114                                                          dbcsr_type_antisymmetric, dbcsr_type_symmetric/), (/3, 12/))
115      INTEGER, DIMENSION(4), PARAMETER :: types = (/dbcsr_type_real_4, dbcsr_type_real_8, &
116                                                    dbcsr_type_complex_4, dbcsr_type_complex_8/)
117
118      CHARACTER                                          :: a_symm, b_symm, c_symm, transa, transb
119      INTEGER                                            :: a_c, a_r, a_tr, b_c, b_r, b_tr, c_c, &
120                                                            c_r, handle, isymm, itype, mynode, &
121                                                            numnodes, numthreads, TYPE
122      INTEGER, DIMENSION(:), POINTER, CONTIGUOUS         :: my_sizes_k, my_sizes_m, my_sizes_n, &
123                                                            sizes_k, sizes_m, sizes_n
124      LOGICAL                                            :: do_complex
125      TYPE(dbcsr_data_obj)                               :: data_a, data_b, data_c, data_c_dbcsr
126      TYPE(dbcsr_scalar_type)                            :: alpha_obj, beta_obj
127      TYPE(dbcsr_type)                                   :: matrix_a, matrix_b, matrix_c
128
129!   ---------------------------------------------------------------------------
130
131      CALL timeset(routineN, handle)
132      NULLIFY (my_sizes_k, my_sizes_m, my_sizes_n, &
133               sizes_k, sizes_m, sizes_n)
134      !
135      ! print
136      CALL mp_environ(numnodes, mynode, mp_group)
137      IF (io_unit .GT. 0) THEN
138         WRITE (io_unit, *) 'test_name ', test_name
139         numthreads = 1
140!$OMP PARALLEL
141!$OMP MASTER
142!$       numthreads = omp_get_num_threads()
143!$OMP END MASTER
144!$OMP END PARALLEL
145         WRITE (io_unit, *) 'numthreads', numthreads
146         WRITE (io_unit, *) 'numnodes', numnodes
147         WRITE (io_unit, *) 'matrix_sizes', matrix_sizes
148         WRITE (io_unit, *) 'sparsities', sparsities
149         WRITE (io_unit, *) 'alpha', alpha
150         WRITE (io_unit, *) 'beta', beta
151         WRITE (io_unit, *) 'limits', limits
152         WRITE (io_unit, *) 'retain_sparsity', retain_sparsity
153         WRITE (io_unit, *) 'bs_m', bs_m
154         WRITE (io_unit, *) 'bs_n', bs_n
155         WRITE (io_unit, *) 'bs_k', bs_k
156      END IF
157      !
158      !
159      ! loop over symmetry
160      DO isymm = 1, SIZE(symmetries, 2)
161         a_symm = symmetries(1, isymm)
162         b_symm = symmetries(2, isymm)
163         c_symm = symmetries(3, isymm)
164
165         IF (a_symm .NE. dbcsr_type_no_symmetry .AND. matrix_sizes(1) .NE. matrix_sizes(3)) CYCLE
166         IF (b_symm .NE. dbcsr_type_no_symmetry .AND. matrix_sizes(2) .NE. matrix_sizes(3)) CYCLE
167         IF (c_symm .NE. dbcsr_type_no_symmetry .AND. matrix_sizes(1) .NE. matrix_sizes(2)) CYCLE
168
169         !
170         ! loop over types
171         DO itype = 1, SIZE(types)
172            TYPE = types(itype)
173
174            do_complex = TYPE .EQ. dbcsr_type_complex_4 .OR. TYPE .EQ. dbcsr_type_complex_8
175
176            alpha_obj = compx_to_dbcsr_scalar(alpha, TYPE)
177            beta_obj = compx_to_dbcsr_scalar(beta, TYPE)
178
179            IF (do_complex .AND. c_symm == dbcsr_type_symmetric) CYCLE
180
181            !
182            ! loop over transpositions
183            DO a_tr = 1, SIZE(trans)
184            DO b_tr = 1, SIZE(trans)
185               transa = trans(a_tr)
186               transb = trans(b_tr)
187
188               !
189               ! if C has a symmetry, we need special transpositions
190               IF (c_symm .NE. dbcsr_type_no_symmetry) THEN
191                  IF (.NOT. (transa .EQ. dbcsr_no_transpose .AND. transb .EQ. dbcsr_transpose .OR. &
192                             transa .EQ. dbcsr_transpose .AND. transb .EQ. dbcsr_no_transpose .OR. &
193                             transa .EQ. dbcsr_no_transpose .AND. transb .EQ. dbcsr_conjugate_transpose .AND. &
194                             .NOT. do_complex .OR. &
195                             transa .EQ. dbcsr_conjugate_transpose .AND. transb .EQ. dbcsr_no_transpose .AND. &
196                             .NOT. do_complex)) CYCLE
197               END IF
198               !
199               ! if C has symmetry and special limits
200               IF (c_symm .NE. dbcsr_type_no_symmetry) THEN
201                  IF (limits(1) .NE. 1 .OR. limits(2) .NE. matrix_sizes(1) .OR. &
202                      limits(3) .NE. 1 .OR. limits(4) .NE. matrix_sizes(2)) CYCLE
203               END IF
204
205               !
206               ! Create the row/column block sizes.
207               CALL dbcsr_make_random_block_sizes(sizes_m, matrix_sizes(1), bs_m)
208               CALL dbcsr_make_random_block_sizes(sizes_n, matrix_sizes(2), bs_n)
209               CALL dbcsr_make_random_block_sizes(sizes_k, matrix_sizes(3), bs_k)
210
211               !
212               ! if we have symmetry the row and column block sizes hae to match
213               IF (c_symm .NE. dbcsr_type_no_symmetry .AND. a_symm .NE. dbcsr_type_no_symmetry .AND. &
214                   b_symm .NE. dbcsr_type_no_symmetry) THEN
215                  my_sizes_m => sizes_m
216                  my_sizes_n => sizes_m
217                  my_sizes_k => sizes_m
218               ELSE IF ((c_symm .EQ. dbcsr_type_no_symmetry .AND. a_symm .NE. dbcsr_type_no_symmetry .AND. &
219                         b_symm .NE. dbcsr_type_no_symmetry) .OR. &
220                        (c_symm .NE. dbcsr_type_no_symmetry .AND. a_symm .EQ. dbcsr_type_no_symmetry .AND. &
221                         b_symm .NE. dbcsr_type_no_symmetry) .OR. &
222                        (c_symm .NE. dbcsr_type_no_symmetry .AND. a_symm .NE. dbcsr_type_no_symmetry .AND. &
223                         b_symm .EQ. dbcsr_type_no_symmetry)) THEN
224                  my_sizes_m => sizes_m
225                  my_sizes_n => sizes_m
226                  my_sizes_k => sizes_m
227               ELSE IF (c_symm .EQ. dbcsr_type_no_symmetry .AND. a_symm .EQ. dbcsr_type_no_symmetry .AND. &
228                        b_symm .NE. dbcsr_type_no_symmetry) THEN
229                  my_sizes_m => sizes_m
230                  my_sizes_n => sizes_n
231                  my_sizes_k => sizes_n
232               ELSE IF (c_symm .EQ. dbcsr_type_no_symmetry .AND. a_symm .NE. dbcsr_type_no_symmetry .AND. &
233                        b_symm .EQ. dbcsr_type_no_symmetry) THEN
234                  my_sizes_m => sizes_m
235                  my_sizes_n => sizes_n
236                  my_sizes_k => sizes_m
237               ELSE IF (c_symm .NE. dbcsr_type_no_symmetry .AND. a_symm .EQ. dbcsr_type_no_symmetry .AND. &
238                        b_symm .EQ. dbcsr_type_no_symmetry) THEN
239                  my_sizes_m => sizes_m
240                  my_sizes_n => sizes_m
241                  my_sizes_k => sizes_k
242               ELSE IF (c_symm .EQ. dbcsr_type_no_symmetry .AND. a_symm .EQ. dbcsr_type_no_symmetry .AND. &
243                        b_symm .EQ. dbcsr_type_no_symmetry) THEN
244                  my_sizes_m => sizes_m
245                  my_sizes_n => sizes_n
246                  my_sizes_k => sizes_k
247               ELSE
248                  CALL dbcsr_abort(__LOCATION__, &
249                                   "something wrong here... ")
250               END IF
251
252               IF (.FALSE.) THEN
253                  WRITE (*, *) 'sizes_m', my_sizes_m
254                  WRITE (*, *) 'sum(sizes_m)', SUM(my_sizes_m), ' matrix_sizes(1)', matrix_sizes(1)
255                  WRITE (*, *) 'sizes_n', my_sizes_n
256                  WRITE (*, *) 'sum(sizes_n)', SUM(my_sizes_n), ' matrix_sizes(2)', matrix_sizes(2)
257                  WRITE (*, *) 'sizes_k', my_sizes_k
258                  WRITE (*, *) 'sum(sizes_k)', SUM(my_sizes_k), ' matrix_sizes(3)', matrix_sizes(3)
259               END IF
260
261               !
262               ! Create the undistributed matrices.
263               CALL dbcsr_make_random_matrix(matrix_c, my_sizes_m, my_sizes_n, "Matrix C", &
264                                             sparsities(3), &
265                                             mp_group, data_type=TYPE, symmetry=c_symm)
266
267               IF (transa .NE. dbcsr_no_transpose) THEN
268                  CALL dbcsr_make_random_matrix(matrix_a, my_sizes_k, my_sizes_m, "Matrix A", &
269                                                sparsities(1), &
270                                                mp_group, data_type=TYPE, symmetry=a_symm)
271               ELSE
272                  CALL dbcsr_make_random_matrix(matrix_a, my_sizes_m, my_sizes_k, "Matrix A", &
273                                                sparsities(1), &
274                                                mp_group, data_type=TYPE, symmetry=a_symm)
275               END IF
276               IF (transb .NE. dbcsr_no_transpose) THEN
277                  CALL dbcsr_make_random_matrix(matrix_b, my_sizes_n, my_sizes_k, "Matrix B", &
278                                                sparsities(2), &
279                                                mp_group, data_type=TYPE, symmetry=b_symm)
280               ELSE
281                  CALL dbcsr_make_random_matrix(matrix_b, my_sizes_k, my_sizes_n, "Matrix B", &
282                                                sparsities(2), &
283                                                mp_group, data_type=TYPE, symmetry=b_symm)
284               END IF
285
286               DEALLOCATE (sizes_m, sizes_n, sizes_k)
287
288               !
289               ! if C has a symmetry, we build it accordingly, i.e. C=A*A and C=A*(-A)
290               IF (c_symm .NE. dbcsr_type_no_symmetry) THEN
291                  CALL dbcsr_copy(matrix_b, matrix_a)
292                  !print*, a_symm,b_symm,dbcsr_get_matrix_type(matrix_a),dbcsr_get_matrix_type(matrix_b)
293                  IF (c_symm .EQ. dbcsr_type_antisymmetric) THEN
294                     CALL dbcsr_scale(matrix_b, &
295                                      alpha_scalar=dbcsr_scalar_negative( &
296                                      dbcsr_scalar_one(TYPE)))
297                  END IF
298               END IF
299
300               !
301               ! convert the dbcsr matrices to denses
302               a_r = dbcsr_nfullrows_total(matrix_a)
303               a_c = dbcsr_nfullcols_total(matrix_a)
304               b_r = dbcsr_nfullrows_total(matrix_b)
305               b_c = dbcsr_nfullcols_total(matrix_b)
306               c_r = dbcsr_nfullrows_total(matrix_c)
307               c_c = dbcsr_nfullcols_total(matrix_c)
308               CALL dbcsr_data_init(data_a)
309               CALL dbcsr_data_init(data_b)
310               CALL dbcsr_data_init(data_c)
311               CALL dbcsr_data_init(data_c_dbcsr)
312               CALL dbcsr_data_new(data_a, dbcsr_type_1d_to_2d(TYPE), data_size=a_r, data_size2=a_c)
313               CALL dbcsr_data_new(data_b, dbcsr_type_1d_to_2d(TYPE), data_size=b_r, data_size2=b_c)
314               CALL dbcsr_data_new(data_c, dbcsr_type_1d_to_2d(TYPE), data_size=c_r, data_size2=c_c)
315               CALL dbcsr_data_new(data_c_dbcsr, dbcsr_type_1d_to_2d(TYPE), data_size=c_r, data_size2=c_c)
316               CALL dbcsr_to_dense_local(matrix_a, data_a)
317               CALL dbcsr_to_dense_local(matrix_b, data_b)
318               CALL dbcsr_to_dense_local(matrix_c, data_c)
319
320               !
321               ! Prepare test parameters
322               CALL test_multiply(test_name, mp_group, mp_env, npdims, io_unit, &
323                                  matrix_a, matrix_b, matrix_c, &
324                                  data_a, data_b, data_c, data_c_dbcsr, &
325                                  transa, transb, &
326                                  alpha_obj, beta_obj, &
327                                  limits, retain_sparsity)
328               !
329               ! cleanup
330               CALL dbcsr_release(matrix_a)
331               CALL dbcsr_release(matrix_b)
332               CALL dbcsr_release(matrix_c)
333               CALL dbcsr_data_release(data_a)
334               CALL dbcsr_data_release(data_b)
335               CALL dbcsr_data_release(data_c)
336               CALL dbcsr_data_release(data_c_dbcsr)
337
338            END DO
339            END DO
340
341         END DO ! itype
342
343      END DO !isymm
344
345      CALL timestop(handle)
346
347   END SUBROUTINE dbcsr_test_multiplies
348
349   SUBROUTINE test_multiply(test_name, mp_group, mp_env, npdims, io_unit, &
350                            matrix_a, matrix_b, matrix_c, &
351                            data_a, data_b, data_c, data_c_dbcsr, &
352                            transa, transb, alpha, beta, limits, retain_sparsity)
353      !! Performs a variety of matrix multiplies of same matrices on different
354      !! processor grids
355
356      CHARACTER(len=*), INTENT(IN)                       :: test_name
357      INTEGER, INTENT(IN)                                :: mp_group
358         !! MPI communicator
359      TYPE(dbcsr_mp_obj), INTENT(IN)                     :: mp_env
360      INTEGER, DIMENSION(2), INTENT(in)                  :: npdims
361      INTEGER, INTENT(IN)                                :: io_unit
362         !! which unit to write to, if not negative
363      TYPE(dbcsr_type), INTENT(in)                       :: matrix_a, matrix_b, matrix_c
364         !! matrices to multiply
365         !! matrices to multiply
366         !! matrices to multiply
367      TYPE(dbcsr_data_obj)                               :: data_a, data_b, data_c, data_c_dbcsr
368      CHARACTER, INTENT(in)                              :: transa, transb
369      TYPE(dbcsr_scalar_type), INTENT(in)                :: alpha, beta
370      INTEGER, DIMENSION(6), INTENT(in)                  :: limits
371      LOGICAL, INTENT(in)                                :: retain_sparsity
372
373      CHARACTER(len=*), PARAMETER :: routineN = 'test_multiply', routineP = moduleN//':'//routineN
374
375      INTEGER                                            :: c_a, c_b, c_c, handle, r_a, r_b, r_c
376      INTEGER, DIMENSION(:), POINTER, CONTIGUOUS         :: blk_offsets, col_dist_a, col_dist_b, &
377                                                            col_dist_c, row_dist_a, row_dist_b, &
378                                                            row_dist_c
379      LOGICAL                                            :: success
380      REAL(real_8)                                       :: occ_a, occ_b, occ_c_in, occ_c_out
381      TYPE(dbcsr_distribution_obj)                       :: dist_a, dist_b, dist_c
382      TYPE(dbcsr_type)                                   :: m_a, m_b, m_c
383
384!   ---------------------------------------------------------------------------
385
386      CALL timeset(routineN, handle)
387      NULLIFY (row_dist_a, col_dist_a, &
388               row_dist_b, col_dist_b, &
389               row_dist_c, col_dist_c)
390
391      IF (debug_mod .AND. io_unit .GT. 0) THEN
392         WRITE (io_unit, *) REPEAT("*", 70)
393         WRITE (io_unit, *) " -- TESTING dbcsr_multiply (", transa, ", ", transb, &
394            ", ", dbcsr_get_data_type(m_a), &
395            ", ", dbcsr_get_matrix_type(m_a), &
396            ", ", dbcsr_get_matrix_type(m_b), &
397            ", ", dbcsr_get_matrix_type(m_c), &
398            ") ............... !"
399         WRITE (io_unit, *) REPEAT("*", 70)
400      END IF
401
402      ! Row & column distributions
403      CALL dbcsr_random_dist(row_dist_a, dbcsr_nblkrows_total(matrix_a), npdims(1))
404      CALL dbcsr_random_dist(col_dist_a, dbcsr_nblkcols_total(matrix_a), npdims(2))
405      CALL dbcsr_random_dist(row_dist_b, dbcsr_nblkrows_total(matrix_b), npdims(1))
406      CALL dbcsr_random_dist(col_dist_b, dbcsr_nblkcols_total(matrix_b), npdims(2))
407      CALL dbcsr_random_dist(row_dist_c, dbcsr_nblkrows_total(matrix_c), npdims(1))
408      CALL dbcsr_random_dist(col_dist_c, dbcsr_nblkcols_total(matrix_c), npdims(2))
409      CALL dbcsr_distribution_new(dist_a, mp_env, row_dist_a, col_dist_a, reuse_arrays=.TRUE.)
410      CALL dbcsr_distribution_new(dist_b, mp_env, row_dist_b, col_dist_b, reuse_arrays=.TRUE.)
411      CALL dbcsr_distribution_new(dist_c, mp_env, row_dist_c, col_dist_c, reuse_arrays=.TRUE.)
412      ! Redistribute the matrices
413      ! A
414      CALL dbcsr_create(m_a, "Test for "//TRIM(dbcsr_name(matrix_a)), &
415                        dist_a, dbcsr_get_matrix_type(matrix_a), &
416                        row_blk_size_obj=matrix_a%row_blk_size, &
417                        col_blk_size_obj=matrix_a%col_blk_size, &
418                        data_type=dbcsr_get_data_type(matrix_a))
419      CALL dbcsr_distribution_release(dist_a)
420      CALL dbcsr_redistribute(matrix_a, m_a)
421      ! B
422      CALL dbcsr_create(m_b, "Test for "//TRIM(dbcsr_name(matrix_b)), &
423                        dist_b, dbcsr_get_matrix_type(matrix_b), &
424                        row_blk_size_obj=matrix_b%row_blk_size, &
425                        col_blk_size_obj=matrix_b%col_blk_size, &
426                        data_type=dbcsr_get_data_type(matrix_b))
427      CALL dbcsr_distribution_release(dist_b)
428      CALL dbcsr_redistribute(matrix_b, m_b)
429      ! C
430      CALL dbcsr_create(m_c, "Test for "//TRIM(dbcsr_name(matrix_c)), &
431                        dist_c, dbcsr_get_matrix_type(matrix_c), &
432                        row_blk_size_obj=matrix_c%row_blk_size, &
433                        col_blk_size_obj=matrix_c%col_blk_size, &
434                        data_type=dbcsr_get_data_type(matrix_c))
435      CALL dbcsr_distribution_release(dist_c)
436      CALL dbcsr_redistribute(matrix_c, m_c)
437
438      IF (.FALSE.) THEN
439         blk_offsets => dbcsr_row_block_offsets(matrix_c)
440         WRITE (*, *) 'row_block_offsets(matrix_c)', blk_offsets
441         blk_offsets => dbcsr_col_block_offsets(matrix_c)
442         WRITE (*, *) 'col_block_offsets(matrix_c)', blk_offsets
443      END IF
444
445      IF (.FALSE.) THEN
446         CALL dbcsr_print(m_c, matlab_format=.FALSE., variable_name='c_in_')
447         CALL dbcsr_print(m_a, matlab_format=.FALSE., variable_name='a_')
448         CALL dbcsr_print(m_b, matlab_format=.FALSE., variable_name='b_')
449         CALL dbcsr_print(m_c, matlab_format=.FALSE., variable_name='c_out_')
450      END IF
451
452      occ_a = dbcsr_get_occupation(m_a)
453      occ_b = dbcsr_get_occupation(m_b)
454      occ_c_in = dbcsr_get_occupation(m_c)
455
456      !
457      ! Perform multiply
458      IF (ALL(limits == 0)) THEN
459         DBCSR_ABORT("limits shouldnt be 0")
460      ELSE
461         CALL dbcsr_multiply(transa, transb, alpha, &
462                             m_a, m_b, beta, m_c, &
463                             first_row=limits(1), &
464                             last_row=limits(2), &
465                             first_column=limits(3), &
466                             last_column=limits(4), &
467                             first_k=limits(5), &
468                             last_k=limits(6), &
469                             retain_sparsity=retain_sparsity)
470      END IF
471
472      occ_c_out = dbcsr_get_occupation(m_c)
473
474      IF (.FALSE.) THEN
475         PRINT *, 'retain_sparsity', retain_sparsity, occ_a, occ_b, occ_c_in, occ_c_out
476         CALL dbcsr_print(m_a, matlab_format=.TRUE., variable_name='a_')
477         CALL dbcsr_print(m_b, matlab_format=.TRUE., variable_name='b_')
478         CALL dbcsr_print(m_c, matlab_format=.FALSE., variable_name='c_out_')
479      END IF
480
481      CALL dbcsr_replicate_all(m_c)
482      CALL dbcsr_to_dense_local(m_c, data_c_dbcsr)
483      CALL dbcsr_check_multiply(test_name, m_c, data_c_dbcsr, data_a, data_b, data_c, &
484                                transa, transb, alpha, beta, limits, retain_sparsity, io_unit, mp_group, &
485                                success)
486
487      r_a = dbcsr_nfullrows_total(m_a)
488      c_a = dbcsr_nfullcols_total(m_a)
489      r_b = dbcsr_nfullrows_total(m_b)
490      c_b = dbcsr_nfullcols_total(m_b)
491      r_c = dbcsr_nfullrows_total(m_c)
492      c_c = dbcsr_nfullcols_total(m_c)
493      IF (io_unit .GT. 0) THEN
494         IF (success) THEN
495            WRITE (io_unit, *) REPEAT("*", 70)
496            WRITE (io_unit, *) " -- TESTING dbcsr_multiply (", transa, ", ", transb, &
497               ", ", dbcsr_get_data_type(m_a), &
498               ", ", dbcsr_get_matrix_type(m_a), &
499               ", ", dbcsr_get_matrix_type(m_b), &
500               ", ", dbcsr_get_matrix_type(m_c), &
501               ") ............... PASSED !"
502            WRITE (io_unit, *) REPEAT("*", 70)
503         ELSE
504            WRITE (io_unit, *) REPEAT("*", 70)
505            WRITE (io_unit, *) " -- TESTING dbcsr_multiply (", transa, ", ", transb, &
506               ", ", dbcsr_get_data_type(m_a), &
507               ", ", dbcsr_get_matrix_type(m_a), &
508               ", ", dbcsr_get_matrix_type(m_b), &
509               ", ", dbcsr_get_matrix_type(m_c), &
510               ") ... FAILED !"
511            WRITE (io_unit, *) REPEAT("*", 70)
512            DBCSR_ABORT('Test failed')
513         END IF
514      END IF
515
516      CALL dbcsr_release(m_a)
517      CALL dbcsr_release(m_b)
518      CALL dbcsr_release(m_c)
519
520      CALL timestop(handle)
521
522   END SUBROUTINE test_multiply
523
524   SUBROUTINE dbcsr_check_multiply(test_name, matrix_c, dense_c_dbcsr, dense_a, dense_b, dense_c, &
525                                   transa, transb, alpha, beta, limits, retain_sparsity, io_unit, mp_group, &
526                                   success)
527      !! Performs a check of matrix multiplies
528
529      CHARACTER(len=*), INTENT(IN)                       :: test_name
530      TYPE(dbcsr_type), INTENT(IN)                       :: matrix_c
531      TYPE(dbcsr_data_obj), INTENT(inout)                :: dense_c_dbcsr, dense_a, dense_b, dense_c
532         !! dense result of the dbcsr_multiply
533         !! input dense matrices
534         !! input dense matrices
535         !! input dense matrices
536      CHARACTER, INTENT(in)                              :: transa, transb
537         !! transposition status
538         !! transposition status
539      TYPE(dbcsr_scalar_type), INTENT(in)                :: alpha, beta
540         !! coefficients for the gemm
541         !! coefficients for the gemm
542      INTEGER, DIMENSION(6), INTENT(in)                  :: limits
543         !! limits for the gemm
544      LOGICAL, INTENT(in)                                :: retain_sparsity
545      INTEGER, INTENT(IN)                                :: io_unit, mp_group
546         !! io unit for printing
547      LOGICAL, INTENT(out)                               :: success
548         !! if passed the check success=T
549
550      CHARACTER(len=*), PARAMETER :: routineN = 'dbcsr_check_multiply', &
551                                     routineP = moduleN//':'//routineN
552      INTEGER :: a_col, a_m, a_n, a_row, b_col, b_m, b_n, b_row, c_col, c_col_size, c_row, &
553                 c_row_size, handle, i, istat, j, k, lda, ldb, ldc, lwork, m, mynode, n, numnodes
554      CHARACTER, PARAMETER                               :: norm = 'I'
555
556      LOGICAL                                            :: valid
557      REAL(real_4), ALLOCATABLE, DIMENSION(:)            :: work_sp
558#if defined (__ACCELERATE)
559      REAL(real_8), EXTERNAL                             :: clange, slamch, slange
560#else
561      REAL(real_4), EXTERNAL                             :: clange, slamch, slange
562#endif
563      REAL(real_8)                                       :: a_norm, b_norm, c_norm_dbcsr, c_norm_in, &
564                                                            c_norm_out, eps, residual
565      REAL(real_8), ALLOCATABLE, DIMENSION(:)            :: work
566      REAL(real_8), EXTERNAL                             :: dlamch, dlange, zlange
567
568      CALL timeset(routineN, handle)
569
570      CALL mp_environ(numnodes, mynode, mp_group)
571
572      CALL dbcsr_data_get_sizes(dense_c, c_row_size, c_col_size, valid)
573      IF (.NOT. valid) &
574         DBCSR_ABORT("dense matrix not valid")
575      CALL dbcsr_data_get_sizes(dense_c, ldc, i, valid)
576      IF (.NOT. valid) &
577         DBCSR_ABORT("dense matrix not valid")
578      CALL dbcsr_data_get_sizes(dense_a, lda, i, valid)
579      IF (.NOT. valid) &
580         DBCSR_ABORT("dense matrix not valid")
581      CALL dbcsr_data_get_sizes(dense_b, ldb, i, valid)
582      IF (.NOT. valid) &
583         DBCSR_ABORT("dense matrix not valid")
584      !
585      !
586      m = limits(2) - limits(1) + 1
587      n = limits(4) - limits(3) + 1
588      k = limits(6) - limits(5) + 1
589      a_row = limits(1); a_col = limits(5)
590      b_row = limits(5); b_col = limits(3)
591      c_row = limits(1); c_col = limits(3)
592      !
593      !
594      IF (transA == dbcsr_no_transpose) THEN
595         a_m = m
596         a_n = k
597      ELSE
598         a_m = k
599         a_n = m
600         i = a_row
601         a_row = a_col
602         a_col = i
603      END IF
604      IF (transB == dbcsr_no_transpose) THEN
605         b_m = k
606         b_n = n
607      ELSE
608         b_m = n
609         b_n = k
610         i = b_row
611         b_row = b_col
612         b_col = i
613      END IF
614      !
615      ! set the size of the work array
616      lwork = MAXVAL((/lda, ldb, ldc/))
617      !
618      !
619      SELECT CASE (dense_a%d%data_type)
620      CASE (dbcsr_type_real_8_2d)
621         ALLOCATE (work(lwork), STAT=istat)
622         IF (istat /= 0) &
623            DBCSR_ABORT("allocation problem")
624         eps = dlamch('eps')
625         a_norm = dlange(norm, a_m, a_n, dense_a%d%r2_dp(a_row, a_col), lda, work)
626         b_norm = dlange(norm, b_m, b_n, dense_b%d%r2_dp(b_row, b_col), ldb, work)
627         c_norm_in = dlange(norm, c_row_size, c_col_size, dense_c%d%r2_dp(1, 1), ldc, work)
628         c_norm_dbcsr = dlange(norm, c_row_size, c_col_size, dense_c_dbcsr%d%r2_dp(1, 1), ldc, work)
629         !
630         CALL dgemm(transa, transb, m, n, k, alpha%r_dp, dense_a%d%r2_dp(a_row, a_col), lda, &
631                    dense_b%d%r2_dp(b_row, b_col), ldb, beta%r_dp, dense_c%d%r2_dp(c_row, c_col), ldc)
632         !
633         ! impose the sparsity if needed
634         IF (retain_sparsity) CALL dbcsr_impose_sparsity(matrix_c, dense_c)
635         !
636         c_norm_out = dlange(norm, m, n, dense_c%d%r2_dp(c_row, c_col), ldc, work)
637         !
638         ! take the difference dense/sparse
639         dense_c%d%r2_dp = dense_c%d%r2_dp - dense_c_dbcsr%d%r2_dp
640         !
641         ! compute the residual
642         residual = dlange(norm, c_row_size, c_col_size, dense_c%d%r2_dp(1, 1), ldc, work)
643         DEALLOCATE (work)
644      CASE (dbcsr_type_real_4_2d)
645         ALLOCATE (work_sp(lwork), STAT=istat)
646         IF (istat /= 0) &
647            DBCSR_ABORT("allocation problem")
648         eps = REAL(slamch('eps'), real_8)
649         a_norm = slange(norm, a_m, a_n, dense_a%d%r2_sp(a_row, a_col), lda, work_sp)
650         b_norm = slange(norm, b_m, b_n, dense_b%d%r2_sp(b_row, b_col), ldb, work_sp)
651         c_norm_in = slange(norm, c_row_size, c_col_size, dense_c%d%r2_sp(1, 1), ldc, work_sp)
652         c_norm_dbcsr = slange(norm, c_row_size, c_col_size, dense_c_dbcsr%d%r2_sp(1, 1), ldc, work_sp)
653         !
654
655         IF (.FALSE.) THEN
656            !IF (io_unit .GT. 0) THEN
657            DO j = 1, SIZE(dense_a%d%r2_sp, 2)
658               DO i = 1, SIZE(dense_a%d%r2_sp, 1)
659                  WRITE (*, '(A,I3,A,I3,A,E15.7,A)') 'a(', i, ',', j, ')=', dense_a%d%r2_sp(i, j), ';'
660               END DO
661            END DO
662            DO j = 1, SIZE(dense_b%d%r2_sp, 2)
663               DO i = 1, SIZE(dense_b%d%r2_sp, 1)
664                  WRITE (*, '(A,I3,A,I3,A,E15.7,A)') 'b(', i, ',', j, ')=', dense_b%d%r2_sp(i, j), ';'
665               END DO
666            END DO
667            DO j = 1, SIZE(dense_c%d%r2_sp, 2)
668               DO i = 1, SIZE(dense_c%d%r2_sp, 1)
669                  WRITE (*, '(A,I3,A,I3,A,E15.7,A)') 'c_in(', i, ',', j, ')=', dense_c%d%r2_sp(i, j), ';'
670               END DO
671            END DO
672         END IF
673
674         CALL sgemm(transa, transb, m, n, k, alpha%r_sp, dense_a%d%r2_sp(a_row, a_col), lda, &
675                    dense_b%d%r2_sp(b_row, b_col), ldb, beta%r_sp, dense_c%d%r2_sp(c_row, c_col), ldc)
676         !
677         ! impose the sparsity if needed
678         IF (retain_sparsity) CALL dbcsr_impose_sparsity(matrix_c, dense_c)
679
680         IF (.FALSE.) THEN
681            !IF (io_unit .GT. 0) THEN
682            DO j = 1, SIZE(dense_c%d%r2_sp, 2)
683               DO i = 1, SIZE(dense_c%d%r2_sp, 1)
684                  WRITE (*, '(A,I3,A,I3,A,E15.7,A)') 'c_out(', i, ',', j, ')=', dense_c%d%r2_sp(i, j), ';'
685               END DO
686            END DO
687            DO j = 1, SIZE(dense_c_dbcsr%d%r2_sp, 2)
688               DO i = 1, SIZE(dense_c_dbcsr%d%r2_sp, 1)
689                  WRITE (*, '(A,I3,A,I3,A,E15.7,A)') 'c_dbcsr(', i, ',', j, ')=', dense_c_dbcsr%d%r2_sp(i, j), ';'
690               END DO
691            END DO
692         END IF
693         !
694         c_norm_out = slange(norm, m, n, dense_c%d%r2_sp(c_row, c_col), ldc, work_sp)
695         !
696         ! take the difference dense/sparse
697         dense_c%d%r2_sp = dense_c%d%r2_sp - dense_c_dbcsr%d%r2_sp
698         !
699         ! compute the residual
700         residual = REAL(slange(norm, c_row_size, c_col_size, dense_c%d%r2_sp(1, 1), ldc, work_sp), real_8)
701         DEALLOCATE (work_sp)
702      CASE (dbcsr_type_complex_8_2d)
703         ALLOCATE (work(lwork), STAT=istat)
704         IF (istat /= 0) &
705            DBCSR_ABORT("allocation problem")
706         eps = dlamch('eps')
707         a_norm = zlange(norm, a_m, a_n, dense_a%d%c2_dp(a_row, a_col), lda, work)
708         b_norm = zlange(norm, b_m, b_n, dense_b%d%c2_dp(b_row, b_col), ldb, work)
709         c_norm_in = zlange(norm, c_row_size, c_col_size, dense_c%d%c2_dp(1, 1), ldc, work)
710         c_norm_dbcsr = zlange(norm, c_row_size, c_col_size, dense_c_dbcsr%d%c2_dp(1, 1), ldc, work)
711         !
712         CALL zgemm(transa, transb, m, n, k, alpha%c_dp, dense_a%d%c2_dp(a_row, a_col), lda, &
713                    dense_b%d%c2_dp(b_row, b_col), ldb, beta%c_dp, dense_c%d%c2_dp(c_row, c_col), ldc)
714         !
715         ! impose the sparsity if needed
716         IF (retain_sparsity) CALL dbcsr_impose_sparsity(matrix_c, dense_c)
717         !
718         c_norm_out = zlange(norm, m, n, dense_c%d%c2_dp(c_row, c_col), ldc, work)
719         !
720         ! take the difference dense/sparse
721         dense_c%d%c2_dp = dense_c%d%c2_dp - dense_c_dbcsr%d%c2_dp
722         !
723         ! compute the residual
724         residual = zlange(norm, c_row_size, c_col_size, dense_c%d%c2_dp(1, 1), ldc, work)
725         DEALLOCATE (work)
726      CASE (dbcsr_type_complex_4_2d)
727         ALLOCATE (work_sp(lwork), STAT=istat)
728         IF (istat /= 0) &
729            DBCSR_ABORT("allocation problem")
730         eps = REAL(slamch('eps'), real_8)
731         a_norm = clange(norm, a_m, a_n, dense_a%d%c2_sp(a_row, a_col), lda, work_sp)
732         b_norm = clange(norm, b_m, b_n, dense_b%d%c2_sp(b_row, b_col), ldb, work_sp)
733         c_norm_in = clange(norm, c_row_size, c_col_size, dense_c%d%c2_sp(1, 1), ldc, work_sp)
734         c_norm_dbcsr = clange(norm, c_row_size, c_col_size, dense_c_dbcsr%d%c2_sp(1, 1), ldc, work_sp)
735         !
736         CALL cgemm(transa, transb, m, n, k, alpha%c_sp, dense_a%d%c2_sp(a_row, a_col), lda, &
737                    dense_b%d%c2_sp(b_row, b_col), ldb, beta%c_sp, dense_c%d%c2_sp(c_row, c_col), ldc)
738         !
739         ! impose the sparsity if needed
740         IF (retain_sparsity) CALL dbcsr_impose_sparsity(matrix_c, dense_c)
741         !
742         c_norm_out = clange(norm, m, n, dense_c%d%c2_sp(c_row, c_col), ldc, work_sp)
743         !
744         ! take the difference dense/sparse
745         dense_c%d%c2_sp = dense_c%d%c2_sp - dense_c_dbcsr%d%c2_sp
746         !
747         ! compute the residual
748         residual = clange(norm, c_row_size, c_col_size, dense_c%d%c2_sp(1, 1), ldc, work_sp)
749         DEALLOCATE (work_sp)
750      CASE default
751         DBCSR_ABORT("Incorrect or 1-D data type")
752      END SELECT
753
754      IF (mynode .EQ. 0) THEN
755         IF (residual/((a_norm + b_norm + c_norm_in)*REAL(n, real_8)*eps) .GT. 10.0_real_8) THEN
756            success = .FALSE.
757         ELSE
758            success = .TRUE.
759         END IF
760      END IF
761      !
762      ! synchronize the result...
763      CALL mp_bcast(success, 0, mp_group)
764      !
765      ! printing
766      IF (io_unit .GT. 0) THEN
767         WRITE (io_unit, *) 'test_name ', test_name
768         !
769         ! check for nan or inf here
770         IF (success) THEN
771            WRITE (io_unit, '(A)') ' The solution is CORRECT !'
772         ELSE
773            WRITE (io_unit, '(A)') ' The solution is suspicious !'
774
775            WRITE (io_unit, '(3(A,E12.5))') ' residual ', residual, ', a_norm ', a_norm, ', b_norm ', b_norm
776            WRITE (io_unit, '(3(A,E12.5))') ' c_norm_in ', c_norm_in, ', c_norm_out ', c_norm_out, &
777               ', c_norm_dbcsr ', c_norm_dbcsr
778            WRITE (io_unit, '(A)') ' Checking the norm of the difference against reference GEMM '
779            WRITE (io_unit, '(A,E12.5)') ' -- ||C_dbcsr-C_dense||_oo/((||A||_oo+||B||_oo+||C||_oo).N.eps)=', &
780               residual/((a_norm + b_norm + c_norm_in)*n*eps)
781         END IF
782
783      END IF
784
785      CALL timestop(handle)
786
787   END SUBROUTINE dbcsr_check_multiply
788
789END MODULE dbcsr_test_multiply
790