1!--------------------------------------------------------------------------------------------------!
2! Copyright (C) by the DBCSR developers group - All rights reserved                                !
3! This file is part of the DBCSR library.                                                          !
4!                                                                                                  !
5! For information on the license, see the LICENSE file.                                            !
6! For further information please visit https://dbcsr.cp2k.org                                      !
7! SPDX-License-Identifier: GPL-2.0+                                                                !
8!--------------------------------------------------------------------------------------------------!
9
10MODULE dbcsr_tensor
11   !! DBCSR tensor framework for block-sparse tensor contraction.
12   !! Representation of n-rank tensors as DBCSR tall-and-skinny matrices.
13   !! Support for arbitrary redistribution between different representations.
14   !! Support for arbitrary tensor contractions
15   !! \todo implement checks and error messages
16
17#:include "dbcsr_tensor.fypp"
18#:set maxdim = maxrank
19#:set ndims = range(2,maxdim+1)
20
21   USE dbcsr_allocate_wrap, ONLY: &
22      allocate_any
23   USE dbcsr_array_list_methods, ONLY: &
24      get_arrays, reorder_arrays, get_ith_array, array_list, array_sublist, check_equal, array_eq_i, &
25      create_array_list, destroy_array_list, sizes_of_arrays
26   USE dbcsr_api, ONLY: &
27      dbcsr_type, dbcsr_iterator_type, dbcsr_iterator_blocks_left, &
28      dbcsr_iterator_next_block, dbcsr_iterator_start, dbcsr_iterator_stop, &
29      dbcsr_transpose, dbcsr_no_transpose, dbcsr_scalar, dbcsr_put_block, &
30      ${uselist(dtype_float_param)}$, dbcsr_clear, &
31      dbcsr_release, dbcsr_desymmetrize, dbcsr_has_symmetry
32   USE dbcsr_tas_types, ONLY: &
33      dbcsr_tas_split_info
34   USE dbcsr_tas_base, ONLY: &
35      dbcsr_tas_copy, dbcsr_tas_finalize, dbcsr_tas_get_data_type, dbcsr_tas_get_info, dbcsr_tas_info
36   USE dbcsr_tas_mm, ONLY: &
37      dbcsr_tas_multiply, dbcsr_tas_batched_mm_init, dbcsr_tas_batched_mm_finalize, dbcsr_tas_result_index, &
38      dbcsr_tas_batched_mm_complete, dbcsr_tas_set_batched_state
39   USE dbcsr_tensor_block, ONLY: &
40      dbcsr_t_iterator_type, dbcsr_t_get_block, dbcsr_t_put_block, dbcsr_t_iterator_start, &
41      dbcsr_t_iterator_blocks_left, dbcsr_t_iterator_stop, dbcsr_t_iterator_next_block, &
42      ndims_iterator, dbcsr_t_reserve_blocks, block_nd, destroy_block
43   USE dbcsr_tensor_index, ONLY: &
44      dbcsr_t_get_mapping_info, nd_to_2d_mapping, dbcsr_t_inverse_order, permute_index, get_nd_indices_tensor, &
45      ndims_mapping_row, ndims_mapping_column, ndims_mapping
46   USE dbcsr_tensor_types, ONLY: &
47      dbcsr_t_create, dbcsr_t_get_data_type, dbcsr_t_type, ndims_tensor, dims_tensor, &
48      dbcsr_t_distribution_type, dbcsr_t_distribution, dbcsr_t_nd_mp_comm, dbcsr_t_destroy, &
49      dbcsr_t_distribution_destroy, dbcsr_t_distribution_new_expert, dbcsr_t_get_stored_coordinates, &
50      blk_dims_tensor, dbcsr_t_hold, dbcsr_t_pgrid_type, mp_environ_pgrid, dbcsr_t_filter, &
51      dbcsr_t_clear, dbcsr_t_finalize, dbcsr_t_get_num_blocks, dbcsr_t_scale, &
52      dbcsr_t_get_num_blocks_total, dbcsr_t_get_info, ndims_matrix_row, ndims_matrix_column, &
53      dbcsr_t_max_nblks_local, dbcsr_t_default_distvec, dbcsr_t_contraction_storage, dbcsr_t_nblks_total, &
54      dbcsr_t_distribution_new, dbcsr_t_copy_contraction_storage, dbcsr_t_pgrid_destroy
55   USE dbcsr_kinds, ONLY: &
56      ${uselist(dtype_float_prec)}$, default_string_length, int_8, dp
57   USE dbcsr_mpiwrap, ONLY: &
58      mp_environ, mp_max, mp_sum, mp_comm_free, mp_cart_create, mp_sync
59   USE dbcsr_toollib, ONLY: &
60      sort
61   USE dbcsr_tensor_reshape, ONLY: &
62      dbcsr_t_reshape
63   USE dbcsr_tas_split, ONLY: &
64      dbcsr_tas_mp_comm, rowsplit, colsplit, dbcsr_tas_info_hold, dbcsr_tas_release_info, default_nsplit_accept_ratio, &
65      default_pdims_accept_ratio, dbcsr_tas_create_split
66   USE dbcsr_data_types, ONLY: &
67      dbcsr_scalar_type
68   USE dbcsr_tensor_split, ONLY: &
69      dbcsr_t_split_copyback, dbcsr_t_make_compatible_blocks, dbcsr_t_crop
70   USE dbcsr_tensor_io, ONLY: &
71      dbcsr_t_write_tensor_info, dbcsr_t_write_tensor_dist, prep_output_unit, dbcsr_t_write_split_info
72   USE dbcsr_dist_operations, ONLY: &
73      checker_tr
74   USE dbcsr_toollib, ONLY: &
75      swap
76
77#include "base/dbcsr_base_uses.f90"
78
79   IMPLICIT NONE
80   PRIVATE
81   CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'dbcsr_tensor'
82
83   PUBLIC :: &
84      dbcsr_t_contract, &
85      dbcsr_t_copy, &
86      dbcsr_t_get_block, &
87      dbcsr_t_get_stored_coordinates, &
88      dbcsr_t_inverse_order, &
89      dbcsr_t_iterator_blocks_left, &
90      dbcsr_t_iterator_next_block, &
91      dbcsr_t_iterator_start, &
92      dbcsr_t_iterator_stop, &
93      dbcsr_t_iterator_type, &
94      dbcsr_t_put_block, &
95      dbcsr_t_reserve_blocks, &
96      dbcsr_t_copy_matrix_to_tensor, &
97      dbcsr_t_copy_tensor_to_matrix, &
98      dbcsr_t_contract_index, &
99      dbcsr_t_batched_contract_init, &
100      dbcsr_t_batched_contract_finalize
101
102CONTAINS
103
104   SUBROUTINE dbcsr_t_copy(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
105      !! Copy tensor data.
106      !! Redistributes tensor data according to distributions of target and source tensor.
107      !! Permutes tensor index according to `order` argument (if present).
108      !! Source and target tensor formats are arbitrary as long as the following requirements are met:
109      !! * source and target tensors have the same rank and the same sizes in each dimension in terms
110      !!   of tensor elements (block sizes don't need to be the same).
111      !!   If `order` argument is present, sizes must match after index permutation.
112      !! OR
113      !! * target tensor is not yet created, in this case an exact copy of source tensor is returned.
114
115      TYPE(dbcsr_t_type), INTENT(INOUT), TARGET      :: tensor_in, tensor_out
116         !! Source
117         !! Target
118      INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
119         INTENT(IN), OPTIONAL                        :: order
120         !! Permutation of target tensor index. Exact same convention as order argument of RESHAPE intrinsic
121      LOGICAL, INTENT(IN), OPTIONAL                  :: summation, move_data
122      INTEGER, DIMENSION(2, ndims_tensor(tensor_in)), &
123         INTENT(IN), OPTIONAL                        :: bounds
124         !! crop tensor data: start and end index for each tensor dimension
125      INTEGER, INTENT(IN), OPTIONAL                  :: unit_nr
126      INTEGER :: handle
127
128      CALL mp_sync(tensor_in%pgrid%mp_comm_2d)
129      CALL timeset("dbcsr_t_total", handle)
130
131      ! make sure that it is safe to use dbcsr_t_copy during a batched contraction
132      CALL dbcsr_tas_batched_mm_complete(tensor_in%matrix_rep, warn=.TRUE.)
133      CALL dbcsr_tas_batched_mm_complete(tensor_out%matrix_rep, warn=.TRUE.)
134
135      CALL dbcsr_t_copy_expert(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
136      CALL mp_sync(tensor_in%pgrid%mp_comm_2d)
137      CALL timestop(handle)
138   END SUBROUTINE
139
140   SUBROUTINE dbcsr_t_copy_expert(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr)
141      !! expert routine for copying a tensor. For internal use only.
142      TYPE(dbcsr_t_type), INTENT(INOUT), TARGET      :: tensor_in, tensor_out
143      INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
144         INTENT(IN), OPTIONAL                        :: order
145      LOGICAL, INTENT(IN), OPTIONAL                  :: summation, move_data
146      INTEGER, DIMENSION(2, ndims_tensor(tensor_in)), &
147         INTENT(IN), OPTIONAL                        :: bounds
148      INTEGER, INTENT(IN), OPTIONAL                  :: unit_nr
149
150      TYPE(dbcsr_t_type), POINTER                    :: in_tmp_1, in_tmp_2, &
151                                                        in_tmp_3, out_tmp_1
152      INTEGER                                        :: handle, unit_nr_prv
153      INTEGER, DIMENSION(:), ALLOCATABLE             :: map1_in_1, map1_in_2, map2_in_1, map2_in_2
154
155      CHARACTER(LEN=*), PARAMETER :: routineN = 'dbcsr_t_copy'
156      LOGICAL                                        :: dist_compatible_tas, dist_compatible_tensor, &
157                                                        summation_prv, new_in_1, new_in_2, &
158                                                        new_in_3, new_out_1, block_compatible, &
159                                                        move_prv
160      TYPE(array_list)                               :: blk_sizes_in
161
162      CALL timeset(routineN, handle)
163
164      DBCSR_ASSERT(tensor_out%valid)
165
166      unit_nr_prv = prep_output_unit(unit_nr)
167
168      IF (PRESENT(move_data)) THEN
169         move_prv = move_data
170      ELSE
171         move_prv = .FALSE.
172      ENDIF
173
174      dist_compatible_tas = .FALSE.
175      dist_compatible_tensor = .FALSE.
176      block_compatible = .FALSE.
177      new_in_1 = .FALSE.
178      new_in_2 = .FALSE.
179      new_in_3 = .FALSE.
180      new_out_1 = .FALSE.
181
182      IF (PRESENT(summation)) THEN
183         summation_prv = summation
184      ELSE
185         summation_prv = .FALSE.
186      ENDIF
187
188      IF (PRESENT(bounds)) THEN
189         ALLOCATE (in_tmp_1)
190         CALL dbcsr_t_crop(tensor_in, in_tmp_1, bounds=bounds, move_data=move_prv)
191         new_in_1 = .TRUE.
192         move_prv = .TRUE.
193      ELSE
194         in_tmp_1 => tensor_in
195      ENDIF
196
197      IF (PRESENT(order)) THEN
198         CALL reorder_arrays(in_tmp_1%blk_sizes, blk_sizes_in, order=order)
199         block_compatible = check_equal(blk_sizes_in, tensor_out%blk_sizes)
200      ELSE
201         block_compatible = check_equal(in_tmp_1%blk_sizes, tensor_out%blk_sizes)
202      ENDIF
203
204      IF (.NOT. block_compatible) THEN
205         ALLOCATE (in_tmp_2, out_tmp_1)
206         CALL dbcsr_t_make_compatible_blocks(in_tmp_1, tensor_out, in_tmp_2, out_tmp_1, order=order, &
207                                             nodata2=.NOT. summation_prv, move_data=move_prv)
208         new_in_2 = .TRUE.; new_out_1 = .TRUE.
209         move_prv = .TRUE.
210      ELSE
211         in_tmp_2 => in_tmp_1
212         out_tmp_1 => tensor_out
213      ENDIF
214
215      IF (PRESENT(order)) THEN
216         ALLOCATE (in_tmp_3)
217         CALL dbcsr_t_permute_index(in_tmp_2, in_tmp_3, order)
218         new_in_3 = .TRUE.
219      ELSE
220         in_tmp_3 => in_tmp_2
221      ENDIF
222
223      ALLOCATE (map1_in_1(ndims_matrix_row(in_tmp_3)))
224      ALLOCATE (map1_in_2(ndims_matrix_column(in_tmp_3)))
225      CALL dbcsr_t_get_mapping_info(in_tmp_3%nd_index, map1_2d=map1_in_1, map2_2d=map1_in_2)
226
227      ALLOCATE (map2_in_1(ndims_matrix_row(out_tmp_1)))
228      ALLOCATE (map2_in_2(ndims_matrix_column(out_tmp_1)))
229      CALL dbcsr_t_get_mapping_info(out_tmp_1%nd_index, map1_2d=map2_in_1, map2_2d=map2_in_2)
230
231      IF (.NOT. PRESENT(order)) THEN
232         IF (array_eq_i(map1_in_1, map2_in_1) .AND. array_eq_i(map1_in_2, map2_in_2)) THEN
233            dist_compatible_tas = check_equal(in_tmp_3%nd_dist, out_tmp_1%nd_dist)
234         ELSEIF (array_eq_i([map1_in_1, map1_in_2], [map2_in_1, map2_in_2])) THEN
235            dist_compatible_tensor = check_equal(in_tmp_3%nd_dist, out_tmp_1%nd_dist)
236         ENDIF
237      ENDIF
238
239      IF (dist_compatible_tas) THEN
240         CALL dbcsr_tas_copy(out_tmp_1%matrix_rep, in_tmp_3%matrix_rep, summation)
241         IF (move_prv) CALL dbcsr_t_clear(in_tmp_3)
242      ELSEIF (dist_compatible_tensor) THEN
243         CALL dbcsr_t_copy_nocomm(in_tmp_3, out_tmp_1, summation)
244         IF (move_prv) CALL dbcsr_t_clear(in_tmp_3)
245      ELSE
246         CALL dbcsr_t_reshape(in_tmp_3, out_tmp_1, summation, move_data=move_prv)
247      ENDIF
248
249      IF (new_in_1) THEN
250         CALL dbcsr_t_destroy(in_tmp_1)
251         DEALLOCATE (in_tmp_1)
252      ENDIF
253
254      IF (new_in_2) THEN
255         CALL dbcsr_t_destroy(in_tmp_2)
256         DEALLOCATE (in_tmp_2)
257      ENDIF
258
259      IF (new_in_3) THEN
260         CALL dbcsr_t_destroy(in_tmp_3)
261         DEALLOCATE (in_tmp_3)
262      ENDIF
263
264      IF (new_out_1) THEN
265         IF (unit_nr_prv /= 0) THEN
266            CALL dbcsr_t_write_tensor_dist(out_tmp_1, unit_nr)
267         ENDIF
268         CALL dbcsr_t_split_copyback(out_tmp_1, tensor_out, summation)
269         CALL dbcsr_t_destroy(out_tmp_1)
270         DEALLOCATE (out_tmp_1)
271      ENDIF
272
273      CALL timestop(handle)
274
275   END SUBROUTINE
276
277   SUBROUTINE dbcsr_t_copy_nocomm(tensor_in, tensor_out, summation)
278      !! copy without communication, requires that both tensors have same process grid and distribution
279
280      TYPE(dbcsr_t_type), INTENT(INOUT) :: tensor_in
281      TYPE(dbcsr_t_type), INTENT(INOUT) :: tensor_out
282      LOGICAL, INTENT(IN), OPTIONAL                      :: summation
283         !! Whether to sum matrices b = a + b
284      TYPE(dbcsr_t_iterator_type) :: iter
285      INTEGER, DIMENSION(ndims_tensor(tensor_in))  :: ind_nd
286      INTEGER :: blk
287      TYPE(block_nd) :: blk_data
288      LOGICAL :: found
289
290      CHARACTER(LEN=*), PARAMETER :: routineN = 'dbcsr_t_copy_nocomm'
291      INTEGER :: handle
292
293      CALL timeset(routineN, handle)
294      DBCSR_ASSERT(tensor_out%valid)
295
296      IF (PRESENT(summation)) THEN
297         IF (.NOT. summation) CALL dbcsr_t_clear(tensor_out)
298      ELSE
299         CALL dbcsr_t_clear(tensor_out)
300      ENDIF
301
302      CALL dbcsr_t_reserve_blocks(tensor_in, tensor_out)
303
304      CALL dbcsr_t_iterator_start(iter, tensor_in)
305      DO WHILE (dbcsr_t_iterator_blocks_left(iter))
306         CALL dbcsr_t_iterator_next_block(iter, ind_nd, blk)
307         CALL dbcsr_t_get_block(tensor_in, ind_nd, blk_data, found)
308         DBCSR_ASSERT(found)
309         CALL dbcsr_t_put_block(tensor_out, ind_nd, blk_data, summation=summation)
310         CALL destroy_block(blk_data)
311      ENDDO
312      CALL dbcsr_t_iterator_stop(iter)
313
314      CALL timestop(handle)
315   END SUBROUTINE
316
317   SUBROUTINE dbcsr_t_copy_matrix_to_tensor(matrix_in, tensor_out, summation)
318      !! copy matrix to tensor.
319
320      TYPE(dbcsr_type), TARGET, INTENT(IN)               :: matrix_in
321      TYPE(dbcsr_t_type), INTENT(INOUT)                  :: tensor_out
322      LOGICAL, INTENT(IN), OPTIONAL                      :: summation
323         !! tensor_out = tensor_out + matrix_in
324      TYPE(dbcsr_type), POINTER                          :: matrix_in_desym
325
326      INTEGER, DIMENSION(2)                              :: ind_2d
327      REAL(KIND=real_8), ALLOCATABLE, DIMENSION(:, :)    :: block_arr
328      REAL(KIND=real_8), DIMENSION(:, :), POINTER        :: block
329      TYPE(dbcsr_iterator_type)                          :: iter
330      LOGICAL                                            :: tr
331
332      INTEGER                                            :: handle
333      CHARACTER(LEN=*), PARAMETER :: routineN = 'dbcsr_t_copy_matrix_to_tensor'
334
335      CALL timeset(routineN, handle)
336      DBCSR_ASSERT(tensor_out%valid)
337
338      NULLIFY (block)
339
340      IF (dbcsr_has_symmetry(matrix_in)) THEN
341         ALLOCATE (matrix_in_desym)
342         CALL dbcsr_desymmetrize(matrix_in, matrix_in_desym)
343      ELSE
344         matrix_in_desym => matrix_in
345      ENDIF
346
347      IF (PRESENT(summation)) THEN
348         IF (.NOT. summation) CALL dbcsr_t_clear(tensor_out)
349      ELSE
350         CALL dbcsr_t_clear(tensor_out)
351      ENDIF
352
353      CALL dbcsr_t_reserve_blocks(matrix_in_desym, tensor_out)
354
355      CALL dbcsr_iterator_start(iter, matrix_in_desym)
356      DO WHILE (dbcsr_iterator_blocks_left(iter))
357         CALL dbcsr_iterator_next_block(iter, ind_2d(1), ind_2d(2), block, tr)
358         CALL allocate_any(block_arr, source=block)
359         CALL dbcsr_t_put_block(tensor_out, ind_2d, SHAPE(block_arr), block_arr, summation=summation)
360         DEALLOCATE (block_arr)
361      ENDDO
362      CALL dbcsr_iterator_stop(iter)
363
364      IF (dbcsr_has_symmetry(matrix_in)) THEN
365         CALL dbcsr_release(matrix_in_desym)
366         DEALLOCATE (matrix_in_desym)
367      ENDIF
368
369      CALL timestop(handle)
370
371   END SUBROUTINE
372
373   SUBROUTINE dbcsr_t_copy_tensor_to_matrix(tensor_in, matrix_out, summation)
374      !! copy tensor to matrix
375
376      TYPE(dbcsr_t_type), INTENT(INOUT)      :: tensor_in
377      TYPE(dbcsr_type), INTENT(INOUT)        :: matrix_out
378      LOGICAL, INTENT(IN), OPTIONAL          :: summation
379         !! matrix_out = matrix_out + tensor_in
380      TYPE(dbcsr_t_iterator_type)            :: iter
381      INTEGER                                :: blk, handle
382      INTEGER, DIMENSION(2)                  :: ind_2d
383      REAL(KIND=real_8), DIMENSION(:, :), ALLOCATABLE :: block
384      CHARACTER(LEN=*), PARAMETER :: routineN = 'dbcsr_t_copy_tensor_to_matrix'
385      LOGICAL :: found
386
387      CALL timeset(routineN, handle)
388
389      IF (PRESENT(summation)) THEN
390         IF (.NOT. summation) CALL dbcsr_clear(matrix_out)
391      ELSE
392         CALL dbcsr_clear(matrix_out)
393      ENDIF
394
395      CALL dbcsr_t_reserve_blocks(tensor_in, matrix_out)
396
397      CALL dbcsr_t_iterator_start(iter, tensor_in)
398      DO WHILE (dbcsr_t_iterator_blocks_left(iter))
399         CALL dbcsr_t_iterator_next_block(iter, ind_2d, blk)
400         IF (dbcsr_has_symmetry(matrix_out) .AND. checker_tr(ind_2d(1), ind_2d(2))) CYCLE
401
402         CALL dbcsr_t_get_block(tensor_in, ind_2d, block, found)
403         DBCSR_ASSERT(found)
404
405         IF (dbcsr_has_symmetry(matrix_out) .AND. ind_2d(1) > ind_2d(2)) THEN
406            CALL dbcsr_put_block(matrix_out, ind_2d(2), ind_2d(1), TRANSPOSE(block), summation=summation)
407         ELSE
408            CALL dbcsr_put_block(matrix_out, ind_2d(1), ind_2d(2), block, summation=summation)
409         ENDIF
410         DEALLOCATE (block)
411      ENDDO
412      CALL dbcsr_t_iterator_stop(iter)
413
414      CALL timestop(handle)
415
416   END SUBROUTINE
417
418   SUBROUTINE dbcsr_t_contract(alpha, tensor_1, tensor_2, beta, tensor_3, &
419                               contract_1, notcontract_1, &
420                               contract_2, notcontract_2, &
421                               map_1, map_2, &
422                               bounds_1, bounds_2, bounds_3, &
423                               optimize_dist, pgrid_opt_1, pgrid_opt_2, pgrid_opt_3, &
424                               filter_eps, flop, move_data, retain_sparsity, unit_nr, log_verbose)
425      !! Contract tensors by multiplying matrix representations.
426      !! tensor_3(map_1, map_2) := alpha * tensor_1(notcontract_1, contract_1)
427      !! * tensor_2(contract_2, notcontract_2)
428      !! + beta * tensor_3(map_1, map_2)
429      !!
430      !! @note
431      !! note 1: block sizes of the corresponding indices need to be the same in all tensors.
432      !!
433      !! note 2: for best performance the tensors should have been created in matrix layouts
434      !! compatible with the contraction, e.g. tensor_1 should have been created with either
435      !! map1_2d == contract_1 and map2_2d == notcontract_1 or map1_2d == notcontract_1 and
436      !! map2_2d == contract_1 (the same with tensor_2 and contract_2 / notcontract_2 and with
437      !! tensor_3 and map_1 / map_2).
438      !! Furthermore the two largest tensors involved in the contraction should map both to either
439      !! tall or short matrices: the largest matrix dimension should be "on the same side"
440      !! and should have identical distribution (which is always the case if the distributions were
441      !! obtained with dbcsr_t_default_distvec).
442      !!
443      !! note 3: if the same tensor occurs in multiple contractions, a different tensor object should
444      !! be created for each contraction and the data should be copied between the tensors by use of
445      !! dbcsr_t_copy. If the same tensor object is used in multiple contractions, matrix layouts are
446      !! not compatible for all contractions (see note 2).
447      !!
448      !! note 4: automatic optimizations are enabled by using the feature of batched contraction, see
449      !! dbcsr_t_batched_contract_init, dbcsr_t_batched_contract_finalize. The arguments bounds_1,
450      !! bounds_2, bounds_3 give the index ranges of the batches.
451      !! @endnote
452
453      TYPE(dbcsr_scalar_type), INTENT(IN)            :: alpha
454      TYPE(dbcsr_t_type), INTENT(INOUT), TARGET      :: tensor_1
455         !! first tensor (in)
456      TYPE(dbcsr_t_type), INTENT(INOUT), TARGET      :: tensor_2
457         !! second tensor (in)
458      TYPE(dbcsr_scalar_type), INTENT(IN)            :: beta
459      INTEGER, DIMENSION(:), INTENT(IN)              :: contract_1
460         !! indices of tensor_1 to contract
461      INTEGER, DIMENSION(:), INTENT(IN)              :: contract_2
462         !! indices of tensor_2 to contract (1:1 with contract_1)
463      INTEGER, DIMENSION(:), INTENT(IN)              :: map_1
464         !! which indices of tensor_3 map to non-contracted indices of tensor_1 (1:1 with notcontract_1)
465      INTEGER, DIMENSION(:), INTENT(IN)              :: map_2
466         !! which indices of tensor_3 map to non-contracted indices of tensor_2 (1:1 with notcontract_2)
467      INTEGER, DIMENSION(:), INTENT(IN)              :: notcontract_1
468         !! indices of tensor_1 not to contract
469      INTEGER, DIMENSION(:), INTENT(IN)              :: notcontract_2
470         !! indices of tensor_2 not to contract
471      TYPE(dbcsr_t_type), INTENT(INOUT), TARGET      :: tensor_3
472         !! contracted tensor (out)
473      INTEGER, DIMENSION(2, SIZE(contract_1)), &
474         INTENT(IN), OPTIONAL                        :: bounds_1
475         !! bounds corresponding to contract_1 AKA contract_2: start and end index of an index range over
476         !! which to contract. For use in batched contraction.
477      INTEGER, DIMENSION(2, SIZE(notcontract_1)), &
478         INTENT(IN), OPTIONAL                        :: bounds_2
479         !! bounds corresponding to notcontract_1: start and end index of an index range.
480         !! For use in batched contraction.
481      INTEGER, DIMENSION(2, SIZE(notcontract_2)), &
482         INTENT(IN), OPTIONAL                        :: bounds_3
483         !! bounds corresponding to notcontract_2: start and end index of an index range.
484         !! For use in batched contraction.
485      LOGICAL, INTENT(IN), OPTIONAL                  :: optimize_dist
486         !! Whether distribution should be optimized internally. In the current implementation this guarantees optimal parameters
487         !! only for dense matrices.
488      TYPE(dbcsr_t_pgrid_type), INTENT(OUT), &
489         POINTER, OPTIONAL                           :: pgrid_opt_1
490         !! Optionally return optimal process grid for tensor_1. This can be used to choose optimal process grids for subsequent
491         !! tensor contractions with tensors of similar shape and sparsity. Under some conditions, pgrid_opt_1 can not be returned,
492         !! in this case the pointer is not associated.
493      TYPE(dbcsr_t_pgrid_type), INTENT(OUT), &
494         POINTER, OPTIONAL                           :: pgrid_opt_2
495         !! Optionally return optimal process grid for tensor_2.
496      TYPE(dbcsr_t_pgrid_type), INTENT(OUT), &
497         POINTER, OPTIONAL                           :: pgrid_opt_3
498         !! Optionally return optimal process grid for tensor_3.
499      REAL(KIND=real_8), INTENT(IN), OPTIONAL        :: filter_eps
500         !! As in DBCSR mm
501      INTEGER(KIND=int_8), INTENT(OUT), OPTIONAL     :: flop
502         !! As in DBCSR mm
503      LOGICAL, INTENT(IN), OPTIONAL                  :: move_data
504         !! memory optimization: transfer data such that tensor_1 and tensor_2 are empty on return
505      LOGICAL, INTENT(IN), OPTIONAL                  :: retain_sparsity
506         !! enforce the sparsity pattern of the existing tensor_3; default is no
507      INTEGER, OPTIONAL, INTENT(IN)                  :: unit_nr
508         !! output unit for logging
509         !! set it to -1 on ranks that should not write (and any valid unit number on ranks that should write output)
510         !! if 0 on ALL ranks, no output is written
511      LOGICAL, INTENT(IN), OPTIONAL                  :: log_verbose
512         !! verbose logging (for testing only)
513
514      INTEGER                     :: handle
515
516      CALL mp_sync(tensor_1%pgrid%mp_comm_2d)
517      CALL timeset("dbcsr_t_total", handle)
518      CALL dbcsr_t_contract_expert(alpha, tensor_1, tensor_2, beta, tensor_3, &
519                                   contract_1, notcontract_1, &
520                                   contract_2, notcontract_2, &
521                                   map_1, map_2, &
522                                   bounds_1=bounds_1, &
523                                   bounds_2=bounds_2, &
524                                   bounds_3=bounds_3, &
525                                   optimize_dist=optimize_dist, &
526                                   pgrid_opt_1=pgrid_opt_1, &
527                                   pgrid_opt_2=pgrid_opt_2, &
528                                   pgrid_opt_3=pgrid_opt_3, &
529                                   filter_eps=filter_eps, &
530                                   flop=flop, &
531                                   move_data=move_data, &
532                                   retain_sparsity=retain_sparsity, &
533                                   unit_nr=unit_nr, &
534                                   log_verbose=log_verbose)
535      CALL mp_sync(tensor_1%pgrid%mp_comm_2d)
536      CALL timestop(handle)
537
538   END SUBROUTINE
539
540   SUBROUTINE dbcsr_t_contract_expert(alpha, tensor_1, tensor_2, beta, tensor_3, &
541                                      contract_1, notcontract_1, &
542                                      contract_2, notcontract_2, &
543                                      map_1, map_2, &
544                                      bounds_1, bounds_2, bounds_3, &
545                                      optimize_dist, pgrid_opt_1, pgrid_opt_2, pgrid_opt_3, &
546                                      filter_eps, flop, move_data, retain_sparsity, &
547                                      nblks_local, result_index, unit_nr, log_verbose)
548      !! expert routine for tensor contraction. For internal use only.
549      TYPE(dbcsr_scalar_type), INTENT(IN)            :: alpha
550      TYPE(dbcsr_t_type), INTENT(INOUT), TARGET      :: tensor_1
551      TYPE(dbcsr_t_type), INTENT(INOUT), TARGET      :: tensor_2
552      TYPE(dbcsr_scalar_type), INTENT(IN)            :: beta
553      INTEGER, DIMENSION(:), INTENT(IN)              :: contract_1
554      INTEGER, DIMENSION(:), INTENT(IN)              :: contract_2
555      INTEGER, DIMENSION(:), INTENT(IN)              :: map_1
556      INTEGER, DIMENSION(:), INTENT(IN)              :: map_2
557      INTEGER, DIMENSION(:), INTENT(IN)              :: notcontract_1
558      INTEGER, DIMENSION(:), INTENT(IN)              :: notcontract_2
559      TYPE(dbcsr_t_type), INTENT(INOUT), TARGET      :: tensor_3
560      INTEGER, DIMENSION(2, SIZE(contract_1)), &
561         INTENT(IN), OPTIONAL                        :: bounds_1
562      INTEGER, DIMENSION(2, SIZE(notcontract_1)), &
563         INTENT(IN), OPTIONAL                        :: bounds_2
564      INTEGER, DIMENSION(2, SIZE(notcontract_2)), &
565         INTENT(IN), OPTIONAL                        :: bounds_3
566      LOGICAL, INTENT(IN), OPTIONAL                  :: optimize_dist
567      TYPE(dbcsr_t_pgrid_type), INTENT(OUT), &
568         POINTER, OPTIONAL                           :: pgrid_opt_1
569      TYPE(dbcsr_t_pgrid_type), INTENT(OUT), &
570         POINTER, OPTIONAL                           :: pgrid_opt_2
571      TYPE(dbcsr_t_pgrid_type), INTENT(OUT), &
572         POINTER, OPTIONAL                           :: pgrid_opt_3
573      REAL(KIND=real_8), INTENT(IN), OPTIONAL        :: filter_eps
574      INTEGER(KIND=int_8), INTENT(OUT), OPTIONAL     :: flop
575      LOGICAL, INTENT(IN), OPTIONAL                  :: move_data
576      LOGICAL, INTENT(IN), OPTIONAL                  :: retain_sparsity
577      INTEGER, INTENT(OUT), OPTIONAL                 :: nblks_local
578         !! number of local blocks on this MPI rank
579      INTEGER, DIMENSION(dbcsr_t_max_nblks_local(tensor_3), ndims_tensor(tensor_3)), &
580         OPTIONAL, INTENT(OUT)                       :: result_index
581         !! get indices of non-zero tensor blocks for tensor_3 without actually performing contraction
582         !! this is an estimate based on block norm multiplication
583      INTEGER, OPTIONAL, INTENT(IN)                  :: unit_nr
584      LOGICAL, INTENT(IN), OPTIONAL                  :: log_verbose
585
586      TYPE(dbcsr_t_type), POINTER                    :: tensor_contr_1, tensor_contr_2, tensor_contr_3
587      TYPE(dbcsr_t_type), TARGET                     :: tensor_algn_1, tensor_algn_2, tensor_algn_3
588      TYPE(dbcsr_t_type), POINTER                    :: tensor_crop_1, tensor_crop_2
589      TYPE(dbcsr_t_type), POINTER                    :: tensor_small, tensor_large
590
591      INTEGER(int_8), DIMENSION(:, :), ALLOCATABLE  :: result_index_2d
592      LOGICAL                                        :: assert_stmt, tensors_remapped
593      INTEGER                                        :: data_type, max_mm_dim, max_tensor, mp_comm, &
594                                                        iblk, nblk, unit_nr_prv, ref_tensor, mp_comm_opt, &
595                                                        handle
596      INTEGER, DIMENSION(SIZE(contract_1))           :: contract_1_mod
597      INTEGER, DIMENSION(SIZE(notcontract_1))        :: notcontract_1_mod
598      INTEGER, DIMENSION(SIZE(contract_2))           :: contract_2_mod
599      INTEGER, DIMENSION(SIZE(notcontract_2))        :: notcontract_2_mod
600      INTEGER, DIMENSION(SIZE(map_1))                :: map_1_mod
601      INTEGER, DIMENSION(SIZE(map_2))                :: map_2_mod
602      CHARACTER(LEN=1)                               :: trans_1, trans_2, trans_3
603      LOGICAL                                        :: new_1, new_2, new_3, move_data_1, move_data_2
604      INTEGER                                        :: ndims1, ndims2, ndims3
605      INTEGER                                        :: occ_1, occ_2
606      INTEGER, DIMENSION(:), ALLOCATABLE             :: dims1, dims2, dims3
607
608      CHARACTER(LEN=*), PARAMETER :: routineN = 'dbcsr_t_contract'
609      CHARACTER(LEN=1), DIMENSION(:), ALLOCATABLE    :: indchar1, indchar2, indchar3, indchar1_mod, &
610                                                        indchar2_mod, indchar3_mod
611      CHARACTER(LEN=1), DIMENSION(15) :: alph = &
612                                         ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o']
613      INTEGER, DIMENSION(2, ndims_tensor(tensor_1)) :: bounds_t1
614      INTEGER, DIMENSION(2, ndims_tensor(tensor_2)) :: bounds_t2
615      LOGICAL                                        :: do_crop_1, do_crop_2, do_write_3, nodata_3, do_batched, pgrid_changed, &
616                                                        pgrid_changed_any, do_change_pgrid(2)
617      TYPE(dbcsr_tas_split_info)                     :: split_opt, split, split_opt_avg
618      INTEGER, DIMENSION(2) :: pdims_2d_opt, pdims_2d, pcoord_2d, pdims_sub, pdims_sub_opt
619      LOGICAL, DIMENSION(2) :: periods_2d
620      REAL(real_8) :: pdim_ratio, pdim_ratio_opt
621
622      NULLIFY (tensor_contr_1, tensor_contr_2, tensor_contr_3, tensor_crop_1, tensor_crop_2, &
623               tensor_small)
624
625      CALL timeset(routineN, handle)
626
627      DBCSR_ASSERT(tensor_1%valid)
628      DBCSR_ASSERT(tensor_2%valid)
629      DBCSR_ASSERT(tensor_3%valid)
630
631      assert_stmt = SIZE(contract_1) .EQ. SIZE(contract_2)
632      DBCSR_ASSERT(assert_stmt)
633
634      assert_stmt = SIZE(map_1) .EQ. SIZE(notcontract_1)
635      DBCSR_ASSERT(assert_stmt)
636
637      assert_stmt = SIZE(map_2) .EQ. SIZE(notcontract_2)
638      DBCSR_ASSERT(assert_stmt)
639
640      assert_stmt = SIZE(notcontract_1) + SIZE(contract_1) .EQ. ndims_tensor(tensor_1)
641      DBCSR_ASSERT(assert_stmt)
642
643      assert_stmt = SIZE(notcontract_2) + SIZE(contract_2) .EQ. ndims_tensor(tensor_2)
644      DBCSR_ASSERT(assert_stmt)
645
646      assert_stmt = SIZE(map_1) + SIZE(map_2) .EQ. ndims_tensor(tensor_3)
647      DBCSR_ASSERT(assert_stmt)
648
649      assert_stmt = dbcsr_t_get_data_type(tensor_1) .EQ. dbcsr_t_get_data_type(tensor_2)
650      DBCSR_ASSERT(assert_stmt)
651
652      unit_nr_prv = prep_output_unit(unit_nr)
653
654      IF (PRESENT(flop)) flop = 0
655      IF (PRESENT(result_index)) result_index = 0
656      IF (PRESENT(nblks_local)) nblks_local = 0
657
658      IF (PRESENT(move_data)) THEN
659         move_data_1 = move_data
660         move_data_2 = move_data
661      ELSE
662         move_data_1 = .FALSE.
663         move_data_2 = .FALSE.
664      ENDIF
665
666      nodata_3 = .TRUE.
667      IF (PRESENT(retain_sparsity)) THEN
668         IF (retain_sparsity) nodata_3 = .FALSE.
669      ENDIF
670
671      CALL dbcsr_t_map_bounds_to_tensors(tensor_1, tensor_2, &
672                                         contract_1, notcontract_1, &
673                                         contract_2, notcontract_2, &
674                                         bounds_t1, bounds_t2, &
675                                         bounds_1=bounds_1, bounds_2=bounds_2, bounds_3=bounds_3, &
676                                         do_crop_1=do_crop_1, do_crop_2=do_crop_2)
677
678      IF (do_crop_1) THEN
679         ALLOCATE (tensor_crop_1)
680         CALL dbcsr_t_crop(tensor_1, tensor_crop_1, bounds_t1, move_data=move_data_1)
681         move_data_1 = .TRUE.
682      ELSE
683         tensor_crop_1 => tensor_1
684      ENDIF
685
686      IF (do_crop_2) THEN
687         ALLOCATE (tensor_crop_2)
688         CALL dbcsr_t_crop(tensor_2, tensor_crop_2, bounds_t2, move_data=move_data_2)
689         move_data_2 = .TRUE.
690      ELSE
691         tensor_crop_2 => tensor_2
692      ENDIF
693
694      ! shortcut for empty tensors
695      ! this is needed to avoid unnecessary work in case user contracts different portions of a
696      ! tensor consecutively to save memory
697      mp_comm = tensor_crop_1%pgrid%mp_comm_2d
698      occ_1 = dbcsr_t_get_num_blocks(tensor_crop_1)
699      CALL mp_max(occ_1, mp_comm)
700      occ_2 = dbcsr_t_get_num_blocks(tensor_crop_2)
701      CALL mp_max(occ_2, mp_comm)
702
703      IF (occ_1 == 0 .OR. occ_2 == 0) THEN
704         CALL dbcsr_t_scale(tensor_3, beta)
705         IF (do_crop_1) THEN
706            CALL dbcsr_t_destroy(tensor_crop_1)
707            DEALLOCATE (tensor_crop_1)
708         ENDIF
709         IF (do_crop_2) THEN
710            CALL dbcsr_t_destroy(tensor_crop_2)
711            DEALLOCATE (tensor_crop_2)
712         ENDIF
713
714         CALL timestop(handle)
715         RETURN
716      ENDIF
717
718      IF (unit_nr_prv /= 0) THEN
719         IF (unit_nr_prv > 0) THEN
720            WRITE (unit_nr_prv, '(A)') repeat("-", 80)
721            WRITE (unit_nr_prv, '(A,1X,A,1X,A,1X,A,1X,A,1X,A)') "DBCSR TENSOR CONTRACTION:", &
722               TRIM(tensor_crop_1%name), 'x', TRIM(tensor_crop_2%name), '=', TRIM(tensor_3%name)
723            WRITE (unit_nr_prv, '(A)') repeat("-", 80)
724         ENDIF
725         CALL dbcsr_t_write_tensor_info(tensor_crop_1, unit_nr_prv, full_info=log_verbose)
726         CALL dbcsr_t_write_tensor_dist(tensor_crop_1, unit_nr_prv)
727         CALL dbcsr_t_write_tensor_info(tensor_crop_2, unit_nr_prv, full_info=log_verbose)
728         CALL dbcsr_t_write_tensor_dist(tensor_crop_2, unit_nr_prv)
729      ENDIF
730
731      data_type = dbcsr_t_get_data_type(tensor_crop_1)
732
733      ! align tensor index with data, tensor data is not modified
734      ndims1 = ndims_tensor(tensor_crop_1)
735      ndims2 = ndims_tensor(tensor_crop_2)
736      ndims3 = ndims_tensor(tensor_3)
737      ALLOCATE (indchar1(ndims1), indchar1_mod(ndims1))
738      ALLOCATE (indchar2(ndims2), indchar2_mod(ndims2))
739      ALLOCATE (indchar3(ndims3), indchar3_mod(ndims3))
740
741      ! labeling tensor index with letters
742
743      indchar1([notcontract_1, contract_1]) = alph(1:ndims1) ! arb. choice
744      indchar2(notcontract_2) = alph(ndims1 + 1:ndims1 + SIZE(notcontract_2)) ! arb. choice
745      indchar2(contract_2) = indchar1(contract_1)
746      indchar3(map_1) = indchar1(notcontract_1)
747      indchar3(map_2) = indchar2(notcontract_2)
748
749      IF (unit_nr_prv /= 0) CALL dbcsr_t_print_contraction_index(tensor_crop_1, indchar1, &
750                                                                 tensor_crop_2, indchar2, &
751                                                                 tensor_3, indchar3, unit_nr_prv)
752      IF (unit_nr_prv > 0) THEN
753         WRITE (unit_nr_prv, '(T2,A)') "aligning tensor index with data"
754      ENDIF
755
756      CALL align_tensor(tensor_crop_1, contract_1, notcontract_1, &
757                        tensor_algn_1, contract_1_mod, notcontract_1_mod, indchar1, indchar1_mod)
758
759      CALL align_tensor(tensor_crop_2, contract_2, notcontract_2, &
760                        tensor_algn_2, contract_2_mod, notcontract_2_mod, indchar2, indchar2_mod)
761
762      CALL align_tensor(tensor_3, map_1, map_2, &
763                        tensor_algn_3, map_1_mod, map_2_mod, indchar3, indchar3_mod)
764
765      IF (unit_nr_prv /= 0) CALL dbcsr_t_print_contraction_index(tensor_algn_1, indchar1_mod, &
766                                                                 tensor_algn_2, indchar2_mod, &
767                                                                 tensor_algn_3, indchar3_mod, unit_nr_prv)
768
769      ALLOCATE (dims1(ndims1))
770      ALLOCATE (dims2(ndims2))
771      ALLOCATE (dims3(ndims3))
772
773      ! ideally we should consider block sizes and occupancy to measure tensor sizes but current solution should work for most
774      ! cases and is more elegant. Note that we can not easily consider occupancy since it is unknown for result tensor
775      CALL blk_dims_tensor(tensor_crop_1, dims1)
776      CALL blk_dims_tensor(tensor_crop_2, dims2)
777      CALL blk_dims_tensor(tensor_3, dims3)
778
779      max_mm_dim = MAXLOC([PRODUCT(INT(dims1(notcontract_1), int_8)), &
780                           PRODUCT(INT(dims1(contract_1), int_8)), &
781                           PRODUCT(INT(dims2(notcontract_2), int_8))], DIM=1)
782      max_tensor = MAXLOC([PRODUCT(INT(dims1, int_8)), PRODUCT(INT(dims2, int_8)), PRODUCT(INT(dims3, int_8))], DIM=1)
783      SELECT CASE (max_mm_dim)
784      CASE (1)
785         IF (unit_nr_prv > 0) THEN
786            WRITE (unit_nr_prv, '(T2,A)') "large tensors: 1, 3; small tensor: 2"
787            WRITE (unit_nr_prv, '(T2,A)') "sorting contraction indices"
788         ENDIF
789         CALL index_linked_sort(contract_1_mod, contract_2_mod)
790         CALL index_linked_sort(map_2_mod, notcontract_2_mod)
791         SELECT CASE (max_tensor)
792         CASE (1)
793            CALL index_linked_sort(notcontract_1_mod, map_1_mod)
794         CASE (3)
795            CALL index_linked_sort(map_1_mod, notcontract_1_mod)
796         CASE DEFAULT
797            DBCSR_ABORT("should not happen")
798         END SELECT
799
800         CALL reshape_mm_compatible(tensor_algn_1, tensor_algn_3, tensor_contr_1, tensor_contr_3, &
801                                    contract_1_mod, notcontract_1_mod, map_2_mod, map_1_mod, &
802                                    trans_1, trans_3, new_1, new_3, ref_tensor, nodata2=nodata_3, optimize_dist=optimize_dist, &
803                                    move_data_1=move_data_1, unit_nr=unit_nr_prv)
804
805         CALL reshape_mm_small(tensor_algn_2, contract_2_mod, notcontract_2_mod, tensor_contr_2, trans_2, &
806                               new_2, move_data=move_data_2, unit_nr=unit_nr_prv)
807
808         SELECT CASE (ref_tensor)
809         CASE (1)
810            tensor_large => tensor_contr_1
811         CASE (2)
812            tensor_large => tensor_contr_3
813         END SELECT
814         tensor_small => tensor_contr_2
815
816      CASE (2)
817         IF (unit_nr_prv > 0) THEN
818            WRITE (unit_nr_prv, '(T2,A)') "large tensors: 1, 2; small tensor: 3"
819            WRITE (unit_nr_prv, '(T2,A)') "sorting contraction indices"
820         ENDIF
821
822         CALL index_linked_sort(notcontract_1_mod, map_1_mod)
823         CALL index_linked_sort(notcontract_2_mod, map_2_mod)
824         SELECT CASE (max_tensor)
825         CASE (1)
826            CALL index_linked_sort(contract_1_mod, contract_2_mod)
827         CASE (2)
828            CALL index_linked_sort(contract_2_mod, contract_1_mod)
829         CASE DEFAULT
830            DBCSR_ABORT("should not happen")
831         END SELECT
832
833         CALL reshape_mm_compatible(tensor_algn_1, tensor_algn_2, tensor_contr_1, tensor_contr_2, &
834                                    notcontract_1_mod, contract_1_mod, notcontract_2_mod, contract_2_mod, &
835                                    trans_1, trans_2, new_1, new_2, ref_tensor, optimize_dist=optimize_dist, &
836                                    move_data_1=move_data_1, move_data_2=move_data_2, unit_nr=unit_nr_prv)
837         CALL invert_transpose_flag(trans_1)
838
839         CALL reshape_mm_small(tensor_algn_3, map_1_mod, map_2_mod, tensor_contr_3, trans_3, &
840                               new_3, nodata=nodata_3, unit_nr=unit_nr_prv)
841
842         SELECT CASE (ref_tensor)
843         CASE (1)
844            tensor_large => tensor_contr_1
845         CASE (2)
846            tensor_large => tensor_contr_2
847         END SELECT
848         tensor_small => tensor_contr_3
849
850      CASE (3)
851         IF (unit_nr_prv > 0) THEN
852            WRITE (unit_nr_prv, '(T2,A)') "large tensors: 2, 3; small tensor: 1"
853            WRITE (unit_nr_prv, '(T2,A)') "sorting contraction indices"
854         ENDIF
855         CALL index_linked_sort(map_1_mod, notcontract_1_mod)
856         CALL index_linked_sort(contract_2_mod, contract_1_mod)
857         SELECT CASE (max_tensor)
858         CASE (2)
859            CALL index_linked_sort(notcontract_2_mod, map_2_mod)
860         CASE (3)
861            CALL index_linked_sort(map_2_mod, notcontract_2_mod)
862         CASE DEFAULT
863            DBCSR_ABORT("should not happen")
864         END SELECT
865
866         CALL reshape_mm_compatible(tensor_algn_2, tensor_algn_3, tensor_contr_2, tensor_contr_3, &
867                                    contract_2_mod, notcontract_2_mod, map_1_mod, map_2_mod, &
868                                    trans_2, trans_3, new_2, new_3, ref_tensor, nodata2=nodata_3, optimize_dist=optimize_dist, &
869                                    move_data_1=move_data_2, unit_nr=unit_nr_prv)
870
871         CALL invert_transpose_flag(trans_2)
872         CALL invert_transpose_flag(trans_3)
873
874         CALL reshape_mm_small(tensor_algn_1, notcontract_1_mod, contract_1_mod, tensor_contr_1, &
875                               trans_1, new_1, move_data=move_data_1, unit_nr=unit_nr_prv)
876
877         SELECT CASE (ref_tensor)
878         CASE (1)
879            tensor_large => tensor_contr_2
880         CASE (2)
881            tensor_large => tensor_contr_3
882         END SELECT
883         tensor_small => tensor_contr_1
884
885      END SELECT
886
887      IF (unit_nr_prv /= 0) CALL dbcsr_t_print_contraction_index(tensor_contr_1, indchar1_mod, &
888                                                                 tensor_contr_2, indchar2_mod, &
889                                                                 tensor_contr_3, indchar3_mod, unit_nr_prv)
890      IF (unit_nr_prv /= 0) THEN
891         IF (new_1) CALL dbcsr_t_write_tensor_info(tensor_contr_1, unit_nr_prv, full_info=log_verbose)
892         IF (new_1) CALL dbcsr_t_write_tensor_dist(tensor_contr_1, unit_nr_prv)
893         IF (new_2) CALL dbcsr_t_write_tensor_info(tensor_contr_2, unit_nr_prv, full_info=log_verbose)
894         IF (new_2) CALL dbcsr_t_write_tensor_dist(tensor_contr_2, unit_nr_prv)
895      ENDIF
896
897      IF (.NOT. PRESENT(result_index)) THEN
898         CALL dbcsr_tas_multiply(trans_1, trans_2, trans_3, alpha, &
899                                 tensor_contr_1%matrix_rep, tensor_contr_2%matrix_rep, &
900                                 beta, &
901                                 tensor_contr_3%matrix_rep, filter_eps=filter_eps, flop=flop, &
902                                 unit_nr=unit_nr_prv, log_verbose=log_verbose, &
903                                 split_opt=split_opt, &
904                                 move_data_a=move_data_1, move_data_b=move_data_2, retain_sparsity=retain_sparsity)
905      ELSE
906
907         CALL dbcsr_tas_result_index(trans_1, trans_2, trans_3, tensor_contr_1%matrix_rep, tensor_contr_2%matrix_rep, &
908                                     tensor_contr_3%matrix_rep, filter_eps=filter_eps, blk_ind=result_index_2d)
909
910         nblk = SIZE(result_index_2d, 1)
911         IF (PRESENT(nblks_local)) nblks_local = nblk
912         IF (SIZE(result_index, 1) < nblk) THEN
913            CALL dbcsr_abort(__LOCATION__, &
914        "allocated size of `result_index` is too small. This error occurs due to a high load imbalance of distributed tensor data.")
915         ENDIF
916
917         DO iblk = 1, nblk
918            result_index(iblk, :) = get_nd_indices_tensor(tensor_contr_3%nd_index_blk, result_index_2d(iblk, :))
919         ENDDO
920
921         IF (new_1) THEN
922            CALL dbcsr_t_destroy(tensor_contr_1)
923            DEALLOCATE (tensor_contr_1)
924         ENDIF
925         IF (new_2) THEN
926            CALL dbcsr_t_destroy(tensor_contr_2)
927            DEALLOCATE (tensor_contr_2)
928         ENDIF
929         IF (new_3) THEN
930            CALL dbcsr_t_destroy(tensor_contr_3)
931            DEALLOCATE (tensor_contr_3)
932         ENDIF
933         IF (do_crop_1) THEN
934            CALL dbcsr_t_destroy(tensor_crop_1)
935            DEALLOCATE (tensor_crop_1)
936         ENDIF
937         IF (do_crop_2) THEN
938            CALL dbcsr_t_destroy(tensor_crop_2)
939            DEALLOCATE (tensor_crop_2)
940         ENDIF
941
942         CALL dbcsr_t_destroy(tensor_algn_1)
943         CALL dbcsr_t_destroy(tensor_algn_2)
944         CALL dbcsr_t_destroy(tensor_algn_3)
945
946         CALL timestop(handle)
947         RETURN
948      ENDIF
949
950      IF (PRESENT(pgrid_opt_1)) THEN
951         IF (.NOT. new_1) THEN
952            ALLOCATE (pgrid_opt_1)
953            pgrid_opt_1 = opt_pgrid(tensor_1, split_opt)
954         ENDIF
955      ENDIF
956
957      IF (PRESENT(pgrid_opt_2)) THEN
958         IF (.NOT. new_2) THEN
959            ALLOCATE (pgrid_opt_2)
960            pgrid_opt_2 = opt_pgrid(tensor_2, split_opt)
961         ENDIF
962      ENDIF
963
964      IF (PRESENT(pgrid_opt_3)) THEN
965         IF (.NOT. new_3) THEN
966            ALLOCATE (pgrid_opt_3)
967            pgrid_opt_3 = opt_pgrid(tensor_3, split_opt)
968         ENDIF
969      ENDIF
970
971      do_batched = tensor_small%matrix_rep%do_batched > 0
972
973      tensors_remapped = .FALSE.
974      IF (new_1 .OR. new_2 .OR. new_3) tensors_remapped = .TRUE.
975
976      IF (tensors_remapped .AND. do_batched) THEN
977         CALL dbcsr_warn(__LOCATION__, &
978                         "Internal process grid optimization disabled because tensors are not in contraction-compatible format")
979      ENDIF
980
981      CALL mp_environ(tensor_large%pgrid%mp_comm_2d, 2, pdims_2d, pcoord_2d, periods_2d)
982
983      ! optimize process grid during batched contraction
984      do_change_pgrid(:) = .FALSE.
985      IF ((.NOT. tensors_remapped) .AND. do_batched) THEN
986         ASSOCIATE (storage => tensor_small%contraction_storage)
987            DBCSR_ASSERT(storage%static)
988            split = dbcsr_tas_info(tensor_large%matrix_rep)
989            do_change_pgrid(:) = &
990               update_contraction_storage(storage, split_opt, split)
991
992            IF (ANY(do_change_pgrid)) THEN
993               mp_comm_opt = dbcsr_tas_mp_comm(tensor_small%pgrid%mp_comm_2d, split_opt%split_rowcol, NINT(storage%nsplit_avg))
994               CALL dbcsr_tas_create_split(split_opt_avg, mp_comm_opt, split_opt%split_rowcol, &
995                                           NINT(storage%nsplit_avg), own_comm=.TRUE.)
996               CALL mp_environ(split_opt_avg%mp_comm, 2, pdims_2d_opt, pcoord_2d, periods_2d)
997            ENDIF
998
999         END ASSOCIATE
1000
1001         IF (do_change_pgrid(1) .AND. .NOT. do_change_pgrid(2)) THEN
1002            ! check if new grid has better subgrid, if not there is no need to change process grid
1003            CALL mp_environ(split_opt_avg%mp_comm_group, 2, pdims_sub_opt, pcoord_2d, periods_2d)
1004            CALL mp_environ(split%mp_comm_group, 2, pdims_sub, pcoord_2d, periods_2d)
1005
1006            pdim_ratio = MAXVAL(REAL(pdims_sub, real_8))/MINVAL(pdims_sub)
1007            pdim_ratio_opt = MAXVAL(REAL(pdims_sub_opt, real_8))/MINVAL(pdims_sub_opt)
1008            IF (pdim_ratio/pdim_ratio_opt <= default_pdims_accept_ratio**2) THEN
1009               do_change_pgrid(1) = .FALSE.
1010               CALL dbcsr_tas_release_info(split_opt_avg)
1011            ENDIF
1012         ENDIF
1013      ENDIF
1014
1015      IF (unit_nr_prv /= 0) THEN
1016         do_write_3 = .TRUE.
1017         IF (tensor_contr_3%matrix_rep%do_batched > 0) THEN
1018            IF (tensor_contr_3%matrix_rep%mm_storage%batched_out) do_write_3 = .FALSE.
1019         ENDIF
1020         IF (do_write_3) THEN
1021            CALL dbcsr_t_write_tensor_info(tensor_contr_3, unit_nr_prv, full_info=log_verbose)
1022            CALL dbcsr_t_write_tensor_dist(tensor_contr_3, unit_nr_prv)
1023         ENDIF
1024      ENDIF
1025
1026      IF (new_3) THEN
1027         ! need redistribute if we created new tensor for tensor 3
1028         CALL dbcsr_t_scale(tensor_algn_3, beta)
1029         CALL dbcsr_t_copy_expert(tensor_contr_3, tensor_algn_3, summation=.TRUE., move_data=.TRUE.)
1030         IF (PRESENT(filter_eps)) CALL dbcsr_t_filter(tensor_algn_3, filter_eps)
1031         ! tensor_3 automatically has correct data because tensor_algn_3 contains a matrix
1032         ! pointer to data of tensor_3
1033      ENDIF
1034
1035      ! transfer contraction storage
1036      CALL dbcsr_t_copy_contraction_storage(tensor_contr_1, tensor_1)
1037      CALL dbcsr_t_copy_contraction_storage(tensor_contr_2, tensor_2)
1038      CALL dbcsr_t_copy_contraction_storage(tensor_contr_3, tensor_3)
1039
1040      IF (unit_nr_prv /= 0) THEN
1041         IF (new_3 .AND. do_write_3) CALL dbcsr_t_write_tensor_info(tensor_3, unit_nr_prv, full_info=log_verbose)
1042         IF (new_3 .AND. do_write_3) CALL dbcsr_t_write_tensor_dist(tensor_3, unit_nr_prv)
1043      ENDIF
1044
1045      CALL dbcsr_t_destroy(tensor_algn_1)
1046      CALL dbcsr_t_destroy(tensor_algn_2)
1047      CALL dbcsr_t_destroy(tensor_algn_3)
1048
1049      IF (do_crop_1) THEN
1050         CALL dbcsr_t_destroy(tensor_crop_1)
1051         DEALLOCATE (tensor_crop_1)
1052      ENDIF
1053
1054      IF (do_crop_2) THEN
1055         CALL dbcsr_t_destroy(tensor_crop_2)
1056         DEALLOCATE (tensor_crop_2)
1057      ENDIF
1058
1059      IF (new_1) THEN
1060         CALL dbcsr_t_destroy(tensor_contr_1)
1061         DEALLOCATE (tensor_contr_1)
1062      ENDIF
1063      IF (new_2) THEN
1064         CALL dbcsr_t_destroy(tensor_contr_2)
1065         DEALLOCATE (tensor_contr_2)
1066      ENDIF
1067      IF (new_3) THEN
1068         CALL dbcsr_t_destroy(tensor_contr_3)
1069         DEALLOCATE (tensor_contr_3)
1070      ENDIF
1071
1072      IF (PRESENT(move_data)) THEN
1073         IF (move_data) THEN
1074            CALL dbcsr_t_clear(tensor_1)
1075            CALL dbcsr_t_clear(tensor_2)
1076         ENDIF
1077      ENDIF
1078
1079      IF (unit_nr_prv > 0) THEN
1080         WRITE (unit_nr_prv, '(A)') repeat("-", 80)
1081         WRITE (unit_nr_prv, '(A)') "TENSOR CONTRACTION DONE"
1082         WRITE (unit_nr_prv, '(A)') repeat("-", 80)
1083      ENDIF
1084
1085      IF (ANY(do_change_pgrid)) THEN
1086         pgrid_changed_any = .FALSE.
1087         SELECT CASE (max_mm_dim)
1088         CASE (1)
1089            IF (ALLOCATED(tensor_1%contraction_storage) .AND. ALLOCATED(tensor_3%contraction_storage)) THEN
1090               CALL dbcsr_t_change_pgrid_2d(tensor_1, tensor_1%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1091                                            nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1092                                            pgrid_changed=pgrid_changed, &
1093                                            unit_nr=unit_nr_prv)
1094               IF (pgrid_changed) pgrid_changed_any = .TRUE.
1095               CALL dbcsr_t_change_pgrid_2d(tensor_3, tensor_3%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1096                                            nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1097                                            pgrid_changed=pgrid_changed, &
1098                                            unit_nr=unit_nr_prv)
1099               IF (pgrid_changed) pgrid_changed_any = .TRUE.
1100            ENDIF
1101            IF (pgrid_changed_any) THEN
1102               IF (tensor_2%matrix_rep%do_batched == 3) THEN
1103                  ! set flag that process grid has been optimized to make sure that no grid optimizations are done
1104                  ! in TAS multiply algorithm
1105                  CALL dbcsr_tas_batched_mm_complete(tensor_2%matrix_rep)
1106               ENDIF
1107            ENDIF
1108         CASE (2)
1109            IF (ALLOCATED(tensor_1%contraction_storage) .AND. ALLOCATED(tensor_2%contraction_storage)) THEN
1110               CALL dbcsr_t_change_pgrid_2d(tensor_1, tensor_1%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1111                                            nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1112                                            pgrid_changed=pgrid_changed, &
1113                                            unit_nr=unit_nr_prv)
1114               IF (pgrid_changed) pgrid_changed_any = .TRUE.
1115               CALL dbcsr_t_change_pgrid_2d(tensor_2, tensor_2%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1116                                            nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1117                                            pgrid_changed=pgrid_changed, &
1118                                            unit_nr=unit_nr_prv)
1119               IF (pgrid_changed) pgrid_changed_any = .TRUE.
1120            ENDIF
1121            IF (pgrid_changed_any) THEN
1122               IF (tensor_3%matrix_rep%do_batched == 3) THEN
1123                  CALL dbcsr_tas_batched_mm_complete(tensor_3%matrix_rep)
1124               ENDIF
1125            ENDIF
1126         CASE (3)
1127            IF (ALLOCATED(tensor_2%contraction_storage) .AND. ALLOCATED(tensor_3%contraction_storage)) THEN
1128               CALL dbcsr_t_change_pgrid_2d(tensor_2, tensor_2%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1129                                            nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1130                                            pgrid_changed=pgrid_changed, &
1131                                            unit_nr=unit_nr_prv)
1132               IF (pgrid_changed) pgrid_changed_any = .TRUE.
1133               CALL dbcsr_t_change_pgrid_2d(tensor_3, tensor_3%pgrid%mp_comm_2d, pdims=pdims_2d_opt, &
1134                                            nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, &
1135                                            pgrid_changed=pgrid_changed, &
1136                                            unit_nr=unit_nr_prv)
1137               IF (pgrid_changed) pgrid_changed_any = .TRUE.
1138            ENDIF
1139            IF (pgrid_changed_any) THEN
1140               IF (tensor_1%matrix_rep%do_batched == 3) THEN
1141                  CALL dbcsr_tas_batched_mm_complete(tensor_1%matrix_rep)
1142               ENDIF
1143            ENDIF
1144         END SELECT
1145         CALL dbcsr_tas_release_info(split_opt_avg)
1146      ENDIF
1147
1148      IF ((.NOT. tensors_remapped) .AND. do_batched) THEN
1149         ! freeze TAS process grids if tensor grids were optimized
1150         CALL dbcsr_tas_set_batched_state(tensor_1%matrix_rep, opt_grid=.TRUE.)
1151         CALL dbcsr_tas_set_batched_state(tensor_2%matrix_rep, opt_grid=.TRUE.)
1152         CALL dbcsr_tas_set_batched_state(tensor_3%matrix_rep, opt_grid=.TRUE.)
1153      ENDIF
1154
1155      CALL dbcsr_tas_release_info(split_opt)
1156
1157      CALL timestop(handle)
1158
1159   END SUBROUTINE
1160
1161   SUBROUTINE align_tensor(tensor_in, contract_in, notcontract_in, &
1162      !! align tensor index with data
1163                           tensor_out, contract_out, notcontract_out, indp_in, indp_out)
1164      TYPE(dbcsr_t_type), INTENT(INOUT)               :: tensor_in
1165      INTEGER, DIMENSION(:), INTENT(IN)            :: contract_in, notcontract_in
1166      TYPE(dbcsr_t_type), INTENT(OUT)              :: tensor_out
1167      INTEGER, DIMENSION(SIZE(contract_in)), &
1168         INTENT(OUT)                               :: contract_out
1169      INTEGER, DIMENSION(SIZE(notcontract_in)), &
1170         INTENT(OUT)                               :: notcontract_out
1171      CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_in)), INTENT(IN) :: indp_in
1172      CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_in)), INTENT(OUT) :: indp_out
1173      INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: align
1174
1175      CALL dbcsr_t_align_index(tensor_in, tensor_out, order=align)
1176      contract_out = align(contract_in)
1177      notcontract_out = align(notcontract_in)
1178      indp_out(align) = indp_in
1179
1180   END SUBROUTINE
1181
1182   SUBROUTINE reshape_mm_compatible(tensor1, tensor2, tensor1_out, tensor2_out, ind1_free, ind1_linked, &
1183                                    ind2_free, ind2_linked, trans1, trans2, new1, new2, ref_tensor, &
1184                                    nodata1, nodata2, move_data_1, &
1185                                    move_data_2, optimize_dist, unit_nr)
1186      !! Prepare tensor for contraction: redistribute to a 2d format which can be contracted by
1187      !! matrix multiplication. This routine reshapes the two largest of the three tensors. Redistribution
1188      !! is avoided if tensors already in a consistent layout.
1189
1190      TYPE(dbcsr_t_type), TARGET, INTENT(INOUT)   :: tensor1
1191         !! tensor 1 in
1192      TYPE(dbcsr_t_type), TARGET, INTENT(INOUT)   :: tensor2
1193         !! tensor 2 in
1194      TYPE(dbcsr_t_type), POINTER, INTENT(OUT)    :: tensor1_out, tensor2_out
1195         !! tensor 1 out
1196         !! tensor 2 out
1197      INTEGER, DIMENSION(:), INTENT(IN)           :: ind1_free, ind2_free
1198         !! indices of tensor 1 that are "free" (not linked to any index of tensor 2)
1199      INTEGER, DIMENSION(:), INTENT(IN)           :: ind1_linked, ind2_linked
1200         !! indices of tensor 1 that are linked to indices of tensor 2
1201         !! 1:1 correspondence with ind1_linked
1202      CHARACTER(LEN=1), INTENT(OUT)               :: trans1, trans2
1203         !! transpose flag of matrix rep. of tensor 1
1204         !! transpose flag of matrix rep. tensor 2
1205      LOGICAL, INTENT(OUT)                        :: new1, new2
1206         !! whether a new tensor 1 was created
1207         !! whether a new tensor 2 was created
1208      INTEGER, INTENT(OUT) :: ref_tensor
1209      LOGICAL, INTENT(IN), OPTIONAL               :: nodata1, nodata2
1210         !! don't copy data of tensor 1
1211         !! don't copy data of tensor 2
1212      LOGICAL, INTENT(INOUT), OPTIONAL            :: move_data_1, move_data_2
1213         !! memory optimization: transfer data s.t. tensor1 may be empty on return
1214         !! memory optimization: transfer data s.t. tensor2 may be empty on return
1215      LOGICAL, INTENT(IN), OPTIONAL               :: optimize_dist
1216         !! experimental: optimize distribution
1217      INTEGER, INTENT(IN), OPTIONAL               :: unit_nr
1218         !! output unit
1219      INTEGER                                     :: compat1, compat1_old, compat2, compat2_old, &
1220                                                     comm_2d, unit_nr_prv
1221      TYPE(array_list)                            :: dist_list
1222      INTEGER, DIMENSION(:), ALLOCATABLE          :: mp_dims
1223      TYPE(dbcsr_t_distribution_type)             :: dist_in
1224      INTEGER(KIND=int_8)                         :: nblkrows, nblkcols
1225      LOGICAL                                     :: optimize_dist_prv
1226      INTEGER, DIMENSION(ndims_tensor(tensor1)) :: dims1
1227      INTEGER, DIMENSION(ndims_tensor(tensor2)) :: dims2
1228
1229      NULLIFY (tensor1_out, tensor2_out)
1230
1231      unit_nr_prv = prep_output_unit(unit_nr)
1232
1233      CALL blk_dims_tensor(tensor1, dims1)
1234      CALL blk_dims_tensor(tensor2, dims2)
1235
1236      IF (PRODUCT(int(dims1, int_8)) .GE. PRODUCT(int(dims2, int_8))) THEN
1237         ref_tensor = 1
1238      ELSE
1239         ref_tensor = 2
1240      ENDIF
1241
1242      IF (PRESENT(optimize_dist)) THEN
1243         optimize_dist_prv = optimize_dist
1244      ELSE
1245         optimize_dist_prv = .FALSE.
1246      ENDIF
1247
1248      compat1 = compat_map(tensor1%nd_index, ind1_linked)
1249      compat2 = compat_map(tensor2%nd_index, ind2_linked)
1250      compat1_old = compat1
1251      compat2_old = compat2
1252
1253      IF (unit_nr_prv > 0) THEN
1254         WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor1%name), ":"
1255         SELECT CASE (compat1)
1256         CASE (0)
1257            WRITE (unit_nr_prv, '(A)') "Not compatible"
1258         CASE (1)
1259            WRITE (unit_nr_prv, '(A)') "Normal"
1260         CASE (2)
1261            WRITE (unit_nr_prv, '(A)') "Transposed"
1262         END SELECT
1263         WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor2%name), ":"
1264         SELECT CASE (compat2)
1265         CASE (0)
1266            WRITE (unit_nr_prv, '(A)') "Not compatible"
1267         CASE (1)
1268            WRITE (unit_nr_prv, '(A)') "Normal"
1269         CASE (2)
1270            WRITE (unit_nr_prv, '(A)') "Transposed"
1271         END SELECT
1272      ENDIF
1273
1274      new1 = .FALSE.
1275      new2 = .FALSE.
1276
1277      IF (compat1 == 0 .OR. optimize_dist_prv) THEN
1278         new1 = .TRUE.
1279      ENDIF
1280
1281      IF (compat2 == 0 .OR. optimize_dist_prv) THEN
1282         new2 = .TRUE.
1283      ENDIF
1284
1285      IF (ref_tensor == 1) THEN ! tensor 1 is reference and tensor 2 is reshaped compatible with tensor 1
1286         IF (compat1 == 0 .OR. optimize_dist_prv) THEN ! tensor 1 is not contraction compatible --> reshape
1287            IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "Redistribution of", TRIM(tensor1%name)
1288            nblkrows = PRODUCT(INT(dims1(ind1_linked), KIND=int_8))
1289            nblkcols = PRODUCT(INT(dims1(ind1_free), KIND=int_8))
1290            comm_2d = dbcsr_tas_mp_comm(tensor1%pgrid%mp_comm_2d, nblkrows, nblkcols)
1291            ALLOCATE (tensor1_out)
1292            CALL dbcsr_t_remap(tensor1, ind1_linked, ind1_free, tensor1_out, comm_2d=comm_2d, &
1293                               nodata=nodata1, move_data=move_data_1)
1294            CALL mp_comm_free(comm_2d)
1295            compat1 = 1
1296         ELSE
1297            IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor1%name)
1298            tensor1_out => tensor1
1299         ENDIF
1300         IF (compat2 == 0 .OR. optimize_dist_prv) THEN ! tensor 2 is not contraction compatible --> reshape
1301            IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A,1X,A,1X,A)') "Redistribution of", &
1302               TRIM(tensor2%name), "compatible with", TRIM(tensor1%name)
1303            dist_in = dbcsr_t_distribution(tensor1_out)
1304            dist_list = array_sublist(dist_in%nd_dist, ind1_linked)
1305            IF (compat1 == 1) THEN ! linked index is first 2d dimension
1306               ! get distribution of linked index, tensor 2 must adopt this distribution
1307               ! get grid dimensions of linked index
1308               ALLOCATE (mp_dims(ndims_mapping_row(dist_in%pgrid%nd_index_grid)))
1309               CALL dbcsr_t_get_mapping_info(dist_in%pgrid%nd_index_grid, dims1_2d=mp_dims)
1310               ALLOCATE (tensor2_out)
1311               CALL dbcsr_t_remap(tensor2, ind2_linked, ind2_free, tensor2_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
1312                                  dist1=dist_list, mp_dims_1=mp_dims, nodata=nodata2, move_data=move_data_2)
1313            ELSEIF (compat1 == 2) THEN ! linked index is second 2d dimension
1314               ! get distribution of linked index, tensor 2 must adopt this distribution
1315               ! get grid dimensions of linked index
1316               ALLOCATE (mp_dims(ndims_mapping_column(dist_in%pgrid%nd_index_grid)))
1317               CALL dbcsr_t_get_mapping_info(dist_in%pgrid%nd_index_grid, dims2_2d=mp_dims)
1318               ALLOCATE (tensor2_out)
1319               CALL dbcsr_t_remap(tensor2, ind2_free, ind2_linked, tensor2_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
1320                                  dist2=dist_list, mp_dims_2=mp_dims, nodata=nodata2, move_data=move_data_2)
1321            ELSE
1322               DBCSR_ABORT("should not happen")
1323            ENDIF
1324            compat2 = compat1
1325         ELSE
1326            IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor2%name)
1327            tensor2_out => tensor2
1328         ENDIF
1329      ELSE ! tensor 2 is reference and tensor 1 is reshaped compatible with tensor 2
1330         IF (compat2 == 0 .OR. optimize_dist_prv) THEN ! tensor 2 is not contraction compatible --> reshape
1331            IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "Redistribution of", TRIM(tensor2%name)
1332            nblkrows = PRODUCT(INT(dims2(ind2_linked), KIND=int_8))
1333            nblkcols = PRODUCT(INT(dims2(ind2_free), KIND=int_8))
1334            comm_2d = dbcsr_tas_mp_comm(tensor2%pgrid%mp_comm_2d, nblkrows, nblkcols)
1335            ALLOCATE (tensor2_out)
1336            CALL dbcsr_t_remap(tensor2, ind2_linked, ind2_free, tensor2_out, nodata=nodata2, move_data=move_data_2)
1337            CALL mp_comm_free(comm_2d)
1338            compat2 = 1
1339         ELSE
1340            IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor2%name)
1341            tensor2_out => tensor2
1342         ENDIF
1343         IF (compat1 == 0 .OR. optimize_dist_prv) THEN ! tensor 1 is not contraction compatible --> reshape
1344            IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A,1X,A,1X,A)') "Redistribution of", TRIM(tensor1%name), &
1345               "compatible with", TRIM(tensor2%name)
1346            dist_in = dbcsr_t_distribution(tensor2_out)
1347            dist_list = array_sublist(dist_in%nd_dist, ind2_linked)
1348            IF (compat2 == 1) THEN
1349               ALLOCATE (mp_dims(ndims_mapping_row(dist_in%pgrid%nd_index_grid)))
1350               CALL dbcsr_t_get_mapping_info(dist_in%pgrid%nd_index_grid, dims1_2d=mp_dims)
1351               ALLOCATE (tensor1_out)
1352               CALL dbcsr_t_remap(tensor1, ind1_linked, ind1_free, tensor1_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
1353                                  dist1=dist_list, mp_dims_1=mp_dims, nodata=nodata1, move_data=move_data_1)
1354            ELSEIF (compat2 == 2) THEN
1355               ALLOCATE (mp_dims(ndims_mapping_column(dist_in%pgrid%nd_index_grid)))
1356               CALL dbcsr_t_get_mapping_info(dist_in%pgrid%nd_index_grid, dims2_2d=mp_dims)
1357               ALLOCATE (tensor1_out)
1358               CALL dbcsr_t_remap(tensor1, ind1_free, ind1_linked, tensor1_out, comm_2d=dist_in%pgrid%mp_comm_2d, &
1359                                  dist2=dist_list, mp_dims_2=mp_dims, nodata=nodata1, move_data=move_data_1)
1360            ELSE
1361               DBCSR_ABORT("should not happen")
1362            ENDIF
1363            compat1 = compat2
1364         ELSE
1365            IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor1%name)
1366            tensor1_out => tensor1
1367         ENDIF
1368      ENDIF
1369
1370      SELECT CASE (compat1)
1371      CASE (1)
1372         trans1 = dbcsr_no_transpose
1373      CASE (2)
1374         trans1 = dbcsr_transpose
1375      CASE DEFAULT
1376         DBCSR_ABORT("should not happen")
1377      END SELECT
1378
1379      SELECT CASE (compat2)
1380      CASE (1)
1381         trans2 = dbcsr_no_transpose
1382      CASE (2)
1383         trans2 = dbcsr_transpose
1384      CASE DEFAULT
1385         DBCSR_ABORT("should not happen")
1386      END SELECT
1387
1388      IF (unit_nr_prv > 0) THEN
1389         IF (compat1 .NE. compat1_old) THEN
1390            WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor1_out%name), ":"
1391            SELECT CASE (compat1)
1392            CASE (0)
1393               WRITE (unit_nr_prv, '(A)') "Not compatible"
1394            CASE (1)
1395               WRITE (unit_nr_prv, '(A)') "Normal"
1396            CASE (2)
1397               WRITE (unit_nr_prv, '(A)') "Transposed"
1398            END SELECT
1399         ENDIF
1400         IF (compat2 .NE. compat2_old) THEN
1401            WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor2_out%name), ":"
1402            SELECT CASE (compat2)
1403            CASE (0)
1404               WRITE (unit_nr_prv, '(A)') "Not compatible"
1405            CASE (1)
1406               WRITE (unit_nr_prv, '(A)') "Normal"
1407            CASE (2)
1408               WRITE (unit_nr_prv, '(A)') "Transposed"
1409            END SELECT
1410         ENDIF
1411      ENDIF
1412
1413      IF (new1 .AND. PRESENT(move_data_1)) move_data_1 = .TRUE.
1414      IF (new2 .AND. PRESENT(move_data_2)) move_data_2 = .TRUE.
1415
1416   END SUBROUTINE
1417
1418   SUBROUTINE reshape_mm_small(tensor_in, ind1, ind2, tensor_out, trans, new, nodata, move_data, unit_nr)
1419      !! Prepare tensor for contraction: redistribute to a 2d format which can be contracted by
1420      !! matrix multiplication. This routine reshapes the smallest of the three tensors.
1421
1422      TYPE(dbcsr_t_type), TARGET, INTENT(INOUT)   :: tensor_in
1423         !! tensor in
1424      INTEGER, DIMENSION(:), INTENT(IN)           :: ind1, ind2
1425         !! index that should be mapped to first matrix dimension
1426         !! index that should be mapped to second matrix dimension
1427      TYPE(dbcsr_t_type), POINTER, INTENT(OUT)    :: tensor_out
1428         !! tensor out
1429      CHARACTER(LEN=1), INTENT(OUT)               :: trans
1430         !! transpose flag of matrix rep.
1431      LOGICAL, INTENT(OUT)                        :: new
1432         !! whether a new tensor was created for tensor_out
1433      LOGICAL, INTENT(IN), OPTIONAL               :: nodata, move_data
1434         !! don't copy tensor data
1435         !! memory optimization: transfer data s.t. tensor_in may be empty on return
1436      INTEGER, INTENT(IN), OPTIONAL               :: unit_nr
1437         !! output unit
1438      INTEGER                                     :: compat1, compat2, compat1_old, compat2_old, unit_nr_prv
1439      LOGICAL                                     :: nodata_prv
1440
1441      NULLIFY (tensor_out)
1442      IF (PRESENT(nodata)) THEN
1443         nodata_prv = nodata
1444      ELSE
1445         nodata_prv = .FALSE.
1446      ENDIF
1447
1448      unit_nr_prv = prep_output_unit(unit_nr)
1449
1450      new = .FALSE.
1451      compat1 = compat_map(tensor_in%nd_index, ind1)
1452      compat2 = compat_map(tensor_in%nd_index, ind2)
1453      compat1_old = compat1; compat2_old = compat2
1454      IF (unit_nr_prv > 0) THEN
1455         WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor_in%name), ":"
1456         IF (compat1 == 1 .AND. compat2 == 2) THEN
1457            WRITE (unit_nr_prv, '(A)') "Normal"
1458         ELSEIF (compat1 == 2 .AND. compat2 == 1) THEN
1459            WRITE (unit_nr_prv, '(A)') "Transposed"
1460         ELSE
1461            WRITE (unit_nr_prv, '(A)') "Not compatible"
1462         ENDIF
1463      ENDIF
1464      IF (compat1 == 0 .or. compat2 == 0) THEN ! index mapping not compatible with contract index
1465
1466         IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "Redistribution of", TRIM(tensor_in%name)
1467
1468         ALLOCATE (tensor_out)
1469         CALL dbcsr_t_remap(tensor_in, ind1, ind2, tensor_out, nodata=nodata, move_data=move_data)
1470         CALL dbcsr_t_copy_contraction_storage(tensor_in, tensor_out)
1471         compat1 = 1
1472         compat2 = 2
1473         new = .TRUE.
1474      ELSE
1475         IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor_in%name)
1476         tensor_out => tensor_in
1477      ENDIF
1478
1479      IF (compat1 == 1 .AND. compat2 == 2) THEN
1480         trans = dbcsr_no_transpose
1481      ELSEIF (compat1 == 2 .AND. compat2 == 1) THEN
1482         trans = dbcsr_transpose
1483      ELSE
1484         DBCSR_ABORT("this should not happen")
1485      ENDIF
1486
1487      IF (unit_nr_prv > 0) THEN
1488         IF (compat1_old .NE. compat1 .OR. compat2_old .NE. compat2) THEN
1489            WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor_out%name), ":"
1490            IF (compat1 == 1 .AND. compat2 == 2) THEN
1491               WRITE (unit_nr_prv, '(A)') "Normal"
1492            ELSEIF (compat1 == 2 .AND. compat2 == 1) THEN
1493               WRITE (unit_nr_prv, '(A)') "Transposed"
1494            ELSE
1495               WRITE (unit_nr_prv, '(A)') "Not compatible"
1496            ENDIF
1497         ENDIF
1498      ENDIF
1499
1500   END SUBROUTINE
1501
1502   FUNCTION update_contraction_storage(storage, split_opt, split) RESULT(do_change_pgrid)
1503      !! update contraction storage that keeps track of process grids during a batched contraction
1504      !! and decide if tensor process grid needs to be optimized
1505      TYPE(dbcsr_t_contraction_storage), INTENT(INOUT) :: storage
1506      TYPE(dbcsr_tas_split_info), INTENT(IN)           :: split_opt
1507         !! optimized TAS process grid
1508      TYPE(dbcsr_tas_split_info), INTENT(IN)           :: split
1509         !! current TAS process grid
1510      INTEGER, DIMENSION(2) :: pdims_opt, coor, pdims, pdims_sub
1511      LOGICAL, DIMENSION(2) :: periods
1512      LOGICAL, DIMENSION(2) :: do_change_pgrid
1513      REAL(kind=real_8) :: change_criterion, pdims_ratio
1514      INTEGER :: nsplit_opt, nsplit
1515
1516      DBCSR_ASSERT(ALLOCATED(split_opt%ngroup_opt))
1517      nsplit_opt = split_opt%ngroup_opt
1518      nsplit = split%ngroup
1519
1520      CALL mp_environ(split_opt%mp_comm, 2, pdims_opt, coor, periods)
1521      CALL mp_environ(split%mp_comm, 2, pdims, coor, periods)
1522
1523      storage%ibatch = storage%ibatch + 1
1524
1525      storage%nsplit_avg = (storage%nsplit_avg*REAL(storage%ibatch - 1, real_8) + REAL(nsplit_opt, real_8)) &
1526                           /REAL(storage%ibatch, real_8)
1527
1528      SELECT CASE (split_opt%split_rowcol)
1529      CASE (rowsplit)
1530         pdims_ratio = REAL(pdims(1), real_8)/pdims(2)
1531      CASE (colsplit)
1532         pdims_ratio = REAL(pdims(2), real_8)/pdims(1)
1533      END SELECT
1534
1535      do_change_pgrid(:) = .FALSE.
1536
1537      ! check for process grid dimensions
1538      CALL mp_environ(split%mp_comm_group, 2, pdims_sub, coor, periods)
1539      change_criterion = MAXVAL(REAL(pdims_sub, real_8))/MINVAL(pdims_sub)
1540      IF (change_criterion > default_pdims_accept_ratio**2) do_change_pgrid(1) = .TRUE.
1541
1542      ! check for split factor
1543      change_criterion = MAX(REAL(nsplit, real_8)/storage%nsplit_avg, REAL(storage%nsplit_avg, real_8)/nsplit)
1544      IF (change_criterion > default_nsplit_accept_ratio) do_change_pgrid(2) = .TRUE.
1545
1546   END FUNCTION
1547
1548   FUNCTION compat_map(nd_index, compat_ind)
1549      !! Check if 2d index is compatible with tensor index
1550      TYPE(nd_to_2d_mapping), INTENT(IN) :: nd_index
1551      INTEGER, DIMENSION(:), INTENT(IN)  :: compat_ind
1552      INTEGER, DIMENSION(ndims_mapping_row(nd_index)) :: map1
1553      INTEGER, DIMENSION(ndims_mapping_column(nd_index)) :: map2
1554      INTEGER                            :: compat_map
1555
1556      CALL dbcsr_t_get_mapping_info(nd_index, map1_2d=map1, map2_2d=map2)
1557
1558      compat_map = 0
1559      IF (array_eq_i(map1, compat_ind)) THEN
1560         compat_map = 1
1561      ELSEIF (array_eq_i(map2, compat_ind)) THEN
1562         compat_map = 2
1563      ENDIF
1564
1565   END FUNCTION
1566
1567   SUBROUTINE invert_transpose_flag(trans_flag)
1568      CHARACTER(LEN=1), INTENT(INOUT)                    :: trans_flag
1569
1570      IF (trans_flag == dbcsr_transpose) THEN
1571         trans_flag = dbcsr_no_transpose
1572      ELSEIF (trans_flag == dbcsr_no_transpose) THEN
1573         trans_flag = dbcsr_transpose
1574      ENDIF
1575   END SUBROUTINE
1576
1577   SUBROUTINE index_linked_sort(ind_ref, ind_dep)
1578      INTEGER, DIMENSION(:), INTENT(INOUT) :: ind_ref, ind_dep
1579      INTEGER, DIMENSION(SIZE(ind_ref))    :: sort_indices
1580
1581      CALL sort(ind_ref, SIZE(ind_ref), sort_indices)
1582      ind_dep(:) = ind_dep(sort_indices)
1583
1584   END SUBROUTINE
1585
1586   FUNCTION opt_pgrid(tensor, tas_split_info)
1587      TYPE(dbcsr_t_type), INTENT(IN) :: tensor
1588      TYPE(dbcsr_tas_split_info), INTENT(IN) :: tas_split_info
1589      INTEGER, DIMENSION(ndims_matrix_row(tensor)) :: map1
1590      INTEGER, DIMENSION(ndims_matrix_column(tensor)) :: map2
1591      TYPE(dbcsr_t_pgrid_type) :: opt_pgrid
1592      INTEGER, DIMENSION(ndims_tensor(tensor)) :: dims
1593
1594      CALL dbcsr_t_get_mapping_info(tensor%pgrid%nd_index_grid, map1_2d=map1, map2_2d=map2)
1595      CALL blk_dims_tensor(tensor, dims)
1596      opt_pgrid = dbcsr_t_nd_mp_comm(tas_split_info%mp_comm, map1, map2, tdims=dims)
1597
1598      ALLOCATE (opt_pgrid%tas_split_info, SOURCE=tas_split_info)
1599      CALL dbcsr_tas_info_hold(opt_pgrid%tas_split_info)
1600   END FUNCTION
1601
1602   SUBROUTINE dbcsr_t_remap(tensor_in, map1_2d, map2_2d, tensor_out, comm_2d, dist1, dist2, &
1603                            mp_dims_1, mp_dims_2, name, nodata, move_data)
1604      !! Copy tensor to tensor with modified index mapping
1605
1606      TYPE(dbcsr_t_type), INTENT(INOUT)      :: tensor_in
1607      INTEGER, DIMENSION(:), INTENT(IN)      :: map1_2d, map2_2d
1608         !! new index mapping
1609         !! new index mapping
1610      TYPE(dbcsr_t_type), INTENT(OUT)        :: tensor_out
1611      CHARACTER(len=*), INTENT(IN), OPTIONAL :: name
1612      LOGICAL, INTENT(IN), OPTIONAL          :: nodata, move_data
1613      INTEGER, INTENT(IN), OPTIONAL          :: comm_2d
1614      TYPE(array_list), INTENT(IN), OPTIONAL :: dist1, dist2
1615      INTEGER, DIMENSION(SIZE(map1_2d)), OPTIONAL :: mp_dims_1
1616      INTEGER, DIMENSION(SIZE(map2_2d)), OPTIONAL :: mp_dims_2
1617      CHARACTER(len=default_string_length)   :: name_tmp
1618      INTEGER, DIMENSION(:), ALLOCATABLE     :: ${varlist("blk_sizes")}$, &
1619                                                ${varlist("nd_dist")}$
1620      TYPE(dbcsr_t_distribution_type)        :: dist
1621      INTEGER                                :: comm_2d_prv, handle, i
1622      INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: pdims, myploc
1623      CHARACTER(LEN=*), PARAMETER :: routineN = 'dbcsr_t_remap'
1624      LOGICAL                               :: nodata_prv
1625      TYPE(dbcsr_t_pgrid_type)              :: comm_nd
1626
1627      CALL timeset(routineN, handle)
1628
1629      IF (PRESENT(name)) THEN
1630         name_tmp = name
1631      ELSE
1632         name_tmp = tensor_in%name
1633      ENDIF
1634      IF (PRESENT(dist1)) THEN
1635         DBCSR_ASSERT(PRESENT(mp_dims_1))
1636      ENDIF
1637
1638      IF (PRESENT(dist2)) THEN
1639         DBCSR_ASSERT(PRESENT(mp_dims_2))
1640      ENDIF
1641
1642      IF (PRESENT(comm_2d)) THEN
1643         comm_2d_prv = comm_2d
1644      ELSE
1645         comm_2d_prv = tensor_in%pgrid%mp_comm_2d
1646      ENDIF
1647
1648      comm_nd = dbcsr_t_nd_mp_comm(comm_2d_prv, map1_2d, map2_2d, dims1_nd=mp_dims_1, dims2_nd=mp_dims_2)
1649      CALL mp_environ_pgrid(comm_nd, pdims, myploc)
1650
1651#:for ndim in ndims
1652      IF (ndims_tensor(tensor_in) == ${ndim}$) THEN
1653         CALL get_arrays(tensor_in%blk_sizes, ${varlist("blk_sizes", nmax=ndim)}$)
1654      ENDIF
1655#:endfor
1656
1657#:for ndim in ndims
1658      IF (ndims_tensor(tensor_in) == ${ndim}$) THEN
1659#:for idim in range(1, ndim+1)
1660         IF (PRESENT(dist1)) THEN
1661            IF (ANY(map1_2d == ${idim}$)) THEN
1662               i = MINLOC(map1_2d, dim=1, mask=map1_2d == ${idim}$) ! i is location of idim in map1_2d
1663               CALL get_ith_array(dist1, i, nd_dist_${idim}$)
1664            ENDIF
1665         ENDIF
1666
1667         IF (PRESENT(dist2)) THEN
1668            IF (ANY(map2_2d == ${idim}$)) THEN
1669               i = MINLOC(map2_2d, dim=1, mask=map2_2d == ${idim}$) ! i is location of idim in map2_2d
1670               CALL get_ith_array(dist2, i, nd_dist_${idim}$)
1671            ENDIF
1672         ENDIF
1673
1674         IF (.NOT. ALLOCATED(nd_dist_${idim}$)) THEN
1675            ALLOCATE (nd_dist_${idim}$ (SIZE(blk_sizes_${idim}$)))
1676            CALL dbcsr_t_default_distvec(SIZE(blk_sizes_${idim}$), pdims(${idim}$), blk_sizes_${idim}$, nd_dist_${idim}$)
1677         ENDIF
1678#:endfor
1679         CALL dbcsr_t_distribution_new_expert(dist, comm_nd, map1_2d, map2_2d, &
1680                                              ${varlist("nd_dist", nmax=ndim)}$, own_comm=.TRUE.)
1681      ENDIF
1682#:endfor
1683
1684#:for ndim in ndims
1685      IF (ndims_tensor(tensor_in) == ${ndim}$) THEN
1686         CALL dbcsr_t_create(tensor_out, name_tmp, dist, &
1687                             map1_2d, map2_2d, dbcsr_tas_get_data_type(tensor_in%matrix_rep), &
1688                             ${varlist("blk_sizes", nmax=ndim)}$)
1689      ENDIF
1690#:endfor
1691
1692      IF (PRESENT(nodata)) THEN
1693         nodata_prv = nodata
1694      ELSE
1695         nodata_prv = .FALSE.
1696      ENDIF
1697
1698      IF (.NOT. nodata_prv) CALL dbcsr_t_copy_expert(tensor_in, tensor_out, move_data=move_data)
1699      CALL dbcsr_t_distribution_destroy(dist)
1700
1701      CALL timestop(handle)
1702   END SUBROUTINE
1703
1704   SUBROUTINE dbcsr_t_align_index(tensor_in, tensor_out, order)
1705      !! Align index with data
1706
1707      TYPE(dbcsr_t_type), INTENT(INOUT)               :: tensor_in
1708      TYPE(dbcsr_t_type), INTENT(OUT)                 :: tensor_out
1709      INTEGER, DIMENSION(ndims_matrix_row(tensor_in)) :: map1_2d
1710      INTEGER, DIMENSION(ndims_matrix_column(tensor_in)) :: map2_2d
1711      INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
1712         INTENT(OUT), OPTIONAL                        :: order
1713         !! permutation resulting from alignment
1714      INTEGER, DIMENSION(ndims_tensor(tensor_in))     :: order_prv
1715      CHARACTER(LEN=*), PARAMETER :: routineN = 'dbcsr_t_align_index'
1716      INTEGER                                         :: handle
1717
1718      CALL timeset(routineN, handle)
1719
1720      CALL dbcsr_t_get_mapping_info(tensor_in%nd_index_blk, map1_2d=map1_2d, map2_2d=map2_2d)
1721      order_prv = dbcsr_t_inverse_order([map1_2d, map2_2d])
1722      CALL dbcsr_t_permute_index(tensor_in, tensor_out, order=order_prv)
1723
1724      IF (PRESENT(order)) order = order_prv
1725
1726      CALL timestop(handle)
1727   END SUBROUTINE
1728
1729   SUBROUTINE dbcsr_t_permute_index(tensor_in, tensor_out, order)
1730      !! Create new tensor by reordering index, data is copied exactly (shallow copy)
1731      TYPE(dbcsr_t_type), INTENT(INOUT)                  :: tensor_in
1732      TYPE(dbcsr_t_type), INTENT(OUT)                 :: tensor_out
1733      INTEGER, DIMENSION(ndims_tensor(tensor_in)), &
1734         INTENT(IN)                                   :: order
1735
1736      TYPE(nd_to_2d_mapping)                          :: nd_index_blk_rs, nd_index_rs
1737      CHARACTER(LEN=*), PARAMETER :: routineN = 'dbcsr_t_permute_index'
1738      INTEGER                                         :: handle
1739      INTEGER                                         :: ndims
1740
1741      CALL timeset(routineN, handle)
1742
1743      ndims = ndims_tensor(tensor_in)
1744
1745      CALL permute_index(tensor_in%nd_index, nd_index_rs, order)
1746      CALL permute_index(tensor_in%nd_index_blk, nd_index_blk_rs, order)
1747      CALL permute_index(tensor_in%pgrid%nd_index_grid, tensor_out%pgrid%nd_index_grid, order)
1748
1749      tensor_out%matrix_rep => tensor_in%matrix_rep
1750      tensor_out%owns_matrix = .FALSE.
1751
1752      tensor_out%nd_index = nd_index_rs
1753      tensor_out%nd_index_blk = nd_index_blk_rs
1754      tensor_out%pgrid%mp_comm_2d = tensor_in%pgrid%mp_comm_2d
1755      IF (ALLOCATED(tensor_in%pgrid%tas_split_info)) THEN
1756         ALLOCATE (tensor_out%pgrid%tas_split_info, SOURCE=tensor_in%pgrid%tas_split_info)
1757      ENDIF
1758      tensor_out%refcount => tensor_in%refcount
1759      CALL dbcsr_t_hold(tensor_out)
1760
1761      CALL reorder_arrays(tensor_in%blk_sizes, tensor_out%blk_sizes, order)
1762      CALL reorder_arrays(tensor_in%blk_offsets, tensor_out%blk_offsets, order)
1763      CALL reorder_arrays(tensor_in%nd_dist, tensor_out%nd_dist, order)
1764      CALL reorder_arrays(tensor_in%blks_local, tensor_out%blks_local, order)
1765      ALLOCATE (tensor_out%nblks_local(ndims))
1766      ALLOCATE (tensor_out%nfull_local(ndims))
1767      tensor_out%nblks_local(order) = tensor_in%nblks_local(:)
1768      tensor_out%nfull_local(order) = tensor_in%nfull_local(:)
1769      tensor_out%name = tensor_in%name
1770      tensor_out%valid = .TRUE.
1771
1772      IF (ALLOCATED(tensor_in%contraction_storage)) THEN
1773         ALLOCATE (tensor_out%contraction_storage, SOURCE=tensor_in%contraction_storage)
1774         CALL destroy_array_list(tensor_out%contraction_storage%batch_ranges)
1775         CALL reorder_arrays(tensor_in%contraction_storage%batch_ranges, tensor_out%contraction_storage%batch_ranges, order)
1776      ENDIF
1777
1778      CALL timestop(handle)
1779   END SUBROUTINE
1780
1781   SUBROUTINE dbcsr_t_contract_index(alpha, tensor_1, tensor_2, beta, tensor_3, &
1782                                     contract_1, notcontract_1, &
1783                                     contract_2, notcontract_2, &
1784                                     map_1, map_2, &
1785                                     bounds_1, bounds_2, bounds_3, &
1786                                     filter_eps, &
1787                                     nblks_local, result_index)
1788      !! get indices of non-zero tensor blocks for contraction result without actually
1789      !! performing contraction.
1790      !! this is an estimate based on block norm multiplication.
1791      !! See documentation of dbcsr_t_contract.
1792      TYPE(dbcsr_scalar_type), INTENT(IN)            :: alpha
1793      TYPE(dbcsr_t_type), INTENT(INOUT), TARGET      :: tensor_1
1794      TYPE(dbcsr_t_type), INTENT(INOUT), TARGET      :: tensor_2
1795      TYPE(dbcsr_scalar_type), INTENT(IN)            :: beta
1796      INTEGER, DIMENSION(:), INTENT(IN)              :: contract_1
1797      INTEGER, DIMENSION(:), INTENT(IN)              :: contract_2
1798      INTEGER, DIMENSION(:), INTENT(IN)              :: map_1
1799      INTEGER, DIMENSION(:), INTENT(IN)              :: map_2
1800      INTEGER, DIMENSION(:), INTENT(IN)              :: notcontract_1
1801      INTEGER, DIMENSION(:), INTENT(IN)              :: notcontract_2
1802      TYPE(dbcsr_t_type), INTENT(INOUT), TARGET      :: tensor_3
1803      INTEGER, DIMENSION(2, SIZE(contract_1)), &
1804         INTENT(IN), OPTIONAL                        :: bounds_1
1805      INTEGER, DIMENSION(2, SIZE(notcontract_1)), &
1806         INTENT(IN), OPTIONAL                        :: bounds_2
1807      INTEGER, DIMENSION(2, SIZE(notcontract_2)), &
1808         INTENT(IN), OPTIONAL                        :: bounds_3
1809      REAL(KIND=real_8), INTENT(IN), OPTIONAL        :: filter_eps
1810      INTEGER, INTENT(OUT)                           :: nblks_local
1811         !! number of local blocks on this MPI rank
1812      INTEGER, DIMENSION(dbcsr_t_max_nblks_local(tensor_3), ndims_tensor(tensor_3)), &
1813         INTENT(OUT)                                 :: result_index
1814         !! indices of local non-zero tensor blocks for tensor_3
1815         !! only the elements result_index(:nblks_local, :) are relevant (all others are set to 0)
1816
1817      CALL dbcsr_t_contract_expert(alpha, tensor_1, tensor_2, beta, tensor_3, &
1818                                   contract_1, notcontract_1, &
1819                                   contract_2, notcontract_2, &
1820                                   map_1, map_2, &
1821                                   bounds_1=bounds_1, &
1822                                   bounds_2=bounds_2, &
1823                                   bounds_3=bounds_3, &
1824                                   filter_eps=filter_eps, &
1825                                   nblks_local=nblks_local, &
1826                                   result_index=result_index)
1827   END SUBROUTINE
1828
1829   SUBROUTINE dbcsr_t_map_bounds_to_tensors(tensor_1, tensor_2, &
1830                                            contract_1, notcontract_1, &
1831                                            contract_2, notcontract_2, &
1832                                            bounds_t1, bounds_t2, &
1833                                            bounds_1, bounds_2, bounds_3, &
1834                                            do_crop_1, do_crop_2)
1835      !! Map contraction bounds to bounds referring to tensor indices
1836      !! see dbcsr_t_contract for docu of dummy arguments
1837
1838      TYPE(dbcsr_t_type), INTENT(IN)      :: tensor_1, tensor_2
1839      INTEGER, DIMENSION(:), INTENT(IN)   :: contract_1, contract_2, &
1840                                             notcontract_1, notcontract_2
1841      INTEGER, DIMENSION(2, ndims_tensor(tensor_1)), &
1842         INTENT(OUT)                                 :: bounds_t1
1843         !! bounds mapped to tensor_1
1844      INTEGER, DIMENSION(2, ndims_tensor(tensor_2)), &
1845         INTENT(OUT)                                 :: bounds_t2
1846         !! bounds mapped to tensor_2
1847      INTEGER, DIMENSION(2, SIZE(contract_1)), &
1848         INTENT(IN), OPTIONAL                        :: bounds_1
1849      INTEGER, DIMENSION(2, SIZE(notcontract_1)), &
1850         INTENT(IN), OPTIONAL                        :: bounds_2
1851      INTEGER, DIMENSION(2, SIZE(notcontract_2)), &
1852         INTENT(IN), OPTIONAL                        :: bounds_3
1853      LOGICAL, INTENT(OUT), OPTIONAL                 :: do_crop_1, do_crop_2
1854         !! whether tensor 1 should be cropped
1855         !! whether tensor 2 should be cropped
1856      LOGICAL, DIMENSION(2)                          :: do_crop
1857
1858      do_crop = .FALSE.
1859
1860      bounds_t1(1, :) = 1
1861      CALL dbcsr_t_get_info(tensor_1, nfull_total=bounds_t1(2, :))
1862
1863      bounds_t2(1, :) = 1
1864      CALL dbcsr_t_get_info(tensor_2, nfull_total=bounds_t2(2, :))
1865
1866      IF (PRESENT(bounds_1)) THEN
1867         bounds_t1(:, contract_1) = bounds_1
1868         do_crop(1) = .TRUE.
1869         bounds_t2(:, contract_2) = bounds_1
1870         do_crop(2) = .TRUE.
1871      ENDIF
1872
1873      IF (PRESENT(bounds_2)) THEN
1874         bounds_t1(:, notcontract_1) = bounds_2
1875         do_crop(1) = .TRUE.
1876      ENDIF
1877
1878      IF (PRESENT(bounds_3)) THEN
1879         bounds_t2(:, notcontract_2) = bounds_3
1880         do_crop(2) = .TRUE.
1881      ENDIF
1882
1883      IF (PRESENT(do_crop_1)) do_crop_1 = do_crop(1)
1884      IF (PRESENT(do_crop_2)) do_crop_2 = do_crop(2)
1885
1886   END SUBROUTINE
1887
1888   SUBROUTINE dbcsr_t_print_contraction_index(tensor_1, indchar1, tensor_2, indchar2, tensor_3, indchar3, unit_nr)
1889      !! print tensor contraction indices in a human readable way
1890
1891      TYPE(dbcsr_t_type), INTENT(IN) :: tensor_1, tensor_2, tensor_3
1892      CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_1)), INTENT(IN) :: indchar1
1893         !! characters printed for index of tensor 1
1894      CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_2)), INTENT(IN) :: indchar2
1895         !! characters printed for index of tensor 2
1896      CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_3)), INTENT(IN) :: indchar3
1897         !! characters printed for index of tensor 3
1898      INTEGER, INTENT(IN) :: unit_nr
1899         !! output unit
1900      INTEGER, DIMENSION(ndims_matrix_row(tensor_1)) :: map11
1901      INTEGER, DIMENSION(ndims_matrix_column(tensor_1)) :: map12
1902      INTEGER, DIMENSION(ndims_matrix_row(tensor_2)) :: map21
1903      INTEGER, DIMENSION(ndims_matrix_column(tensor_2)) :: map22
1904      INTEGER, DIMENSION(ndims_matrix_row(tensor_3)) :: map31
1905      INTEGER, DIMENSION(ndims_matrix_column(tensor_3)) :: map32
1906      INTEGER :: ichar1, ichar2, ichar3, unit_nr_prv
1907
1908      unit_nr_prv = prep_output_unit(unit_nr)
1909
1910      IF (unit_nr_prv /= 0) THEN
1911         CALL dbcsr_t_get_mapping_info(tensor_1%nd_index_blk, map1_2d=map11, map2_2d=map12)
1912         CALL dbcsr_t_get_mapping_info(tensor_2%nd_index_blk, map1_2d=map21, map2_2d=map22)
1913         CALL dbcsr_t_get_mapping_info(tensor_3%nd_index_blk, map1_2d=map31, map2_2d=map32)
1914      ENDIF
1915
1916      IF (unit_nr_prv > 0) THEN
1917         WRITE (unit_nr_prv, '(T2,A)') "INDEX INFO"
1918         WRITE (unit_nr_prv, '(T15,A)', advance='no') "tensor index: ("
1919         DO ichar1 = 1, SIZE(indchar1)
1920            WRITE (unit_nr_prv, '(A1)', advance='no') indchar1(ichar1)
1921         ENDDO
1922         WRITE (unit_nr_prv, '(A)', advance='no') ") x ("
1923         DO ichar2 = 1, SIZE(indchar2)
1924            WRITE (unit_nr_prv, '(A1)', advance='no') indchar2(ichar2)
1925         ENDDO
1926         WRITE (unit_nr_prv, '(A)', advance='no') ") = ("
1927         DO ichar3 = 1, SIZE(indchar3)
1928            WRITE (unit_nr_prv, '(A1)', advance='no') indchar3(ichar3)
1929         ENDDO
1930         WRITE (unit_nr_prv, '(A)') ")"
1931
1932         WRITE (unit_nr_prv, '(T15,A)', advance='no') "matrix index: ("
1933         DO ichar1 = 1, SIZE(map11)
1934            WRITE (unit_nr_prv, '(A1)', advance='no') indchar1(map11(ichar1))
1935         ENDDO
1936         WRITE (unit_nr_prv, '(A1)', advance='no') "|"
1937         DO ichar1 = 1, SIZE(map12)
1938            WRITE (unit_nr_prv, '(A1)', advance='no') indchar1(map12(ichar1))
1939         ENDDO
1940         WRITE (unit_nr_prv, '(A)', advance='no') ") x ("
1941         DO ichar2 = 1, SIZE(map21)
1942            WRITE (unit_nr_prv, '(A1)', advance='no') indchar2(map21(ichar2))
1943         ENDDO
1944         WRITE (unit_nr_prv, '(A1)', advance='no') "|"
1945         DO ichar2 = 1, SIZE(map22)
1946            WRITE (unit_nr_prv, '(A1)', advance='no') indchar2(map22(ichar2))
1947         ENDDO
1948         WRITE (unit_nr_prv, '(A)', advance='no') ") = ("
1949         DO ichar3 = 1, SIZE(map31)
1950            WRITE (unit_nr_prv, '(A1)', advance='no') indchar3(map31(ichar3))
1951         ENDDO
1952         WRITE (unit_nr_prv, '(A1)', advance='no') "|"
1953         DO ichar3 = 1, SIZE(map32)
1954            WRITE (unit_nr_prv, '(A1)', advance='no') indchar3(map32(ichar3))
1955         ENDDO
1956         WRITE (unit_nr_prv, '(A)') ")"
1957      ENDIF
1958
1959   END SUBROUTINE
1960
1961   SUBROUTINE dbcsr_t_batched_contract_init(tensor, ${varlist("batch_range")}$)
1962      !! Initialize batched contraction for this tensor.
1963      !!
1964      !! Explanation: A batched contraction is a contraction performed in several consecutive steps by
1965      !! specification of bounds in dbcsr_t_contract. This can be used to reduce memory by a large factor.
1966      !! The routines dbcsr_t_batched_contract_init and dbcsr_t_batched_contract_finalize should be
1967      !! called to define the scope of a batched contraction as this enables important optimizations
1968      !! (adapting communication scheme to batches and adapting process grid to multiplication algorithm).
1969      !! The routines dbcsr_t_batched_contract_init and dbcsr_t_batched_contract_finalize must be called
1970      !! before the first and after the last contraction step on all 3 tensors.
1971      !!
1972      !! Requirements:
1973      !! - the tensors are in a compatible matrix layout (see documentation of `dbcsr_t_contract`, note 2 & 3).
1974      !!   If they are not, process grid optimizations are disabled and a warning is issued.
1975      !! - within the scope of a batched contraction, it is not allowed to access or change tensor data
1976      !!   except by calling the routines dbcsr_t_contract & dbcsr_t_copy.
1977      !! - the bounds affecting indices of the smallest tensor must not change in the course of a batched
1978      !!   contraction (todo: get rid of this requirement).
1979      !!
1980      !! Side effects:
1981      !! - the parallel layout (process grid and distribution) of all tensors may change. In order to
1982      !!   disable the process grid optimization including this side effect, call this routine only on the
1983      !!   smallest of the 3 tensors.
1984      !!
1985      !! @note
1986      !! Note 1: for an example of batched contraction see `examples/dbcsr_tensor_example.F`.
1987      !! (todo: the example is outdated and should be updated).
1988      !!
1989      !! Note 2: it is meaningful to use this feature if the contraction consists of one batch only
1990      !! but if multiple contractions involving the same 3 tensors are performed
1991      !! (batched_contract_init and batched_contract_finalize must then be called before/after each
1992      !! contraction call). The process grid is then optimized after the first contraction
1993      !! and future contraction may profit from this optimization.
1994      !! @endnote
1995      TYPE(dbcsr_t_type), INTENT(INOUT) :: tensor
1996      INTEGER, DIMENSION(:), OPTIONAL, INTENT(IN)        :: ${varlist("batch_range")}$
1997         !! For internal load balancing optimizations, optionally specify the index ranges of
1998         !! batched contraction.
1999         !! batch_range_i refers to the ith tensor dimension and contains all block indices starting
2000         !! a new range. The size should be the number of ranges plus one, the last element being the
2001         !! block index plus one of the last block in the last range.
2002      INTEGER, DIMENSION(ndims_tensor(tensor)) :: tdims
2003      INTEGER, DIMENSION(:), ALLOCATABLE                 :: ${varlist("batch_range_prv")}$
2004      LOGICAL :: static_range
2005
2006      CALL dbcsr_t_get_info(tensor, nblks_total=tdims)
2007
2008      static_range = .TRUE.
2009#:for idim in range(1, maxdim+1)
2010      IF (ndims_tensor(tensor) >= ${idim}$) THEN
2011         IF (PRESENT(batch_range_${idim}$)) THEN
2012            CALL allocate_any(batch_range_prv_${idim}$, source=batch_range_${idim}$)
2013            static_range = .FALSE.
2014         ELSE
2015            ALLOCATE (batch_range_prv_${idim}$ (2))
2016            batch_range_prv_${idim}$ (1) = 1
2017            batch_range_prv_${idim}$ (2) = tdims(${idim}$) + 1
2018         ENDIF
2019      ENDIF
2020#:endfor
2021
2022      ALLOCATE (tensor%contraction_storage)
2023      tensor%contraction_storage%static = static_range
2024      IF (static_range) THEN
2025         CALL dbcsr_tas_batched_mm_init(tensor%matrix_rep)
2026      ENDIF
2027      tensor%contraction_storage%nsplit_avg = 0.0_real_8
2028      tensor%contraction_storage%ibatch = 0
2029
2030#:for ndim in range(1, maxdim+1)
2031      IF (ndims_tensor(tensor) == ${ndim}$) THEN
2032         CALL create_array_list(tensor%contraction_storage%batch_ranges, ${ndim}$, &
2033                                ${varlist("batch_range_prv", nmax=ndim)}$)
2034      ENDIF
2035#:endfor
2036
2037   END SUBROUTINE
2038
2039   SUBROUTINE dbcsr_t_batched_contract_finalize(tensor, unit_nr)
2040      !! finalize batched contraction. This performs all communication that has been postponed in the
2041      !! contraction calls.
2042      TYPE(dbcsr_t_type), INTENT(INOUT) :: tensor
2043      INTEGER, INTENT(IN), OPTIONAL :: unit_nr
2044      LOGICAL :: do_write
2045      INTEGER :: unit_nr_prv, handle
2046
2047      CALL mp_sync(tensor%pgrid%mp_comm_2d)
2048      CALL timeset("dbcsr_t_total", handle)
2049      unit_nr_prv = prep_output_unit(unit_nr)
2050
2051      do_write = .FALSE.
2052
2053      IF (tensor%contraction_storage%static) THEN
2054         IF (tensor%matrix_rep%do_batched > 0) THEN
2055            IF (tensor%matrix_rep%mm_storage%batched_out) do_write = .TRUE.
2056         ENDIF
2057         CALL dbcsr_tas_batched_mm_finalize(tensor%matrix_rep)
2058      ENDIF
2059
2060      IF (do_write .AND. unit_nr_prv /= 0) THEN
2061         IF (unit_nr_prv > 0) THEN
2062            WRITE (unit_nr_prv, "(T2,A)") &
2063               "FINALIZING BATCHED PROCESSING OF MATMUL"
2064         ENDIF
2065         CALL dbcsr_t_write_tensor_info(tensor, unit_nr_prv)
2066         CALL dbcsr_t_write_tensor_dist(tensor, unit_nr_prv)
2067      ENDIF
2068
2069      CALL destroy_array_list(tensor%contraction_storage%batch_ranges)
2070      DEALLOCATE (tensor%contraction_storage)
2071      CALL mp_sync(tensor%pgrid%mp_comm_2d)
2072      CALL timestop(handle)
2073
2074   END SUBROUTINE
2075
2076   SUBROUTINE dbcsr_t_change_pgrid(tensor, pgrid, ${varlist("batch_range")}$, &
2077                                   nodata, pgrid_changed, unit_nr)
2078      !! change the process grid of a tensor
2079      TYPE(dbcsr_t_type), INTENT(INOUT)                  :: tensor
2080      TYPE(dbcsr_t_pgrid_type), INTENT(IN)               :: pgrid
2081      INTEGER, DIMENSION(:), OPTIONAL, INTENT(IN)        :: ${varlist("batch_range")}$
2082         !! For internal load balancing optimizations, optionally specify the index ranges of
2083         !! batched contraction.
2084         !! batch_range_i refers to the ith tensor dimension and contains all block indices starting
2085         !! a new range. The size should be the number of ranges plus one, the last element being the
2086         !! block index plus one of the last block in the last range.
2087      LOGICAL, INTENT(IN), OPTIONAL                      :: nodata
2088         !! optionally don't copy the tensor data (then tensor is empty on returned)
2089      LOGICAL, INTENT(OUT), OPTIONAL                     :: pgrid_changed
2090      INTEGER, INTENT(IN), OPTIONAL                      :: unit_nr
2091      CHARACTER(LEN=*), PARAMETER :: routineN = 'dbcsr_t_change_pgrid'
2092      CHARACTER(default_string_length)                   :: name
2093      INTEGER                                            :: handle
2094      INTEGER, ALLOCATABLE, DIMENSION(:)                 :: ${varlist("bs")}$, &
2095                                                            ${varlist("dist")}$
2096      INTEGER, DIMENSION(ndims_tensor(tensor))           :: pcoord, pcoord_ref, pdims, pdims_ref, &
2097                                                            tdims
2098      TYPE(dbcsr_t_type)                                 :: t_tmp
2099      TYPE(dbcsr_t_distribution_type)                    :: dist
2100      INTEGER, DIMENSION(ndims_matrix_row(tensor)) :: map1
2101      INTEGER, &
2102         DIMENSION(ndims_matrix_column(tensor))    :: map2
2103      LOGICAL, DIMENSION(ndims_tensor(tensor))             :: mem_aware
2104      INTEGER, DIMENSION(ndims_tensor(tensor)) :: nbatch
2105      INTEGER :: ind1, ind2, batch_size, ibatch
2106
2107      IF (PRESENT(pgrid_changed)) pgrid_changed = .FALSE.
2108      CALL mp_environ_pgrid(pgrid, pdims, pcoord)
2109      CALL mp_environ_pgrid(tensor%pgrid, pdims_ref, pcoord_ref)
2110
2111      IF (ALL(pdims == pdims_ref)) THEN
2112         IF (ALLOCATED(pgrid%tas_split_info) .AND. ALLOCATED(tensor%pgrid%tas_split_info)) THEN
2113            IF (pgrid%tas_split_info%ngroup == tensor%pgrid%tas_split_info%ngroup) THEN
2114               RETURN
2115            ENDIF
2116         ENDIF
2117      ENDIF
2118
2119      CALL timeset(routineN, handle)
2120
2121#:for idim in range(1, maxdim+1)
2122      IF (ndims_tensor(tensor) >= ${idim}$) THEN
2123         mem_aware(${idim}$) = PRESENT(batch_range_${idim}$)
2124         IF (mem_aware(${idim}$)) nbatch(${idim}$) = SIZE(batch_range_${idim}$) - 1
2125      ENDIF
2126#:endfor
2127
2128      CALL dbcsr_t_get_info(tensor, nblks_total=tdims, name=name)
2129
2130#:for idim in range(1, maxdim+1)
2131      IF (ndims_tensor(tensor) >= ${idim}$) THEN
2132         ALLOCATE (bs_${idim}$ (dbcsr_t_nblks_total(tensor, ${idim}$)))
2133         CALL get_ith_array(tensor%blk_sizes, ${idim}$, bs_${idim}$)
2134         ALLOCATE (dist_${idim}$ (tdims(${idim}$)))
2135         dist_${idim}$ = 0
2136         IF (mem_aware(${idim}$)) THEN
2137            DO ibatch = 1, nbatch(${idim}$)
2138               ind1 = batch_range_${idim}$ (ibatch)
2139               ind2 = batch_range_${idim}$ (ibatch + 1) - 1
2140               batch_size = ind2 - ind1 + 1
2141               CALL dbcsr_t_default_distvec(batch_size, pdims(${idim}$), &
2142                                            bs_${idim}$ (ind1:ind2), dist_${idim}$ (ind1:ind2))
2143            ENDDO
2144         ELSE
2145            CALL dbcsr_t_default_distvec(tdims(${idim}$), pdims(${idim}$), bs_${idim}$, dist_${idim}$)
2146         ENDIF
2147      ENDIF
2148#:endfor
2149
2150      CALL dbcsr_t_get_mapping_info(tensor%nd_index_blk, map1_2d=map1, map2_2d=map2)
2151#:for ndim in ndims
2152      IF (ndims_tensor(tensor) == ${ndim}$) THEN
2153         CALL dbcsr_t_distribution_new(dist, pgrid, ${varlist("dist", nmax=ndim)}$)
2154         CALL dbcsr_t_create(t_tmp, name, dist, map1, map2, dbcsr_type_real_8, ${varlist("bs", nmax=ndim)}$)
2155      ENDIF
2156#:endfor
2157      CALL dbcsr_t_distribution_destroy(dist)
2158
2159      IF (PRESENT(nodata)) THEN
2160         IF (.NOT. nodata) CALL dbcsr_t_copy_expert(tensor, t_tmp, move_data=.TRUE.)
2161      ELSE
2162         CALL dbcsr_t_copy_expert(tensor, t_tmp, move_data=.TRUE.)
2163      ENDIF
2164
2165      CALL dbcsr_t_copy_contraction_storage(tensor, t_tmp)
2166
2167      CALL dbcsr_t_destroy(tensor)
2168      tensor = t_tmp
2169
2170      IF (PRESENT(unit_nr)) THEN
2171         IF (unit_nr > 0) THEN
2172            WRITE (unit_nr, "(T2,A,1X,A)") "OPTIMIZED PGRID INFO FOR", TRIM(tensor%name)
2173            WRITE (unit_nr, "(T4,A,1X,3I6)") "process grid dimensions:", pdims
2174            CALL dbcsr_t_write_split_info(pgrid, unit_nr)
2175         ENDIF
2176      ENDIF
2177
2178      IF (PRESENT(pgrid_changed)) pgrid_changed = .TRUE.
2179
2180      CALL timestop(handle)
2181   END SUBROUTINE
2182
2183   SUBROUTINE dbcsr_t_change_pgrid_2d(tensor, mp_comm, pdims, nodata, nsplit, dimsplit, pgrid_changed, unit_nr)
2184      !! map tensor to a new 2d process grid for the matrix representation.
2185      TYPE(dbcsr_t_type), INTENT(INOUT)                  :: tensor
2186      INTEGER, INTENT(IN)               :: mp_comm
2187      INTEGER, DIMENSION(2), INTENT(IN), OPTIONAL :: pdims
2188      LOGICAL, INTENT(IN), OPTIONAL                      :: nodata
2189      INTEGER, INTENT(IN), OPTIONAL :: nsplit, dimsplit
2190      LOGICAL, INTENT(OUT), OPTIONAL :: pgrid_changed
2191      INTEGER, INTENT(IN), OPTIONAL                      :: unit_nr
2192      INTEGER, DIMENSION(ndims_matrix_row(tensor)) :: map1
2193      INTEGER, DIMENSION(ndims_matrix_column(tensor)) :: map2
2194      INTEGER, DIMENSION(ndims_tensor(tensor)) :: dims, nbatches
2195      TYPE(dbcsr_t_pgrid_type) :: pgrid
2196      INTEGER, DIMENSION(:), ALLOCATABLE :: ${varlist("batch_range")}$
2197      INTEGER, DIMENSION(:), ALLOCATABLE :: array
2198      INTEGER :: idim
2199
2200      CALL dbcsr_t_get_mapping_info(tensor%pgrid%nd_index_grid, map1_2d=map1, map2_2d=map2)
2201      CALL blk_dims_tensor(tensor, dims)
2202
2203      IF (ALLOCATED(tensor%contraction_storage)) THEN
2204         ASSOCIATE (batch_ranges => tensor%contraction_storage%batch_ranges)
2205            nbatches = sizes_of_arrays(tensor%contraction_storage%batch_ranges) - 1
2206            ! for good load balancing the process grid dimensions should be chosen adapted to the
2207            ! tensor dimenions. For batched contraction the tensor dimensions should be divided by
2208            ! the number of batches (number of index ranges).
2209            DO idim = 1, ndims_tensor(tensor)
2210               CALL get_ith_array(tensor%contraction_storage%batch_ranges, idim, array)
2211               dims(idim) = array(nbatches(idim) + 1) - array(1)
2212               DEALLOCATE (array)
2213               dims(idim) = dims(idim)/nbatches(idim)
2214               IF (dims(idim) <= 0) dims(idim) = 1
2215            ENDDO
2216         END ASSOCIATE
2217      ENDIF
2218
2219      pgrid = dbcsr_t_nd_mp_comm(mp_comm, map1, map2, pdims_2d=pdims, tdims=dims, nsplit=nsplit, dimsplit=dimsplit)
2220      IF (ALLOCATED(tensor%contraction_storage)) THEN
2221#:for ndim in range(1, maxdim+1)
2222         IF (ndims_tensor(tensor) == ${ndim}$) THEN
2223            CALL get_arrays(tensor%contraction_storage%batch_ranges, ${varlist("batch_range", nmax=ndim)}$)
2224            CALL dbcsr_t_change_pgrid(tensor, pgrid, ${varlist("batch_range", nmax=ndim)}$, &
2225                                      nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr)
2226         ENDIF
2227#:endfor
2228      ELSE
2229         CALL dbcsr_t_change_pgrid(tensor, pgrid, nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr)
2230      ENDIF
2231      CALL dbcsr_t_pgrid_destroy(pgrid)
2232
2233   END SUBROUTINE
2234
2235END MODULE
2236