1import numpy as np
2
3from gpaw.blacs import BlacsGrid
4from gpaw.blacs import Redistributor
5
6
7def collect_uX(kd, comm, a_uX, s, k):
8    # Comm is a communicator orthogonal to kd.comm (ie, domainband_comm)
9    Xshape = a_uX[0].shape
10    dtype = a_uX[0].dtype
11    kpt_rank, q = kd.get_rank_and_index(k)
12    u = q * kd.nspins + s
13    if kd.comm.rank == kpt_rank:
14        a_X = a_uX[u]
15        # Comm master send to the global master
16        if comm.rank == 0:
17            if kpt_rank == 0:
18                # assert world.rank == 0
19                return a_X
20            else:
21                kd.comm.ssend(a_X, 0, 2018)
22    elif comm.rank == 0 and kpt_rank != 0:
23        # assert world.rank == 0
24        a_X = np.empty(Xshape, dtype=dtype)
25        kd.comm.receive(a_X, kpt_rank, 2018)
26        return a_X
27
28
29def write_uX(kd, comm, writer, name, a_uX):
30    ushape = (kd.nspins, kd.nibzkpts)
31    Xshape = a_uX[0].shape
32    dtype = a_uX[0].dtype
33    writer.add_array(name, ushape + Xshape, dtype=dtype)
34    for s in range(kd.nspins):
35        for k in range(kd.nibzkpts):
36            a_X = collect_uX(kd, comm, a_uX, s, k)
37            writer.fill(a_X)
38
39
40def read_uX(kpt_u, reader, name):
41    a_uX = []
42    for kpt in kpt_u:
43        indices = (kpt.s, kpt.k)
44        # TODO: does this read on all the comm ranks in vain?
45        a_X = reader.proxy(name, *indices)[:]
46        a_uX.append(a_X)
47    return a_uX
48
49
50def distribute_nM(ksl, a_nM):
51    if not ksl.using_blacs:
52        return a_nM
53
54    dtype = a_nM.dtype
55    ksl.nMdescriptor.checkassert(a_nM)
56    if ksl.gd.rank != 0:
57        a_nM = ksl.nM_unique_descriptor.zeros(dtype=dtype)
58
59    nM2mm = Redistributor(ksl.block_comm, ksl.nM_unique_descriptor,
60                          ksl.mmdescriptor)
61
62    a_mm = ksl.mmdescriptor.empty(dtype=dtype)
63    nM2mm.redistribute(a_nM, a_mm, ksl.bd.nbands, ksl.nao)
64    return a_mm
65
66
67def collect_MM(ksl, a_mm):
68    if not ksl.using_blacs:
69        return a_mm
70
71    dtype = a_mm.dtype
72    NM = ksl.nao
73    grid = BlacsGrid(ksl.block_comm, 1, 1)
74    MM_descriptor = grid.new_descriptor(NM, NM, NM, NM)
75    mm2MM = Redistributor(ksl.block_comm,
76                          ksl.mmdescriptor,
77                          MM_descriptor)
78
79    a_MM = MM_descriptor.empty(dtype=dtype)
80    mm2MM.redistribute(a_mm, a_MM)
81    return a_MM
82
83
84def collect_uMM(kd, ksl, a_uMM, s, k):
85    return collect_wuMM(kd, ksl, [a_uMM], 0, s, k)
86
87
88def collect_wuMM(kd, ksl, a_wuMM, w, s, k):
89    # This function is based on
90    # gpaw/wavefunctions/base.py: WaveFunctions.collect_auxiliary()
91
92    dtype = a_wuMM[0][0].dtype
93    NM = ksl.nao
94    kpt_rank, q = kd.get_rank_and_index(k)
95    u = q * kd.nspins + s
96    if kd.comm.rank == kpt_rank:
97        a_MM = a_wuMM[w][u]
98
99        # Collect within blacs grid
100        a_MM = collect_MM(ksl, a_MM)
101
102        # KSL master send a_MM to the global master
103        if ksl.block_comm.rank == 0:
104            if kpt_rank == 0:
105                assert ksl.world.rank == 0
106                # I have it already
107                return a_MM
108            else:
109                kd.comm.send(a_MM, 0, 2017)
110                return None
111    elif ksl.block_comm.rank == 0 and kpt_rank != 0:
112        assert ksl.world.rank == 0
113        a_MM = np.empty((NM, NM), dtype=dtype)
114        kd.comm.receive(a_MM, kpt_rank, 2017)
115        return a_MM
116
117
118def distribute_MM(ksl, a_MM):
119    if not ksl.using_blacs:
120        return a_MM
121
122    dtype = a_MM.dtype
123    NM = ksl.nao
124    grid = BlacsGrid(ksl.block_comm, 1, 1)
125    MM_descriptor = grid.new_descriptor(NM, NM, NM, NM)
126    MM2mm = Redistributor(ksl.block_comm,
127                          MM_descriptor,
128                          ksl.mmdescriptor)
129    if ksl.block_comm.rank != 0:
130        a_MM = MM_descriptor.empty(dtype=dtype)
131
132    a_mm = ksl.mmdescriptor.empty(dtype=dtype)
133    MM2mm.redistribute(a_MM, a_mm)
134    return a_mm
135
136
137def write_uMM(kd, ksl, writer, name, a_uMM):
138    return write_wuMM(kd, ksl, writer, name, [a_uMM], wlist=[0])
139
140
141def write_wuMM(kd, ksl, writer, name, a_wuMM, wlist):
142    NM = ksl.nao
143    dtype = a_wuMM[0][0].dtype
144    writer.add_array(name,
145                     (len(wlist), kd.nspins, kd.nibzkpts, NM, NM),
146                     dtype=dtype)
147    for w in wlist:
148        for s in range(kd.nspins):
149            for k in range(kd.nibzkpts):
150                a_MM = collect_wuMM(kd, ksl, a_wuMM, w, s, k)
151                writer.fill(a_MM)
152
153
154def read_uMM(kpt_u, ksl, reader, name):
155    return read_wuMM(kpt_u, ksl, reader, name, wlist=[0])[0]
156
157
158def read_wuMM(kpt_u, ksl, reader, name, wlist):
159    a_wuMM = []
160    for w in wlist:
161        a_uMM = []
162        for kpt in kpt_u:
163            indices = (w, kpt.s, kpt.k)
164            # TODO: does this read on all the ksl ranks in vain?
165            a_MM = reader.proxy(name, *indices)[:]
166            a_MM = distribute_MM(ksl, a_MM)
167            a_uMM.append(a_MM)
168        a_wuMM.append(a_uMM)
169    return a_wuMM
170