1!--------------------------------------------------------------------------------------------------!
2! Copyright (C) by the DBCSR developers group - All rights reserved                                !
3! This file is part of the DBCSR library.                                                          !
4!                                                                                                  !
5! For information on the license, see the LICENSE file.                                            !
6! For further information please visit https://dbcsr.cp2k.org                                      !
7! SPDX-License-Identifier: GPL-2.0+                                                                !
8!--------------------------------------------------------------------------------------------------!
9
10MODULE dbcsr_tas_mm
11   !! Matrix multiplication for tall-and-skinny matrices. This uses the k-split (non-recursive) CARMA
12   !! algorithm that is communication-optimal as long as the two smaller dimensions have
13   !! the same size.
14   !! Submatrices are obtained by splitting a dimension of the process grid. Multiplication of
15   !! submatrices uses DBCSR Cannon algorithm. Due to unknown sparsity pattern of result matrix, parameters
16   !! (group sizes and process grid dimensions) can not be derived from matrix dimensions and need to be
17   !! set manually.
18
19#:include "dbcsr_tas.fypp"
20
21   USE dbcsr_data_methods, ONLY: &
22      dbcsr_scalar_zero, dbcsr_scalar
23   USE dbcsr_data_types, ONLY: &
24      dbcsr_scalar_type, dbcsr_type_real_8, dbcsr_type_real_4, dbcsr_type_complex_8, dbcsr_type_complex_4
25   USE dbcsr_multiply_api, ONLY: dbcsr_multiply
26   USE dbcsr_tas_base, ONLY: &
27      dbcsr_tas_create, dbcsr_tas_destroy, dbcsr_tas_distribution_destroy, dbcsr_tas_distribution_new, &
28      dbcsr_tas_get_data_type, dbcsr_tas_info, dbcsr_tas_nblkcols_total, &
29      dbcsr_tas_nblkrows_total, dbcsr_tas_filter, dbcsr_tas_get_info, dbcsr_tas_iterator_blocks_left, &
30      dbcsr_tas_get_nze_total, dbcsr_tas_reserve_blocks, dbcsr_tas_iterator_start, dbcsr_tas_iterator_next_block, &
31      dbcsr_tas_iterator_stop, dbcsr_tas_copy, dbcsr_tas_get_block_p, dbcsr_tas_clear, dbcsr_tas_get_num_blocks
32   USE dbcsr_tas_types, ONLY: &
33      dbcsr_tas_distribution_type, dbcsr_tas_split_info, dbcsr_tas_type, dbcsr_tas_iterator
34   USE dbcsr_tas_global, ONLY: &
35      dbcsr_tas_dist_cyclic, dbcsr_tas_dist_arb, dbcsr_tas_distribution, dbcsr_tas_dist_arb_default, &
36      dbcsr_tas_rowcol_data, dbcsr_tas_blk_size_one, dbcsr_tas_default_distvec
37   USE dbcsr_tas_reshape_ops, ONLY: &
38      dbcsr_tas_merge, dbcsr_tas_replicate, dbcsr_tas_reshape
39   USE dbcsr_tas_split, ONLY: &
40      rowsplit, colsplit, dbcsr_tas_get_split_info, dbcsr_tas_create_split, dbcsr_tas_mp_comm, &
41      dbcsr_tas_release_info, accept_pgrid_dims, dbcsr_tas_info_hold, default_nsplit_accept_ratio_1
42   USE dbcsr_tas_util, ONLY: &
43      swap, invert_transpose_flag, array_eq, dbcsr_mp_environ
44   USE dbcsr_types, ONLY: &
45      dbcsr_no_transpose, dbcsr_transpose, dbcsr_type, dbcsr_distribution_obj, dbcsr_mp_obj, &
46      dbcsr_type_no_symmetry
47   USE dbcsr_kinds, ONLY: &
48      int_8, real_8, real_4, default_string_length
49   USE dbcsr_mpiwrap, ONLY: &
50      mp_comm_compare, mp_environ, mp_sum, mp_comm_free, mp_cart_create
51   USE dbcsr_operations, ONLY: &
52      dbcsr_scale, dbcsr_get_info, dbcsr_copy, dbcsr_clear, dbcsr_add
53   USE dbcsr_tas_io, ONLY: &
54      dbcsr_tas_write_dist, dbcsr_tas_write_matrix_info, dbcsr_tas_write_split_info, prep_output_unit
55   USE dbcsr_work_operations, ONLY: dbcsr_create, dbcsr_finalize
56   USE dbcsr_transformations, ONLY: dbcsr_redistribute
57   USE dbcsr_dist_methods, ONLY: dbcsr_distribution_new
58   USE dbcsr_methods, ONLY: &
59      dbcsr_mp_release, dbcsr_release, dbcsr_distribution_release, dbcsr_get_nze
60#include "base/dbcsr_base_uses.f90"
61
62   IMPLICIT NONE
63   PRIVATE
64
65   CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'dbcsr_tas_mm'
66
67   PUBLIC :: &
68      dbcsr_tas_multiply, &
69      dbcsr_tas_batched_mm_init, &
70      dbcsr_tas_batched_mm_finalize, &
71      dbcsr_tas_result_index
72
73CONTAINS
74
75   RECURSIVE SUBROUTINE dbcsr_tas_multiply(transa, transb, transc, alpha, matrix_a, matrix_b, beta, matrix_c, &
76                                           optimize_dist, split_opt, filter_eps, flop, move_data_a, &
77                                           move_data_b, retain_sparsity, simple_split, result_index, unit_nr, log_verbose)
78      !! tall-and-skinny matrix-matrix multiplication. Undocumented dummy arguments are identical to
79      !! arguments of dbcsr_multiply (see dbcsr_mm, dbcsr_multiply_generic).
80
81      CHARACTER(LEN=1), INTENT(IN)               :: transa, transb, transc
82      TYPE(dbcsr_scalar_type), INTENT(IN)        :: alpha, beta
83      TYPE(dbcsr_tas_type), TARGET, &
84         INTENT(INOUT)                           :: matrix_a, matrix_b, matrix_c
85      LOGICAL, INTENT(IN), OPTIONAL              :: optimize_dist
86         !! Whether distribution should be optimized internally. In the current implementation this guarantees optimal parameters
87         !! only for dense matrices.
88      TYPE(dbcsr_tas_split_info), INTENT(OUT), &
89         POINTER, OPTIONAL                       :: split_opt
90         !! optionally return split info containing optimal grid and split parameters. This can be used to choose optimal process
91         !! grids for subsequent matrix multiplications with matrices of similar shape and sparsity. under some conditions,
92         !! split_opt can not be returned, in this case the pointer is not associated
93      REAL(KIND=real_8), INTENT(IN), OPTIONAL    :: filter_eps
94      INTEGER(KIND=int_8), INTENT(OUT), OPTIONAL :: flop
95      LOGICAL, INTENT(IN), OPTIONAL              :: move_data_a, move_data_b, simple_split, retain_sparsity
96         !! memory optimization: move data to matrix_c such that matrix_a is empty on return
97         !! memory optimization: move data to matrix_c such that matrix_b is empty on return
98         !! for internal use only
99      INTEGER(int_8), DIMENSION(:, :), ALLOCATABLE, INTENT(OUT), OPTIONAL :: result_index
100      INTEGER, OPTIONAL, INTENT(IN)              :: unit_nr
101         !! unit number for logging output
102      LOGICAL, OPTIONAL, INTENT(IN)              :: log_verbose
103         !! only for testing: verbose output
104
105      TYPE(dbcsr_tas_type), POINTER              :: matrix_b_rs, matrix_a_rs, matrix_c_rs, &
106                                                    matrix_c_rep, matrix_b_rep, matrix_a_rep
107
108      REAL(KIND=real_8)                          :: filter_eps_prv
109      INTEGER(KIND=int_8), DIMENSION(2)          :: dims_a, dims_b, dims_c
110      INTEGER, DIMENSION(2)                      :: pdims, pcoord, pcoord_sub, pdims_sub
111      INTEGER(KIND=int_8), DIMENSION(3)          :: dims
112      INTEGER                                    :: max_mm_dim, data_type, mp_comm, comm_tmp, &
113                                                    handle, handle2, unit_nr_prv, nsplit, nsplit_opt, numproc, numproc_sub, iproc, &
114                                                    mp_comm_group, mp_comm_mm, split_rc, split_a, split_b, split_c, &
115                                                    mp_comm_opt, batched_repl, max_mm_dim_batched, nsplit_batched
116      CHARACTER(LEN=1)                           :: tr_case, transa_prv, transb_prv, transc_prv
117      TYPE(dbcsr_scalar_type)                    :: zero
118      LOGICAL                                    :: new_a, new_b, new_c, simple_split_prv, opt_pgrid, &
119                                                    move_a, move_b, do_batched, simple_split_save, &
120                                                    nodata_3
121      TYPE(dbcsr_tas_split_info)                 :: info, info_a, info_b, info_c
122      CHARACTER(LEN=*), PARAMETER                :: routineN = 'dbcsr_tas_multiply', &
123                                                    routineP = moduleN//':'//routineN
124      INTEGER(KIND=int_8)                        :: nze_a, nze_b, nze_c
125      TYPE(dbcsr_type), POINTER                  :: matrix_a_mm, matrix_b_mm, matrix_c_mm
126
127      CALL timeset(routineN, handle)
128
129      NULLIFY (matrix_b_rs, matrix_a_rs, matrix_c_rs, matrix_a_mm, matrix_b_mm, matrix_c_mm)
130
131      unit_nr_prv = prep_output_unit(unit_nr)
132
133      IF (PRESENT(simple_split)) THEN
134         simple_split_prv = simple_split
135      ELSE
136         simple_split_prv = .FALSE.
137
138         info_a = dbcsr_tas_info(matrix_a); info_b = dbcsr_tas_info(matrix_b); info_c = dbcsr_tas_info(matrix_c)
139         IF (info_a%strict_split .OR. info_b%strict_split .OR. info_c%strict_split) simple_split_prv = .TRUE.
140      ENDIF
141
142      nodata_3 = .TRUE.
143      IF (PRESENT(retain_sparsity)) THEN
144         IF (retain_sparsity) nodata_3 = .FALSE.
145      ENDIF
146
147      ! get prestored info for multiplication strategy in case of batched mm
148      batched_repl = 0
149      do_batched = .FALSE.
150      IF (matrix_a%do_batched > 0) THEN
151         do_batched = .TRUE.
152         IF (matrix_a%do_batched == 2) THEN
153            DBCSR_ASSERT(batched_repl == 0)
154            batched_repl = 1
155            CALL dbcsr_tas_get_split_info( &
156               dbcsr_tas_info(matrix_a%mm_storage%store_batched_repl), &
157               nsplit=nsplit_batched)
158            DBCSR_ASSERT(nsplit_batched > 0)
159            max_mm_dim_batched = 3
160         ENDIF
161      ENDIF
162
163      IF (matrix_b%do_batched > 0) THEN
164         do_batched = .TRUE.
165         IF (matrix_b%do_batched == 2) THEN
166            DBCSR_ASSERT(batched_repl == 0)
167            batched_repl = 2
168            CALL dbcsr_tas_get_split_info( &
169               dbcsr_tas_info(matrix_b%mm_storage%store_batched_repl), &
170               nsplit=nsplit_batched)
171            DBCSR_ASSERT(nsplit_batched > 0)
172            max_mm_dim_batched = 1
173         ENDIF
174      ENDIF
175
176      IF (matrix_c%do_batched > 0) THEN
177         do_batched = .TRUE.
178         IF (matrix_c%do_batched == 2) THEN
179            DBCSR_ASSERT(batched_repl == 0)
180            batched_repl = 3
181            CALL dbcsr_tas_get_split_info( &
182               dbcsr_tas_info(matrix_c%mm_storage%store_batched_repl), &
183               nsplit=nsplit_batched)
184            DBCSR_ASSERT(nsplit_batched > 0)
185            max_mm_dim_batched = 2
186         ENDIF
187      ENDIF
188
189      IF (batched_repl > 0) THEN
190         simple_split_save = simple_split_prv
191         simple_split_prv = .TRUE.
192      ENDIF
193
194      move_a = .FALSE.
195      move_b = .FALSE.
196
197      IF (PRESENT(move_data_a)) move_a = move_data_a
198      IF (PRESENT(move_data_b)) move_b = move_data_b
199
200      IF (.NOT. dbcsr_tas_get_data_type(matrix_a) .EQ. dbcsr_tas_get_data_type(matrix_b)) THEN
201         DBCSR_ABORT("matrices must have same datatype")
202      ENDIF
203
204      data_type = dbcsr_tas_get_data_type(matrix_a)
205
206      transa_prv = transa; transb_prv = transb; transc_prv = transc
207
208      dims_a = [dbcsr_tas_nblkrows_total(matrix_a), dbcsr_tas_nblkcols_total(matrix_a)]
209      dims_b = [dbcsr_tas_nblkrows_total(matrix_b), dbcsr_tas_nblkcols_total(matrix_b)]
210      dims_c = [dbcsr_tas_nblkrows_total(matrix_c), dbcsr_tas_nblkcols_total(matrix_c)]
211
212      IF (unit_nr_prv .GT. 0) THEN
213         WRITE (unit_nr_prv, '(A)') repeat("-", 80)
214         WRITE (unit_nr_prv, '(A,1X,A,1X,A,1X,A,1X,A,1X,A)') "DBCSR TAS MATRIX MULTIPLICATION:", &
215            TRIM(matrix_a%matrix%name), 'x', TRIM(matrix_b%matrix%name), '=', TRIM(matrix_c%matrix%name)
216         WRITE (unit_nr_prv, '(A)') repeat("-", 80)
217      ENDIF
218      IF (do_batched) THEN
219         IF (unit_nr_prv > 0) THEN
220            WRITE (unit_nr_prv, "(T2,A)") &
221               "BATCHED PROCESSING OF MATMUL"
222            IF (batched_repl > 0) THEN
223               WRITE (unit_nr_prv, "(T4,A,T80,I1)") "reusing replicated matrix:", batched_repl
224            ENDIF
225         ENDIF
226      ENDIF
227
228      IF (transa_prv .EQ. dbcsr_transpose) THEN
229         CALL swap(dims_a)
230      ENDIF
231
232      IF (transb_prv .EQ. dbcsr_transpose) THEN
233         CALL swap(dims_b)
234      ENDIF
235
236      dims_c = [dims_a(1), dims_b(2)]
237
238      IF (.NOT. (dims_a(2) .EQ. dims_b(1))) THEN
239         DBCSR_ABORT("inconsistent matrix dimensions")
240      ENDIF
241
242      dims(:) = [dims_a(1), dims_a(2), dims_b(2)]
243
244      tr_case = ''
245
246      IF (unit_nr_prv > 0) THEN
247         WRITE (unit_nr_prv, "(T2,A, 1X, I12, 1X, I12, 1X, I12)") "mm dims:", dims(1), dims(2), dims(3)
248      ENDIF
249
250      CALL dbcsr_tas_get_split_info(dbcsr_tas_info(matrix_a), mp_comm=mp_comm)
251      CALL mp_environ(numproc, iproc, mp_comm)
252
253      ! derive optimal matrix layout and split factor from occupancies
254      IF (.NOT. simple_split_prv) THEN
255         nze_a = dbcsr_tas_get_nze_total(matrix_a)
256         nze_b = dbcsr_tas_get_nze_total(matrix_b)
257         CALL dbcsr_tas_result_index(transa, transb, transc, matrix_a, matrix_b, matrix_c, filter_eps, &
258                                     blk_ind=result_index, nze=nze_c, retain_sparsity=retain_sparsity)
259         IF (PRESENT(result_index)) THEN
260            CALL timestop(handle)
261            RETURN
262         ENDIF
263
264         max_mm_dim = MAXLOC(dims, 1)
265         nsplit = split_factor_estimate(max_mm_dim, nze_a, nze_b, nze_c, numproc)
266         nsplit_opt = nsplit
267
268         IF (unit_nr_prv > 0) THEN
269            WRITE (unit_nr_prv, "(T2,A)") &
270               "MM PARAMETERS"
271            WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Est. number of matrix elements per CPU of result matrix:", &
272               (nze_c + numproc - 1)/numproc
273
274            WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Est. optimal split factor:", nsplit
275         ENDIF
276
277      ELSEIF (batched_repl > 0) THEN
278         nsplit = nsplit_batched
279         nsplit_opt = nsplit
280         max_mm_dim = max_mm_dim_batched
281         simple_split_prv = simple_split_save
282         IF (unit_nr_prv > 0) THEN
283            WRITE (unit_nr_prv, "(T2,A)") &
284               "MM PARAMETERS"
285            WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Est. optimal split factor:", nsplit
286         ENDIF
287
288      ELSE
289         nsplit = 0
290         max_mm_dim = MAXLOC(dims, 1)
291      ENDIF
292
293      ! reshape matrices to the optimal layout and split factor
294      split_a = rowsplit; split_b = rowsplit; split_c = rowsplit
295      SELECT CASE (max_mm_dim)
296      CASE (1)
297
298         split_a = rowsplit; split_c = rowsplit
299         CALL reshape_mm_compatible(matrix_a, matrix_c, matrix_a_rs, matrix_c_rs, &
300                                    new_a, new_c, transa_prv, transc_prv, optimize_dist=optimize_dist, &
301                                    nsplit=nsplit, &
302                                    opt_nsplit=batched_repl == 0, &
303                                    split_rc_1=split_a, split_rc_2=split_c, &
304                                    nodata2=nodata_3, comm_new=comm_tmp, &
305                                    move_data_1=move_a, unit_nr=unit_nr_prv)
306
307         info = dbcsr_tas_info(matrix_a_rs)
308         CALL dbcsr_tas_get_split_info(info, split_rowcol=split_rc, mp_comm=mp_comm)
309
310         new_b = .FALSE.
311         IF (matrix_b%do_batched <= 1) THEN
312            ALLOCATE (matrix_b_rs)
313            CALL reshape_mm_small(mp_comm, matrix_b, matrix_b_rs, transb_prv == dbcsr_transpose, transb_prv, move_data=move_b)
314            new_b = .TRUE.
315         ENDIF
316
317         tr_case = transa_prv
318
319         IF (unit_nr_prv > 0) THEN
320            IF (tr_case == 'N') THEN
321               WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "| x + = |"
322            ELSE
323               WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "--T x + = --T"
324            ENDIF
325         ENDIF
326
327      CASE (2)
328
329         split_a = colsplit; split_b = rowsplit
330         CALL reshape_mm_compatible(matrix_a, matrix_b, matrix_a_rs, matrix_b_rs, new_a, new_b, transa_prv, transb_prv, &
331                                    optimize_dist=optimize_dist, &
332                                    nsplit=nsplit, &
333                                    opt_nsplit=batched_repl == 0, &
334                                    split_rc_1=split_a, split_rc_2=split_b, &
335                                    comm_new=comm_tmp, &
336                                    move_data_1=move_a, move_data_2=move_b, unit_nr=unit_nr_prv)
337
338         info = dbcsr_tas_info(matrix_a_rs)
339         CALL dbcsr_tas_get_split_info(info, split_rowcol=split_rc, mp_comm=mp_comm)
340
341         IF (matrix_c%do_batched <= 1) THEN
342            ALLOCATE (matrix_c_rs)
343            CALL reshape_mm_small(mp_comm, matrix_c, matrix_c_rs, transc_prv == dbcsr_transpose, transc_prv, nodata=nodata_3)
344            IF (matrix_c%do_batched == 1) THEN
345               matrix_c%mm_storage%store_batched => matrix_c_rs
346            ENDIF
347         ELSEIF (matrix_c%do_batched == 2) THEN
348            matrix_c_rs => matrix_c%mm_storage%store_batched
349         ENDIF
350
351         new_c = matrix_c%do_batched == 0
352         tr_case = transa_prv
353
354         IF (unit_nr_prv > 0) THEN
355            IF (tr_case == 'N') THEN
356               WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "-- x --T = +"
357            ELSE
358               WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "|T x | = +"
359            ENDIF
360         ENDIF
361
362      CASE (3)
363
364         split_b = colsplit; split_c = colsplit
365         CALL reshape_mm_compatible(matrix_b, matrix_c, matrix_b_rs, matrix_c_rs, new_b, new_c, transb_prv, &
366                                    transc_prv, optimize_dist=optimize_dist, &
367                                    nsplit=nsplit, &
368                                    opt_nsplit=batched_repl == 0, &
369                                    split_rc_1=split_b, split_rc_2=split_c, &
370                                    nodata2=nodata_3, comm_new=comm_tmp, &
371                                    move_data_1=move_b, unit_nr=unit_nr_prv)
372         info = dbcsr_tas_info(matrix_b_rs)
373         CALL dbcsr_tas_get_split_info(info, split_rowcol=split_rc, mp_comm=mp_comm)
374
375         new_a = .FALSE.
376         IF (matrix_a%do_batched <= 1) THEN
377            ALLOCATE (matrix_a_rs)
378            CALL reshape_mm_small(mp_comm, matrix_a, matrix_a_rs, transa_prv == dbcsr_transpose, transa_prv, move_data=move_a)
379            new_a = .TRUE.
380         ENDIF
381
382         tr_case = transb_prv
383
384         IF (unit_nr_prv > 0) THEN
385            IF (tr_case == 'N') THEN
386               WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "+ x -- = --"
387            ELSE
388               WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "+ x |T = |T"
389            ENDIF
390         ENDIF
391
392      END SELECT
393
394      CALL dbcsr_tas_get_split_info(info, nsplit=nsplit, mp_comm=mp_comm, mp_comm_group=mp_comm_group)
395
396      CALL mp_environ(numproc, pdims, pcoord, mp_comm)
397      CALL mp_environ(numproc_sub, pdims_sub, pcoord_sub, mp_comm_group)
398
399      IF (.NOT. simple_split_prv) THEN
400         opt_pgrid = .NOT. accept_pgrid_dims(pdims_sub, relative=.TRUE.)
401      ELSE
402         opt_pgrid = .FALSE.
403      ENDIF
404
405      IF (PRESENT(filter_eps)) THEN
406         filter_eps_prv = filter_eps
407      ELSE
408         filter_eps_prv = 0.0_real_8
409      ENDIF
410
411      IF (unit_nr_prv /= 0) THEN
412         IF (unit_nr_prv > 0) THEN
413            WRITE (unit_nr_prv, "(T2, A)") "SPLIT / PARALLELIZATION INFO"
414         ENDIF
415         CALL dbcsr_tas_write_split_info(info, unit_nr_prv)
416         IF (ASSOCIATED(matrix_a_rs)) CALL dbcsr_tas_write_matrix_info(matrix_a_rs, unit_nr_prv, full_info=log_verbose)
417         IF (ASSOCIATED(matrix_b_rs)) CALL dbcsr_tas_write_matrix_info(matrix_b_rs, unit_nr_prv, full_info=log_verbose)
418         IF (ASSOCIATED(matrix_c_rs)) CALL dbcsr_tas_write_matrix_info(matrix_c_rs, unit_nr_prv, full_info=log_verbose)
419         IF (unit_nr_prv > 0) THEN
420            IF (opt_pgrid) THEN
421               WRITE (unit_nr_prv, "(T4, A, 1X, A)") "Change process grid:", "Yes"
422            ELSE
423               WRITE (unit_nr_prv, "(T4, A, 1X, A)") "Change process grid:", "No"
424            ENDIF
425         ENDIF
426      ENDIF
427
428      zero = dbcsr_scalar_zero(data_type)
429
430      pdims = 0
431      CALL mp_cart_create(mp_comm_group, 2, pdims, pcoord, mp_comm_mm)
432
433      ! Convert DBCSR submatrices to optimized process grids and multiply
434      SELECT CASE (max_mm_dim)
435      CASE (1)
436         IF (matrix_b%do_batched <= 1) THEN
437            ALLOCATE (matrix_b_rep)
438            CALL dbcsr_tas_replicate(matrix_b_rs%matrix, dbcsr_tas_info(matrix_a_rs), matrix_b_rep, move_data=.TRUE.)
439            IF (matrix_b%do_batched == 1) THEN
440               matrix_b%mm_storage%store_batched_repl => matrix_b_rep
441               matrix_b%do_batched = 2
442            ENDIF
443         ELSEIF (matrix_b%do_batched == 2) THEN
444            matrix_b_rep => matrix_b%mm_storage%store_batched_repl
445         ENDIF
446
447         IF (new_b) THEN
448            CALL dbcsr_tas_destroy(matrix_b_rs)
449            DEALLOCATE (matrix_b_rs)
450         ENDIF
451         IF (unit_nr_prv /= 0) THEN
452            CALL dbcsr_tas_write_dist(matrix_a_rs, unit_nr_prv)
453            CALL dbcsr_tas_write_dist(matrix_b_rep, unit_nr_prv, full_info=log_verbose)
454         ENDIF
455
456         CALL convert_to_new_pgrid(mp_comm_mm, matrix_a_rs%matrix, matrix_a_mm, optimize_pgrid=opt_pgrid, move_data=move_a)
457
458         IF (new_a .AND. opt_pgrid) THEN
459            CALL dbcsr_tas_destroy(matrix_a_rs)
460            DEALLOCATE (matrix_a_rs)
461         ENDIF
462         CALL convert_to_new_pgrid(mp_comm_mm, matrix_b_rep%matrix, matrix_b_mm, optimize_pgrid=opt_pgrid, &
463                                   move_data=matrix_b%do_batched == 0)
464
465         IF (opt_pgrid .AND. matrix_b%do_batched == 0) THEN
466            CALL dbcsr_tas_destroy(matrix_b_rep)
467            DEALLOCATE (matrix_b_rep)
468         ENDIF
469
470         CALL convert_to_new_pgrid(mp_comm_mm, matrix_c_rs%matrix, matrix_c_mm, nodata=nodata_3, optimize_pgrid=opt_pgrid)
471
472         SELECT CASE (tr_case)
473         CASE (dbcsr_no_transpose)
474            CALL timeset(routineN//"_mm_1N", handle2)
475
476            CALL dbcsr_multiply(transa=dbcsr_no_transpose, transb=dbcsr_no_transpose, alpha=alpha, &
477                                matrix_a=matrix_a_mm, matrix_b=matrix_b_mm, beta=beta, matrix_c=matrix_c_mm, &
478                                filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
479            CALL timestop(handle2)
480         CASE (dbcsr_transpose)
481            CALL timeset(routineN//"_mm_1T", handle2)
482            CALL dbcsr_multiply(transa=dbcsr_transpose, transb=dbcsr_no_transpose, alpha=alpha, &
483                                matrix_a=matrix_b_mm, matrix_b=matrix_a_mm, beta=beta, matrix_c=matrix_c_mm, &
484                                filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
485
486            CALL timestop(handle2)
487         END SELECT
488
489         IF (opt_pgrid) THEN
490            CALL dbcsr_release(matrix_a_mm)
491            CALL dbcsr_release(matrix_b_mm)
492            IF (.NOT. new_c) THEN
493               CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, alpha=beta)
494            ELSE
495               CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix)
496            ENDIF
497
498            CALL dbcsr_release(matrix_c_mm)
499
500            DEALLOCATE (matrix_a_mm, matrix_b_mm, matrix_c_mm)
501         ELSE
502            IF (new_a) CALL dbcsr_tas_destroy(matrix_a_rs)
503            IF (new_a) DEALLOCATE (matrix_a_rs)
504            IF (matrix_b%do_batched == 0) THEN
505               CALL dbcsr_tas_destroy(matrix_b_rep)
506               DEALLOCATE (matrix_b_rep)
507            ENDIF
508         ENDIF
509
510         IF (PRESENT(filter_eps)) CALL dbcsr_tas_filter(matrix_c_rs, filter_eps)
511
512         IF (unit_nr_prv /= 0) THEN
513            CALL dbcsr_tas_write_dist(matrix_c_rs, unit_nr_prv)
514         ENDIF
515
516      CASE (2)
517         IF (matrix_c%do_batched <= 1) THEN
518            ALLOCATE (matrix_c_rep)
519            CALL dbcsr_tas_replicate(matrix_c_rs%matrix, dbcsr_tas_info(matrix_a_rs), matrix_c_rep, nodata=nodata_3)
520            IF (matrix_c%do_batched == 1) THEN
521               matrix_c%mm_storage%store_batched_repl => matrix_c_rep
522               matrix_c%do_batched = 2
523            ENDIF
524         ELSEIF (matrix_c%do_batched == 2) THEN
525            matrix_c_rep => matrix_c%mm_storage%store_batched_repl
526         ENDIF
527
528         IF (unit_nr_prv /= 0) THEN
529            CALL dbcsr_tas_write_dist(matrix_a_rs, unit_nr_prv)
530            CALL dbcsr_tas_write_dist(matrix_b_rs, unit_nr_prv)
531         ENDIF
532
533         CALL convert_to_new_pgrid(mp_comm_mm, matrix_a_rs%matrix, matrix_a_mm, optimize_pgrid=opt_pgrid, move_data=move_a)
534         IF (new_a .AND. opt_pgrid) THEN
535            CALL dbcsr_tas_destroy(matrix_a_rs)
536            DEALLOCATE (matrix_a_rs)
537         ENDIF
538
539         CALL convert_to_new_pgrid(mp_comm_mm, matrix_b_rs%matrix, matrix_b_mm, optimize_pgrid=opt_pgrid, move_data=move_b)
540         IF (new_b .AND. opt_pgrid) THEN
541            CALL dbcsr_tas_destroy(matrix_b_rs)
542            DEALLOCATE (matrix_b_rs)
543         ENDIF
544
545         CALL convert_to_new_pgrid(mp_comm_mm, matrix_c_rep%matrix, matrix_c_mm, nodata=nodata_3, optimize_pgrid=opt_pgrid)
546
547         CALL timeset(routineN//"_mm_2", handle2)
548         CALL dbcsr_multiply(transa=transa_prv, transb=transb_prv, alpha=alpha, matrix_a=matrix_a_mm, &
549                             matrix_b=matrix_b_mm, beta=beta, matrix_c=matrix_c_mm, &
550                             filter_eps=filter_eps_prv/REAL(nsplit, KIND=real_8), retain_sparsity=retain_sparsity, flop=flop)
551         CALL timestop(handle2)
552
553         IF (opt_pgrid) THEN
554            CALL dbcsr_release(matrix_a_mm)
555            CALL dbcsr_release(matrix_b_mm)
556            CALL redistribute_and_sum(matrix_c_mm, matrix_c_rep%matrix, alpha=beta)
557            CALL dbcsr_release(matrix_c_mm)
558
559            DEALLOCATE (matrix_a_mm, matrix_b_mm, matrix_c_mm)
560         ELSE
561            IF (new_a) CALL dbcsr_tas_destroy(matrix_a_rs)
562            IF (new_a) DEALLOCATE (matrix_a_rs)
563            IF (new_b) CALL dbcsr_tas_destroy(matrix_b_rs)
564            IF (new_b) DEALLOCATE (matrix_b_rs)
565         ENDIF
566
567         IF (unit_nr_prv /= 0) THEN
568            CALL dbcsr_tas_write_dist(matrix_c_rep, unit_nr_prv, full_info=log_verbose)
569         ENDIF
570
571         IF (matrix_c%do_batched == 0) THEN
572            CALL dbcsr_tas_merge(matrix_c_rs%matrix, matrix_c_rep, move_data=.TRUE.)
573         ELSE
574            matrix_c%mm_storage%batched_out = .TRUE. ! postpone merging submatrices to dbcsr_tas_batched_mm_finalize
575         ENDIF
576
577         IF (matrix_c%do_batched == 0) THEN
578            CALL dbcsr_tas_destroy(matrix_c_rep)
579            DEALLOCATE (matrix_c_rep)
580         ENDIF
581         IF (PRESENT(filter_eps)) CALL dbcsr_tas_filter(matrix_c_rs, filter_eps)
582      CASE (3)
583         IF (matrix_a%do_batched <= 1) THEN
584            ALLOCATE (matrix_a_rep)
585            CALL dbcsr_tas_replicate(matrix_a_rs%matrix, dbcsr_tas_info(matrix_b_rs), matrix_a_rep, move_data=.TRUE.)
586            IF (matrix_a%do_batched == 1) THEN
587               matrix_a%mm_storage%store_batched_repl => matrix_a_rep
588               matrix_a%do_batched = 2
589            ENDIF
590         ELSEIF (matrix_a%do_batched == 2) THEN
591            matrix_a_rep => matrix_a%mm_storage%store_batched_repl
592         ENDIF
593
594         IF (new_a) THEN
595            CALL dbcsr_tas_destroy(matrix_a_rs)
596            DEALLOCATE (matrix_a_rs)
597         ENDIF
598         IF (unit_nr_prv /= 0) THEN
599            CALL dbcsr_tas_write_dist(matrix_a_rep, unit_nr_prv, full_info=log_verbose)
600            CALL dbcsr_tas_write_dist(matrix_b_rs, unit_nr_prv)
601         ENDIF
602
603         CALL convert_to_new_pgrid(mp_comm_mm, matrix_a_rep%matrix, matrix_a_mm, optimize_pgrid=opt_pgrid, &
604                                   move_data=matrix_a%do_batched == 0)
605
606         IF (opt_pgrid .AND. matrix_a%do_batched == 0) THEN
607            CALL dbcsr_tas_destroy(matrix_a_rep)
608            DEALLOCATE (matrix_a_rep)
609         ENDIF
610
611         CALL convert_to_new_pgrid(mp_comm_mm, matrix_b_rs%matrix, matrix_b_mm, optimize_pgrid=opt_pgrid, move_data=move_b)
612
613         IF (new_b .AND. opt_pgrid) THEN
614            CALL dbcsr_tas_destroy(matrix_b_rs)
615            DEALLOCATE (matrix_b_rs)
616         ENDIF
617         CALL convert_to_new_pgrid(mp_comm_mm, matrix_c_rs%matrix, matrix_c_mm, nodata=nodata_3, optimize_pgrid=opt_pgrid)
618
619         SELECT CASE (tr_case)
620         CASE (dbcsr_no_transpose)
621            CALL timeset(routineN//"_mm_3N", handle2)
622            CALL dbcsr_multiply(transa=dbcsr_no_transpose, transb=dbcsr_no_transpose, alpha=alpha, &
623                                matrix_a=matrix_a_mm, matrix_b=matrix_b_mm, beta=beta, matrix_c=matrix_c_mm, &
624                                filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
625            CALL timestop(handle2)
626         CASE (dbcsr_transpose)
627            CALL timeset(routineN//"_mm_3T", handle2)
628            CALL dbcsr_multiply(transa=dbcsr_no_transpose, transb=dbcsr_transpose, alpha=alpha, &
629                                matrix_a=matrix_b_mm, matrix_b=matrix_a_mm, beta=beta, matrix_c=matrix_c_mm, &
630                                filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop)
631            CALL timestop(handle2)
632         END SELECT
633
634         IF (opt_pgrid) THEN
635            CALL dbcsr_release(matrix_a_mm)
636            CALL dbcsr_release(matrix_b_mm)
637
638            IF (.NOT. new_c) THEN
639               CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, alpha=beta)
640            ELSE
641               CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix)
642            ENDIF
643
644            CALL dbcsr_release(matrix_c_mm)
645
646            DEALLOCATE (matrix_a_mm, matrix_b_mm, matrix_c_mm)
647         ELSE
648            IF (new_b) CALL dbcsr_tas_destroy(matrix_b_rs)
649            IF (new_b) DEALLOCATE (matrix_b_rs)
650            IF (matrix_a%do_batched == 0) THEN
651               CALL dbcsr_tas_destroy(matrix_a_rep)
652               DEALLOCATE (matrix_a_rep)
653            ENDIF
654         ENDIF
655
656         IF (PRESENT(filter_eps)) CALL dbcsr_tas_filter(matrix_c_rs, filter_eps)
657
658         IF (unit_nr_prv /= 0) THEN
659            CALL dbcsr_tas_write_dist(matrix_c_rs, unit_nr_prv)
660         ENDIF
661      END SELECT
662
663      CALL mp_comm_free(mp_comm_mm)
664
665      CALL dbcsr_tas_get_split_info(dbcsr_tas_info(matrix_c), mp_comm=mp_comm)
666
667      IF (PRESENT(split_opt) .AND. .NOT. simple_split_prv) THEN
668         ! ideally we should rederive the split factor from the actual sparsity of C, but
669         ! due to parameter beta, we can not get the sparsity of AxB from DBCSR if not new_c
670         ALLOCATE (split_opt)
671         mp_comm_opt = dbcsr_tas_mp_comm(mp_comm, split_rc, nsplit_opt)
672         CALL dbcsr_tas_create_split(split_opt, mp_comm_opt, split_rc, nsplit_opt, own_comm=.TRUE.)
673      ENDIF
674
675      IF (new_c) THEN
676         CALL dbcsr_scale(matrix_c%matrix, beta)
677         CALL dbcsr_tas_reshape(matrix_c_rs, matrix_c, summation=.TRUE., transposed=transc_prv /= transc, &
678                                move_data=.TRUE.)
679         CALL dbcsr_tas_destroy(matrix_c_rs)
680         DEALLOCATE (matrix_c_rs)
681         IF (PRESENT(filter_eps)) CALL dbcsr_tas_filter(matrix_c, filter_eps)
682      ELSEIF (matrix_c%do_batched > 0) THEN
683         IF (matrix_c%mm_storage%batched_out) THEN
684            matrix_c%mm_storage%batched_beta = beta
685            matrix_c%mm_storage%batched_trans = transc_prv /= transc
686         ENDIF
687      ENDIF
688
689      IF (PRESENT(move_data_a)) THEN
690         IF (move_data_a) CALL dbcsr_tas_clear(matrix_a)
691      ENDIF
692      IF (PRESENT(move_data_b)) THEN
693         IF (move_data_b) CALL dbcsr_tas_clear(matrix_b)
694      ENDIF
695
696      IF (PRESENT(flop)) THEN
697         CALL mp_sum(flop, mp_comm)
698         flop = (flop + numproc - 1)/numproc
699      ENDIF
700
701      IF (PRESENT(optimize_dist)) THEN
702         IF (optimize_dist) CALL mp_comm_free(comm_tmp)
703      ENDIF
704      IF (unit_nr_prv > 0) THEN
705         WRITE (unit_nr_prv, '(A)') repeat("-", 80)
706         WRITE (unit_nr_prv, '(A,1X,A,1X,A,1X,A,1X,A,1X,A)') "TAS MATRIX MULTIPLICATION DONE"
707         WRITE (unit_nr_prv, '(A)') repeat("-", 80)
708      ENDIF
709
710      CALL timestop(handle)
711
712   END SUBROUTINE
713
714   SUBROUTINE redistribute_and_sum(matrix_in, matrix_out, alpha)
715      TYPE(dbcsr_type), INTENT(IN) :: matrix_in
716      TYPE(dbcsr_type), INTENT(INOUT) :: matrix_out
717      TYPE(dbcsr_scalar_type), INTENT(IN), OPTIONAL :: alpha
718      TYPE(dbcsr_type) :: matrix_tmp
719
720      CALL dbcsr_create(matrix_tmp, matrix_out)
721      CALL dbcsr_redistribute(matrix_in, matrix_tmp)
722      CALL dbcsr_add(matrix_out, matrix_tmp, alpha_scalar=alpha)
723      CALL dbcsr_release(matrix_tmp)
724
725   END SUBROUTINE
726
727   SUBROUTINE reshape_mm_small(mp_comm, matrix_in, matrix_out, transposed, trans, nodata, move_data)
728      !! Make sure that smallest matrix involved in a multiplication is not split and bring it to
729      !! the same process grid as the other 2 matrices.
730
731      INTEGER, INTENT(IN)               :: mp_comm
732         !! communicator that defines Cartesian topology
733      TYPE(dbcsr_tas_type), INTENT(INOUT) :: matrix_in
734      TYPE(dbcsr_tas_type), INTENT(OUT)   :: matrix_out
735      LOGICAL, INTENT(IN)               :: transposed
736         !! Whether matrix_out should be transposed
737      CHARACTER(LEN=1), INTENT(INOUT)   :: trans
738         !! update transpose flag for DBCSR mm according to 'transposed' argument
739      LOGICAL, INTENT(IN), OPTIONAL     :: nodata, move_data
740         !! Data of matrix_in should not be copied to matrix_out
741         !! memory optimization: move data such that matrix_in is empty on return.
742
743      INTEGER                           :: numnodes
744      INTEGER(KIND=int_8), DIMENSION(2) :: dims
745      INTEGER, DIMENSION(2)             :: pdims, pcoord
746      TYPE(dbcsr_tas_dist_arb)            :: new_row_dist, new_col_dist
747      TYPE(dbcsr_tas_distribution_type)   :: dist
748      LOGICAL                           :: nodata_prv
749      CHARACTER(LEN=*), PARAMETER       :: routineN = 'reshape_mm_small', &
750                                           routineP = moduleN//':'//routineN
751      INTEGER                           :: handle
752
753      CALL timeset(routineN, handle)
754
755      IF (PRESENT(nodata)) THEN
756         nodata_prv = nodata
757      ELSE
758         nodata_prv = .FALSE.
759      ENDIF
760
761      IF (transposed) THEN
762         SELECT CASE (trans)
763         CASE (dbcsr_transpose)
764            trans = dbcsr_no_transpose
765         CASE (dbcsr_no_transpose)
766            trans = dbcsr_transpose
767         END SELECT
768      ENDIF
769
770      CALL mp_environ(numnodes, pdims, pcoord, mp_comm)
771
772      dims = [dbcsr_tas_nblkrows_total(matrix_in), dbcsr_tas_nblkcols_total(matrix_in)]
773
774      IF (transposed) CALL swap(dims)
775
776      IF (.NOT. transposed) THEN
777         new_row_dist = dbcsr_tas_dist_arb_default(pdims(1), dims(1), matrix_in%row_blk_size)
778         new_col_dist = dbcsr_tas_dist_arb_default(pdims(2), dims(2), matrix_in%col_blk_size)
779         CALL dbcsr_tas_distribution_new(dist, mp_comm, new_row_dist, new_col_dist, nosplit=.TRUE.)
780         CALL dbcsr_tas_create(matrix_out, matrix_in%matrix%name, dist, dbcsr_tas_get_data_type(matrix_in), &
781                               matrix_in%row_blk_size, matrix_in%col_blk_size, own_dist=.TRUE.)
782      ELSE
783         new_row_dist = dbcsr_tas_dist_arb_default(pdims(1), dims(1), matrix_in%col_blk_size)
784         new_col_dist = dbcsr_tas_dist_arb_default(pdims(2), dims(2), matrix_in%row_blk_size)
785         CALL dbcsr_tas_distribution_new(dist, mp_comm, new_row_dist, new_col_dist, nosplit=.TRUE.)
786         CALL dbcsr_tas_create(matrix_out, matrix_in%matrix%name, dist, dbcsr_tas_get_data_type(matrix_in), &
787                               matrix_in%col_blk_size, matrix_in%row_blk_size, own_dist=.TRUE.)
788
789      ENDIF
790      IF (.NOT. nodata_prv) CALL dbcsr_tas_reshape(matrix_in, matrix_out, transposed=transposed, move_data=move_data)
791
792      CALL timestop(handle)
793
794   END SUBROUTINE
795
796   SUBROUTINE reshape_mm_compatible(matrix1_in, matrix2_in, matrix1_out, matrix2_out, new1, new2, trans1, trans2, &
797                                    optimize_dist, nsplit, opt_nsplit, split_rc_1, split_rc_2, nodata1, nodata2, &
798                                    move_data_1, move_data_2, comm_new, unit_nr)
799      !! Reshape either matrix1 or matrix2 to make sure that their process grids are compatible with
800      !! the same split factor.
801
802      TYPE(dbcsr_tas_type), TARGET, &
803         INTENT(INOUT)                           :: matrix1_in, matrix2_in
804      TYPE(dbcsr_tas_type), POINTER, INTENT(OUT) :: matrix1_out, matrix2_out
805      LOGICAL, INTENT(OUT)                       :: new1, new2
806         !! Whether matrix1_out is a new matrix or simply pointing to matrix1_in
807         !! Whether matrix2_out is a new matrix or simply pointing to matrix2_in
808      CHARACTER(LEN=1), INTENT(INOUT)            :: trans1, trans2
809         !! transpose flag of matrix1_in for multiplication
810         !! transpose flag of matrix2_in for multiplication
811      LOGICAL, INTENT(IN), OPTIONAL              :: optimize_dist
812         !! experimental: optimize matrix splitting and distribution
813      INTEGER, INTENT(IN), OPTIONAL              :: nsplit
814         !! Optimal split factor (set to 0 if split factor should not be changed)
815      LOGICAL, INTENT(IN), OPTIONAL              :: opt_nsplit
816      INTEGER, INTENT(INOUT)                     :: split_rc_1, split_rc_2
817         !! Whether to split rows or columns for matrix 1
818         !! Whether to split rows or columns for matrix 2
819      INTEGER, INTENT(OUT), OPTIONAL             :: comm_new
820         !! returns the new communicator only if optimize_dist
821      LOGICAL, OPTIONAL, INTENT(IN)              :: nodata1, nodata2
822         !! Don't copy matrix data from matrix1_in to matrix1_out
823         !! Don't copy matrix data from matrix2_in to matrix2_out
824      LOGICAL, OPTIONAL, INTENT(INOUT)           :: move_data_1, move_data_2
825         !! memory optimization: move data such that matrix1_in may be empty on return.
826         !! memory optimization: move data such that matrix2_in may be empty on return.
827      INTEGER, INTENT(IN), OPTIONAL              :: unit_nr
828         !! output unit
829
830      INTEGER(KIND=int_8), DIMENSION(2)          :: dims1, dims2, dims_ref
831      INTEGER(KIND=int_8)                        :: d1, d2
832      CHARACTER(LEN=*), PARAMETER                :: routineN = 'reshape_mm_compatible', &
833                                                    routineP = moduleN//':'//routineN
834      INTEGER                                    :: handle, mp_comm, numnodes, unit_nr_prv, &
835                                                    nsplit_prv, ref, split_rc_ref
836      INTEGER, DIMENSION(2)                      :: pcoord, pdims
837      LOGICAL                                    :: optimize_dist_prv, trans1_newdist, trans2_newdist
838      TYPE(dbcsr_tas_dist_cyclic)                :: row_dist_1, col_dist_1, row_dist_2, col_dist_2
839      TYPE(dbcsr_tas_distribution_type)          :: dist_1, dist_2
840      TYPE(dbcsr_tas_split_info)                 :: split_info
841      INTEGER(KIND=int_8)                        :: nze1, nze2
842      LOGICAL                                    :: nodata1_prv, nodata2_prv
843
844      CALL timeset(routineN, handle)
845      new1 = .FALSE.; new2 = .FALSE.
846
847      IF (PRESENT(nodata1)) THEN
848         nodata1_prv = nodata1
849      ELSE
850         nodata1_prv = .FALSE.
851      ENDIF
852
853      IF (PRESENT(nodata2)) THEN
854         nodata2_prv = nodata2
855      ELSE
856         nodata2_prv = .FALSE.
857      ENDIF
858
859      unit_nr_prv = prep_output_unit(unit_nr)
860
861      NULLIFY (matrix1_out, matrix2_out)
862
863      IF (PRESENT(optimize_dist)) THEN
864         optimize_dist_prv = optimize_dist
865      ELSE
866         optimize_dist_prv = .FALSE.
867      ENDIF
868
869      dims1 = [dbcsr_tas_nblkrows_total(matrix1_in), dbcsr_tas_nblkcols_total(matrix1_in)]
870      dims2 = [dbcsr_tas_nblkrows_total(matrix2_in), dbcsr_tas_nblkcols_total(matrix2_in)]
871      nze1 = dbcsr_tas_get_nze_total(matrix1_in)
872      nze2 = dbcsr_tas_get_nze_total(matrix2_in)
873
874      IF (trans1 == dbcsr_transpose) split_rc_1 = MOD(split_rc_1, 2) + 1
875
876      IF (trans2 == dbcsr_transpose) split_rc_2 = MOD(split_rc_2, 2) + 1
877
878      IF (nze1 >= nze2) THEN
879         ref = 1
880         split_rc_ref = split_rc_1
881         dims_ref = dims1
882      ELSE
883         ref = 2
884         split_rc_ref = split_rc_2
885         dims_ref = dims2
886      ENDIF
887
888      IF (PRESENT(nsplit)) THEN
889         nsplit_prv = nsplit
890      ELSE
891         nsplit_prv = 0
892      ENDIF
893
894      IF (optimize_dist_prv) THEN
895         DBCSR_ASSERT(PRESENT(comm_new))
896      ENDIF
897
898      IF ((.NOT. optimize_dist_prv) .AND. dist_compatible(matrix1_in, matrix2_in, split_rc_1, split_rc_2)) THEN
899         CALL change_split(matrix1_in, matrix1_out, nsplit_prv, split_rc_1, new1, &
900                           move_data=move_data_1, nodata=nodata1, opt_nsplit=opt_nsplit)
901         CALL change_split(matrix2_in, matrix2_out, nsplit_prv, split_rc_2, new2, &
902                           move_data=move_data_2, nodata=nodata2, opt_nsplit=opt_nsplit)
903         IF (unit_nr_prv > 0) THEN
904            WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A,1X,A)") "No redistribution of", TRIM(matrix1_in%matrix%name), &
905               "and", TRIM(matrix2_in%matrix%name)
906            IF (new1) THEN
907               WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", TRIM(matrix1_in%matrix%name), ": Yes"
908            ELSE
909               WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", TRIM(matrix1_in%matrix%name), ": No"
910            ENDIF
911            IF (new2) THEN
912               WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", TRIM(matrix2_in%matrix%name), ": Yes"
913            ELSE
914               WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", TRIM(matrix2_in%matrix%name), ": No"
915            ENDIF
916         ENDIF
917      ELSE
918
919         IF (optimize_dist_prv) THEN
920            IF (unit_nr_prv > 0) THEN
921               WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A,1X,A)") "Optimizing distribution of", TRIM(matrix1_in%matrix%name), &
922                  "and", TRIM(matrix2_in%matrix%name)
923            ENDIF
924
925            trans1_newdist = (split_rc_1 == colsplit)
926            trans2_newdist = (split_rc_2 == colsplit)
927
928            IF (trans1_newdist) THEN
929               CALL swap(dims1)
930               CALL invert_transpose_flag(trans1)
931            ENDIF
932
933            IF (trans2_newdist) THEN
934               CALL swap(dims2)
935               CALL invert_transpose_flag(trans2)
936            ENDIF
937
938            IF (nsplit_prv == 0) THEN
939               SELECT CASE (split_rc_ref)
940               CASE (rowsplit)
941                  d1 = dims_ref(1)
942                  d2 = dims_ref(2)
943               CASE (colsplit)
944                  d1 = dims_ref(2)
945                  d2 = dims_ref(1)
946               END SELECT
947               nsplit_prv = INT((d1 - 1)/d2 + 1)
948            ENDIF
949
950            DBCSR_ASSERT(nsplit_prv > 0)
951
952            CALL dbcsr_tas_get_split_info(dbcsr_tas_info(matrix1_in), mp_comm=mp_comm)
953            comm_new = dbcsr_tas_mp_comm(mp_comm, rowsplit, nsplit_prv)
954            CALL dbcsr_tas_create_split(split_info, comm_new, rowsplit, nsplit_prv)
955
956            CALL mp_environ(numnodes, pdims, pcoord, comm_new)
957
958            ! use a very simple cyclic distribution that may not be load balanced if block
959            ! sizes are not equal. However we can not use arbitrary distributions
960            ! for large dimensions since this would require storing distribution vectors as arrays
961            ! which can not be stored for large dimensions.
962            row_dist_1 = dbcsr_tas_dist_cyclic(1, pdims(1), dims1(1))
963            col_dist_1 = dbcsr_tas_dist_cyclic(1, pdims(2), dims1(2))
964
965            row_dist_2 = dbcsr_tas_dist_cyclic(1, pdims(1), dims2(1))
966            col_dist_2 = dbcsr_tas_dist_cyclic(1, pdims(2), dims2(2))
967
968            CALL dbcsr_tas_distribution_new(dist_1, comm_new, row_dist_1, col_dist_1, split_info=split_info)
969            CALL dbcsr_tas_distribution_new(dist_2, comm_new, row_dist_2, col_dist_2, split_info=split_info)
970            CALL dbcsr_tas_release_info(split_info)
971
972            ALLOCATE (matrix1_out)
973            IF (.NOT. trans1_newdist) THEN
974               CALL dbcsr_tas_create(matrix1_out, matrix1_in%matrix%name, dist_1, dbcsr_tas_get_data_type(matrix1_in), &
975                                     matrix1_in%row_blk_size, matrix1_in%col_blk_size, own_dist=.TRUE.)
976
977            ELSE
978               CALL dbcsr_tas_create(matrix1_out, matrix1_in%matrix%name, dist_1, dbcsr_tas_get_data_type(matrix1_in), &
979                                     matrix1_in%col_blk_size, matrix1_in%row_blk_size, own_dist=.TRUE.)
980            ENDIF
981
982            ALLOCATE (matrix2_out)
983            IF (.NOT. trans2_newdist) THEN
984               CALL dbcsr_tas_create(matrix2_out, matrix2_in%matrix%name, dist_2, dbcsr_tas_get_data_type(matrix2_in), &
985                                     matrix2_in%row_blk_size, matrix2_in%col_blk_size, own_dist=.TRUE.)
986            ELSE
987               CALL dbcsr_tas_create(matrix2_out, matrix2_in%matrix%name, dist_2, dbcsr_tas_get_data_type(matrix2_in), &
988                                     matrix2_in%col_blk_size, matrix2_in%row_blk_size, own_dist=.TRUE.)
989            ENDIF
990
991            IF (.NOT. nodata1_prv) CALL dbcsr_tas_reshape(matrix1_in, matrix1_out, transposed=trans1_newdist, move_data=move_data_1)
992            IF (.NOT. nodata2_prv) CALL dbcsr_tas_reshape(matrix2_in, matrix2_out, transposed=trans2_newdist, move_data=move_data_2)
993            new1 = .TRUE.
994            new2 = .TRUE.
995
996         ELSE
997            SELECT CASE (ref)
998            CASE (1)
999               IF (unit_nr_prv > 0) THEN
1000                  WRITE (unit_nr_prv, "(T2,A,1X,A)") "Redistribution of", TRIM(matrix2_in%matrix%name)
1001               ENDIF
1002
1003               CALL change_split(matrix1_in, matrix1_out, nsplit_prv, split_rc_1, new1, &
1004                                 move_data=move_data_1, nodata=nodata1, opt_nsplit=opt_nsplit)
1005
1006               ALLOCATE (matrix2_out)
1007               CALL reshape_mm_template(matrix1_out, matrix2_in, matrix2_out, trans2, split_rc_2, &
1008                                        nodata=nodata2, move_data=move_data_2)
1009               new2 = .TRUE.
1010            CASE (2)
1011               IF (unit_nr_prv > 0) THEN
1012                  WRITE (unit_nr_prv, "(T2,A,1X,A)") "Redistribution of", TRIM(matrix1_in%matrix%name)
1013               ENDIF
1014
1015               CALL change_split(matrix2_in, matrix2_out, nsplit_prv, split_rc_2, new2, &
1016                                 move_data=move_data_2, nodata=nodata2, opt_nsplit=opt_nsplit)
1017
1018               ALLOCATE (matrix1_out)
1019               CALL reshape_mm_template(matrix2_out, matrix1_in, matrix1_out, trans1, split_rc_1, &
1020                                        nodata=nodata1, move_data=move_data_1)
1021               new1 = .TRUE.
1022            END SELECT
1023         ENDIF
1024      ENDIF
1025
1026      IF (PRESENT(move_data_1) .AND. new1) move_data_1 = .TRUE.
1027      IF (PRESENT(move_data_2) .AND. new2) move_data_2 = .TRUE.
1028
1029      CALL timestop(handle)
1030
1031   END SUBROUTINE
1032
1033   SUBROUTINE change_split(matrix_in, matrix_out, nsplit, split_rowcol, is_new, opt_nsplit, move_data, nodata)
1034      !! Change split factor without redistribution
1035
1036      TYPE(dbcsr_tas_type), TARGET, &
1037         INTENT(INOUT)                           :: matrix_in
1038      TYPE(dbcsr_tas_type), POINTER, INTENT(OUT) :: matrix_out
1039      INTEGER, INTENT(IN)                        :: nsplit
1040         !! new split factor, set to 0 to not change split of matrix_in
1041      INTEGER, INTENT(IN)                        :: split_rowcol
1042         !! split rows or columns
1043      LOGICAL, INTENT(OUT)                       :: is_new
1044         !! whether matrix_out is new or a pointer to matrix_in
1045      LOGICAL, INTENT(IN), OPTIONAL              :: opt_nsplit
1046         !! whether nsplit should be optimized for current process grid
1047      LOGICAL, INTENT(IN), OPTIONAL              :: nodata
1048         !! Data of matrix_in should not be copied to matrix_out
1049      LOGICAL, INTENT(INOUT), OPTIONAL           :: move_data
1050         !! memory optimization: move data such that matrix_in is empty on return.
1051
1052      INTEGER                                    :: &
1053         mp_comm, split_rc, nsplit_old, handle, data_type, nsplit_new, nsplit_prv
1054      TYPE(dbcsr_tas_split_info)                 :: split_info
1055      CHARACTER(len=default_string_length)       :: name
1056      TYPE(dbcsr_tas_distribution_type)          :: dist
1057      LOGICAL                                    :: nodata_prv
1058      CLASS(dbcsr_tas_distribution), ALLOCATABLE :: rdist, cdist
1059      CLASS(dbcsr_tas_rowcol_data), ALLOCATABLE  :: rbsize, cbsize
1060      CHARACTER(LEN=*), PARAMETER                :: routineN = 'change_split', &
1061                                                    routineP = moduleN//':'//routineN
1062      NULLIFY (matrix_out)
1063
1064      is_new = .TRUE.
1065
1066      CALL dbcsr_tas_get_split_info(dbcsr_tas_info(matrix_in), mp_comm=mp_comm, &
1067                                    split_rowcol=split_rc, nsplit=nsplit_old)
1068
1069      IF (nsplit == 0) THEN
1070         IF (split_rowcol == split_rc) THEN
1071            matrix_out => matrix_in
1072            is_new = .FALSE.
1073            RETURN
1074         ELSE
1075            nsplit_prv = 1
1076         ENDIF
1077      ELSE
1078         nsplit_prv = nsplit
1079      ENDIF
1080
1081      CALL timeset(routineN, handle)
1082
1083      nodata_prv = .FALSE.
1084      IF (PRESENT(nodata)) nodata_prv = nodata
1085
1086      CALL dbcsr_tas_get_info(matrix_in, data_type=data_type, name=name, &
1087                              row_blk_size=rbsize, col_blk_size=cbsize, &
1088                              proc_row_dist=rdist, proc_col_dist=cdist)
1089
1090      CALL dbcsr_tas_create_split(split_info, mp_comm, split_rowcol, nsplit_prv, opt_nsplit=opt_nsplit)
1091
1092      CALL dbcsr_tas_get_split_info(split_info, nsplit=nsplit_new)
1093
1094      IF (nsplit_old == nsplit_new .AND. split_rc == split_rowcol) THEN
1095         matrix_out => matrix_in
1096         is_new = .FALSE.
1097         CALL dbcsr_tas_release_info(split_info)
1098         CALL timestop(handle)
1099         RETURN
1100      ENDIF
1101
1102      CALL dbcsr_tas_distribution_new(dist, mp_comm, rdist, cdist, &
1103                                      split_info=split_info)
1104
1105      CALL dbcsr_tas_release_info(split_info)
1106
1107      ALLOCATE (matrix_out)
1108      CALL dbcsr_tas_create(matrix_out, name, dist, &
1109                            data_type, &
1110                            rbsize, cbsize, own_dist=.TRUE.)
1111
1112      IF (.NOT. nodata_prv) CALL dbcsr_tas_copy(matrix_out, matrix_in)
1113
1114      IF (PRESENT(move_data)) THEN
1115         IF (.NOT. nodata_prv) THEN
1116            IF (move_data) CALL dbcsr_tas_clear(matrix_in)
1117            move_data = .TRUE.
1118         ENDIF
1119      ENDIF
1120
1121      CALL timestop(handle)
1122   END SUBROUTINE
1123
1124   FUNCTION dist_compatible(mat_a, mat_b, split_rc_a, split_rc_b, unit_nr)
1125      !! Check whether matrices have same distribution and same split.
1126      TYPE(dbcsr_tas_type), INTENT(IN)           :: mat_a, mat_b
1127      INTEGER, INTENT(IN)                        :: split_rc_a, split_rc_b
1128      INTEGER, INTENT(IN), OPTIONAL              :: unit_nr
1129      LOGICAL                                    :: dist_compatible
1130
1131      INTEGER                                    :: res, same_local_rowcols, split_check
1132      TYPE(dbcsr_tas_split_info)                 :: info_a, info_b
1133      INTEGER                                    :: unit_nr_prv, numproc
1134      INTEGER, DIMENSION(2)                      :: pdims_a, pdims_b, pcoord_a, pcoord_b
1135
1136      unit_nr_prv = prep_output_unit(unit_nr)
1137
1138      dist_compatible = .FALSE.
1139
1140      info_a = dbcsr_tas_info(mat_a)
1141      info_b = dbcsr_tas_info(mat_b)
1142      CALL dbcsr_tas_get_split_info(info_a, split_rowcol=split_check)
1143      IF (split_check /= split_rc_a) RETURN
1144      CALL dbcsr_tas_get_split_info(info_b, split_rowcol=split_check)
1145      IF (split_check /= split_rc_b) RETURN
1146
1147      ! check if communicators are equivalent (global process grid and subgrids)
1148      ! Note: mpi_comm_compare is not sufficient since this does not compare associated Cartesian grids.
1149      ! It's sufficient to check dimensions of global grid and rank order of subgrids.
1150      CALL mp_environ(numproc, pdims_a, pcoord_a, info_a%mp_comm)
1151      CALL mp_environ(numproc, pdims_b, pcoord_b, info_b%mp_comm)
1152      IF (.NOT. array_eq(pdims_a, pdims_b)) THEN
1153         RETURN
1154      ENDIF
1155
1156      CALL mp_comm_compare(info_a%mp_comm_group, info_b%mp_comm_group, res)
1157      IF (res .GT. 1) THEN
1158         RETURN
1159      ENDIF
1160
1161      IF (unit_nr_prv > 0) THEN
1162         WRITE (unit_nr_prv, *) "mp comm compatible"
1163      ENDIF
1164
1165      IF (mat_a%dist%info%split_rowcol == mat_b%dist%info%split_rowcol) THEN
1166
1167         IF (unit_nr_prv > 0) THEN
1168            WRITE (unit_nr_prv, *) "split compatible"
1169         ENDIF
1170
1171         same_local_rowcols = MERGE(1, 0, array_eq(mat_a%dist%local_rowcols, mat_b%dist%local_rowcols))
1172         CALL mp_sum(same_local_rowcols, info_a%mp_comm)
1173
1174         IF (same_local_rowcols == numproc) THEN
1175            IF (unit_nr_prv > 0) THEN
1176               WRITE (unit_nr_prv, *) "local rowcols compatible"
1177            ENDIF
1178            dist_compatible = .TRUE.
1179         ELSE
1180            IF (unit_nr_prv > 0) THEN
1181               WRITE (unit_nr_prv, *) "local rowcols A", mat_a%dist%local_rowcols
1182               WRITE (unit_nr_prv, *) "local rowcols B", mat_b%dist%local_rowcols
1183            ENDIF
1184         ENDIF
1185      ENDIF
1186
1187      IF (unit_nr_prv > 0) THEN
1188         WRITE (unit_nr_prv, *) "is compatible?", dist_compatible
1189      ENDIF
1190
1191   END FUNCTION
1192
1193   SUBROUTINE reshape_mm_template(template, matrix_in, matrix_out, trans, split_rc, nodata, move_data)
1194      !! Reshape matrix_in s.t. it has same process grid, distribution and split as template
1195      TYPE(dbcsr_tas_type), INTENT(IN)           :: template
1196      TYPE(dbcsr_tas_type), INTENT(INOUT)        :: matrix_in
1197      TYPE(dbcsr_tas_type), INTENT(OUT)          :: matrix_out
1198      CHARACTER(LEN=1), INTENT(INOUT)            :: trans
1199      INTEGER, INTENT(IN)                        :: split_rc
1200      LOGICAL, INTENT(IN), OPTIONAL              :: nodata, move_data
1201      CLASS(dbcsr_tas_distribution), ALLOCATABLE :: row_dist, col_dist
1202
1203      TYPE(dbcsr_tas_distribution_type)          :: dist_new
1204      TYPE(dbcsr_tas_split_info)                 :: info_template, info_matrix
1205      INTEGER                                    :: mp_comm, dim_split_template, dim_split_matrix, &
1206                                                    numnodes, handle
1207      INTEGER, DIMENSION(2)                      :: pcoord, pdims
1208      LOGICAL                                    :: nodata_prv, transposed
1209      CHARACTER(LEN=*), PARAMETER :: routineN = 'reshape_mm_template', &
1210                                     routineP = moduleN//':'//routineN
1211
1212      CALL timeset(routineN, handle)
1213
1214      IF (PRESENT(nodata)) THEN
1215         nodata_prv = nodata
1216      ELSE
1217         nodata_prv = .FALSE.
1218      ENDIF
1219
1220      info_template = dbcsr_tas_info(template)
1221      info_matrix = dbcsr_tas_info(matrix_in)
1222
1223      dim_split_template = info_template%split_rowcol
1224      dim_split_matrix = split_rc
1225
1226      transposed = dim_split_template .NE. dim_split_matrix
1227      IF (transposed) THEN
1228         SELECT CASE (trans)
1229         CASE (dbcsr_transpose)
1230            trans = dbcsr_no_transpose
1231         CASE (dbcsr_no_transpose)
1232            trans = dbcsr_transpose
1233         END SELECT
1234      ENDIF
1235
1236      CALL mp_environ(numnodes, pdims, pcoord, info_template%mp_comm)
1237
1238      SELECT CASE (dim_split_template)
1239      CASE (1)
1240         IF (.NOT. transposed) THEN
1241            ALLOCATE (row_dist, source=template%dist%row_dist)
1242            ALLOCATE (col_dist, source=dbcsr_tas_dist_arb_default(pdims(2), matrix_in%nblkcols, matrix_in%col_blk_size))
1243         ELSE
1244            ALLOCATE (row_dist, source=template%dist%row_dist)
1245            ALLOCATE (col_dist, source=dbcsr_tas_dist_arb_default(pdims(2), matrix_in%nblkrows, matrix_in%row_blk_size))
1246         ENDIF
1247      CASE (2)
1248         IF (.NOT. transposed) THEN
1249            ALLOCATE (row_dist, source=dbcsr_tas_dist_arb_default(pdims(1), matrix_in%nblkrows, matrix_in%row_blk_size))
1250            ALLOCATE (col_dist, source=template%dist%col_dist)
1251         ELSE
1252            ALLOCATE (row_dist, source=dbcsr_tas_dist_arb_default(pdims(1), matrix_in%nblkcols, matrix_in%col_blk_size))
1253            ALLOCATE (col_dist, source=template%dist%col_dist)
1254         ENDIF
1255      END SELECT
1256
1257      CALL dbcsr_tas_get_split_info(info_template, mp_comm=mp_comm)
1258      CALL dbcsr_tas_distribution_new(dist_new, mp_comm, row_dist, col_dist, split_info=info_template)
1259      IF (.NOT. transposed) THEN
1260         CALL dbcsr_tas_create(matrix_out, matrix_in%matrix%name, dist_new, dbcsr_tas_get_data_type(matrix_in), &
1261                               matrix_in%row_blk_size, matrix_in%col_blk_size, own_dist=.TRUE.)
1262      ELSE
1263         CALL dbcsr_tas_create(matrix_out, matrix_in%matrix%name, dist_new, dbcsr_tas_get_data_type(matrix_in), &
1264                               matrix_in%col_blk_size, matrix_in%row_blk_size, own_dist=.TRUE.)
1265      ENDIF
1266
1267      IF (.NOT. nodata_prv) CALL dbcsr_tas_reshape(matrix_in, matrix_out, transposed=transposed, move_data=move_data)
1268
1269      CALL timestop(handle)
1270
1271   END SUBROUTINE
1272
1273   SUBROUTINE dbcsr_tas_result_index(transa, transb, transc, matrix_a, matrix_b, matrix_c, filter_eps, &
1274                                     unit_nr, blk_ind, nze, retain_sparsity)
1275      !! Estimate sparsity pattern of C resulting from A x B = C by multiplying the block norms of A and B
1276      !! Same dummy arguments as dbcsr_tas_multiply
1277      CHARACTER(LEN=1), INTENT(IN)               :: transa, transb, transc
1278      TYPE(dbcsr_tas_type), INTENT(INOUT), TARGET        :: matrix_a, matrix_b, matrix_c
1279      TYPE(dbcsr_tas_type), POINTER                      :: matrix_a_bnorm, matrix_b_bnorm, matrix_c_bnorm
1280      REAL(KIND=real_8), INTENT(IN), OPTIONAL    :: filter_eps
1281      INTEGER, INTENT(IN), OPTIONAL              :: unit_nr
1282      INTEGER(int_8), DIMENSION(:, :), ALLOCATABLE, INTENT(OUT), OPTIONAL :: blk_ind
1283      LOGICAL, INTENT(IN), OPTIONAL :: retain_sparsity
1284      INTEGER(int_8), INTENT(OUT), OPTIONAL :: nze
1285
1286      CHARACTER(LEN=*), PARAMETER :: routineN = 'dbcsr_tas_result_index', &
1287                                     routineP = moduleN//':'//routineN
1288      LOGICAL :: retain_sparsity_prv
1289      INTEGER :: bn, row_size, col_size, handle, iblk, mp_comm, nblk
1290      INTEGER(int_8) :: row, col
1291      TYPE(dbcsr_tas_iterator) :: iter
1292
1293      CALL timeset(routineN, handle)
1294
1295      IF (PRESENT(retain_sparsity)) THEN
1296         retain_sparsity_prv = retain_sparsity
1297      ELSE
1298         retain_sparsity_prv = .FALSE.
1299      ENDIF
1300
1301      IF (.NOT. retain_sparsity_prv) THEN
1302         ALLOCATE (matrix_a_bnorm, matrix_b_bnorm, matrix_c_bnorm)
1303         CALL create_block_norms_matrix(matrix_a, matrix_a_bnorm)
1304         CALL create_block_norms_matrix(matrix_b, matrix_b_bnorm)
1305         CALL create_block_norms_matrix(matrix_c, matrix_c_bnorm, nodata=.TRUE.)
1306
1307         CALL dbcsr_tas_multiply(transa, transb, transc, dbcsr_scalar(1.0_real_8), matrix_a_bnorm, &
1308                                 matrix_b_bnorm, dbcsr_scalar(0.0_real_8), matrix_c_bnorm, &
1309                                 filter_eps=filter_eps, move_data_a=.TRUE., move_data_b=.TRUE., &
1310                                 simple_split=.TRUE., unit_nr=unit_nr)
1311         CALL dbcsr_tas_destroy(matrix_a_bnorm)
1312         CALL dbcsr_tas_destroy(matrix_b_bnorm)
1313
1314         DEALLOCATE (matrix_a_bnorm, matrix_b_bnorm)
1315      ELSE
1316         matrix_c_bnorm => matrix_c
1317      ENDIF
1318
1319      nblk = dbcsr_tas_get_num_blocks(matrix_c_bnorm)
1320      IF (PRESENT(blk_ind)) ALLOCATE (blk_ind(nblk, 2))
1321
1322      CALL dbcsr_tas_iterator_start(iter, matrix_c_bnorm)
1323      IF (PRESENT(nze)) nze = 0
1324      DO iblk = 1, nblk
1325         CALL dbcsr_tas_iterator_next_block(iter, row, col, bn)
1326         row_size = matrix_c%row_blk_size%data(row)
1327         col_size = matrix_c%col_blk_size%data(col)
1328         IF (PRESENT(nze)) nze = nze + row_size*col_size
1329         IF (PRESENT(blk_ind)) blk_ind(iblk, :) = [row, col]
1330      ENDDO
1331      CALL dbcsr_tas_iterator_stop(iter)
1332
1333      IF (PRESENT(nze)) THEN
1334         CALL dbcsr_tas_get_split_info(dbcsr_tas_info(matrix_a), mp_comm=mp_comm)
1335         CALL mp_sum(nze, mp_comm)
1336      ENDIF
1337
1338      IF (.NOT. retain_sparsity_prv) THEN
1339         CALL dbcsr_tas_destroy(matrix_c_bnorm)
1340         DEALLOCATE (matrix_c_bnorm)
1341      ENDIF
1342
1343      CALL timestop(handle)
1344
1345   END SUBROUTINE
1346
1347   FUNCTION split_factor_estimate(max_mm_dim, nze_a, nze_b, nze_c, numnodes) RESULT(nsplit)
1348      !! Estimate optimal split factor for AxB=C from occupancies (number of non-zero elements)
1349      !! This estimate is based on the minimization of communication volume whereby
1350      !! the communication of CARMA n-split step and CANNON-multiplication of submatrices are
1351      !! considered.
1352      !! \result estimated split factor
1353
1354      INTEGER, INTENT(IN)                         :: max_mm_dim
1355      INTEGER(KIND=int_8), INTENT(IN)             :: nze_a, nze_b, nze_c
1356         !! number of non-zeroes in A
1357         !! number of non-zeroes in B
1358         !! number of non-zeroes in C
1359      INTEGER, INTENT(IN)                         :: numnodes
1360         !! number of MPI ranks
1361      INTEGER                                     :: nsplit, nsplit_comm, nsplit_memory
1362      INTEGER(KIND=int_8)                         :: max_nze, min_nze
1363
1364      SELECT CASE (max_mm_dim)
1365      CASE (1)
1366         min_nze = MAX(nze_b, 1_int_8)
1367         max_nze = MAX(MAXVAL([nze_a, nze_c]), 1_int_8)
1368      CASE (2)
1369         min_nze = MAX(nze_c, 1_int_8)
1370         max_nze = MAX(MAXVAL([nze_a, nze_b]), 1_int_8)
1371      CASE (3)
1372         min_nze = MAX(nze_a, 1_int_8)
1373         max_nze = MAX(MAXVAL([nze_b, nze_c]), 1_int_8)
1374      CASE DEFAULT
1375         DBCSR_ABORT("")
1376      END SELECT
1377
1378      nsplit_comm = NINT((REAL(nze_a + nze_b, real_8)/(2*min_nze))**(2._real_8/3)*REAL(numnodes, real_8)**(1._real_8/3))
1379      IF (nsplit_comm == 0) nsplit_comm = 1
1380
1381      ! nsplit_memory protects against excess memory usage
1382      ! actual split factor may be up to default_nsplit_accept_ratio_1 times larger, so the largest nsplit
1383      ! that fits into memory used by A or B needs to be divided by this factor
1384      nsplit_memory = CEILING(REAL((max_nze - 1)/min_nze + 1, real_8)/default_nsplit_accept_ratio_1)
1385
1386      nsplit = MIN(nsplit_comm, nsplit_memory)
1387
1388   END FUNCTION
1389
1390   SUBROUTINE create_block_norms_matrix(matrix_in, matrix_out, nodata)
1391      !! Create a matrix with block sizes one that contains the block norms of matrix_in
1392      TYPE(dbcsr_tas_type), INTENT(INOUT)        :: matrix_in
1393      TYPE(dbcsr_tas_type), INTENT(OUT)          :: matrix_out
1394      LOGICAL, INTENT(IN), OPTIONAL              :: nodata
1395      TYPE(dbcsr_tas_blk_size_one)               :: row_blk_size, col_blk_size
1396      TYPE(dbcsr_tas_iterator)                   :: iter
1397      INTEGER(KIND=int_8)                        :: row, column, nblkrows, nblkcols
1398      CHARACTER(len=default_string_length)       :: name
1399      INTEGER                                    :: data_type
1400
1401#:for dparam, dtype, dsuffix in dtype_float_list
1402      ${dtype}$, DIMENSION(:, :), POINTER        :: block_get_${dsuffix}$
1403      ${dtype}$, DIMENSION(:, :), POINTER        :: block_put_${dsuffix}$
1404#:endfor
1405      LOGICAL                                    :: tr, nodata_prv, found
1406
1407      DBCSR_ASSERT(matrix_in%valid)
1408
1409      IF (PRESENT(nodata)) THEN
1410         nodata_prv = nodata
1411      ELSE
1412         nodata_prv = .FALSE.
1413      ENDIF
1414
1415      CALL dbcsr_tas_get_info(matrix_in, data_type=data_type, name=name, &
1416                              nblkrows_total=nblkrows, nblkcols_total=nblkcols)
1417
1418      row_blk_size = dbcsr_tas_blk_size_one(nblkrows)
1419      col_blk_size = dbcsr_tas_blk_size_one(nblkcols)
1420
1421      ! not sure if assumption that same distribution can be taken still holds
1422      CALL dbcsr_tas_create(matrix_out, name, matrix_in%dist, &
1423                            data_type, &
1424                            row_blk_size, col_blk_size)
1425
1426      IF (.NOT. nodata_prv) THEN
1427         CALL dbcsr_tas_reserve_blocks(matrix_in, matrix_out)
1428
1429         CALL dbcsr_tas_iterator_start(iter, matrix_in)
1430
1431         DO WHILE (dbcsr_tas_iterator_blocks_left(iter))
1432
1433#:for dparam, dtype, dsuffix in dtype_float_list
1434            IF (data_type == ${dparam}$) THEN
1435               CALL dbcsr_tas_iterator_next_block(iter, row, column, block_get_${dsuffix}$, tr)
1436               CALL dbcsr_tas_get_block_p(matrix_out, row, column, block_put_${dsuffix}$, tr, found)
1437               DBCSR_ASSERT(found)
1438               block_put_${dsuffix}$ (1, 1) = SQRT(SUM(block_get_${dsuffix}$**2)) ! norm2 works only for real
1439            ENDIF
1440#:endfor
1441         ENDDO
1442         CALL dbcsr_tas_iterator_stop(iter)
1443      ENDIF
1444
1445   END SUBROUTINE
1446
1447   SUBROUTINE convert_to_new_pgrid(mp_comm_cart, matrix_in, matrix_out, move_data, nodata, optimize_pgrid)
1448      !! Convert a DBCSR matrix to a new process grid
1449
1450      INTEGER, INTENT(IN)                        :: mp_comm_cart
1451         !! new process grid
1452      TYPE(dbcsr_type), INTENT(INOUT), TARGET    :: matrix_in
1453      TYPE(dbcsr_type), INTENT(OUT), POINTER     :: matrix_out
1454      LOGICAL, INTENT(IN), OPTIONAL              :: move_data, nodata
1455         !! memory optimization: move data such that matrix_in is empty on return.
1456         !! Data of matrix_in should not be copied to matrix_out
1457      LOGICAL, INTENT(IN), OPTIONAL              :: optimize_pgrid
1458         !! Whether to change process grid
1459
1460      INTEGER                                    :: &
1461         nbrows, nbcols, data_type, nproc, handle
1462      INTEGER, DIMENSION(2)                      :: pdims, pcoord
1463      INTEGER, DIMENSION(:), POINTER, CONTIGUOUS :: row_dist, col_dist, rbsize, rcsize
1464      TYPE(dbcsr_distribution_obj)               :: dist, dist_old
1465      TYPE(dbcsr_mp_obj)                         :: mp_obj
1466      CHARACTER(len=default_string_length)       :: name
1467      LOGICAL                                    :: nodata_prv, optimize_pgrid_prv
1468      CHARACTER(LEN=*), PARAMETER                :: routineN = 'convert_to_new_pgrid', &
1469                                                    routineP = moduleN//':'//routineN
1470
1471      NULLIFY (row_dist, col_dist, rbsize, rcsize)
1472
1473      IF (PRESENT(optimize_pgrid)) THEN
1474         optimize_pgrid_prv = optimize_pgrid
1475      ELSE
1476         optimize_pgrid_prv = .TRUE.
1477      ENDIF
1478
1479      IF (.NOT. optimize_pgrid_prv) THEN
1480         matrix_out => matrix_in
1481         RETURN
1482      ENDIF
1483
1484      CALL timeset(routineN, handle)
1485
1486      IF (PRESENT(nodata)) THEN
1487         nodata_prv = nodata
1488      ELSE
1489         nodata_prv = .FALSE.
1490      ENDIF
1491
1492      ALLOCATE (matrix_out)
1493
1494      CALL dbcsr_get_info(matrix_in, nblkrows_total=nbrows, nblkcols_total=nbcols, &
1495                          row_blk_size=rbsize, col_blk_size=rcsize, &
1496                          data_type=data_type, distribution=dist_old, name=name)
1497      CALL mp_environ(nproc, pdims, pcoord, mp_comm_cart)
1498
1499      ALLOCATE (row_dist(nbrows), col_dist(nbcols))
1500      CALL dbcsr_tas_default_distvec(nbrows, pdims(1), rbsize, row_dist)
1501      CALL dbcsr_tas_default_distvec(nbcols, pdims(2), rcsize, col_dist)
1502
1503      mp_obj = dbcsr_mp_environ(mp_comm_cart)
1504      CALL dbcsr_distribution_new(dist, mp_obj, row_dist, col_dist, reuse_arrays=.TRUE.)
1505      CALL dbcsr_mp_release(mp_obj)
1506
1507      CALL dbcsr_create(matrix_out, name, dist, dbcsr_type_no_symmetry, rbsize, rcsize, data_type=data_type)
1508      CALL dbcsr_distribution_release(dist)
1509
1510      IF (.NOT. nodata_prv) THEN
1511         CALL dbcsr_redistribute(matrix_in, matrix_out)
1512         IF (PRESENT(move_data)) THEN
1513            IF (move_data) CALL dbcsr_clear(matrix_in)
1514         ENDIF
1515      ENDIF
1516
1517      CALL timestop(handle)
1518   END SUBROUTINE
1519
1520   SUBROUTINE dbcsr_tas_batched_mm_init(matrix)
1521      TYPE(dbcsr_tas_type), INTENT(INOUT) :: matrix
1522      matrix%do_batched = 1
1523      ALLOCATE (matrix%mm_storage)
1524      matrix%mm_storage%batched_out = .FALSE.
1525   END SUBROUTINE
1526
1527   SUBROUTINE dbcsr_tas_batched_mm_finalize(matrix)
1528      TYPE(dbcsr_tas_type), INTENT(INOUT) :: matrix
1529
1530      IF (matrix%do_batched == 0) RETURN
1531      ASSOCIATE (storage => matrix%mm_storage)
1532         IF (storage%batched_out) THEN
1533            CALL dbcsr_tas_merge(storage%store_batched%matrix, storage%store_batched_repl, move_data=.TRUE.)
1534            CALL dbcsr_scale(matrix%matrix, storage%batched_beta)
1535            CALL dbcsr_tas_reshape(storage%store_batched, matrix, summation=.TRUE., &
1536                                   transposed=storage%batched_trans, move_data=.TRUE.)
1537            CALL dbcsr_tas_destroy(storage%store_batched)
1538            DEALLOCATE (storage%store_batched)
1539         ENDIF
1540
1541         IF (ASSOCIATED(storage%store_batched_repl)) THEN
1542            CALL dbcsr_tas_destroy(storage%store_batched_repl)
1543            DEALLOCATE (storage%store_batched_repl)
1544         ENDIF
1545
1546         storage%batched_out = .FALSE.
1547      END ASSOCIATE
1548
1549      DEALLOCATE (matrix%mm_storage)
1550      matrix%do_batched = 0
1551
1552   END SUBROUTINE
1553
1554END MODULE
1555