1import numpy as np
2
3from gpaw.matrix import Matrix, create_distribution
4
5
6class MatrixInFile:
7    def __init__(self, M, N, dtype, data, dist):
8        self.shape = (M, N)
9        self.dtype = dtype
10        self.array = data  # pointer to data in a file
11        self.dist = create_distribution(M, N, *dist)
12
13
14class ArrayWaveFunctions:
15    def __init__(self, M, N, dtype, data, dist, collinear):
16        self.collinear = collinear
17        if not collinear:
18            N *= 2
19        if data is None or isinstance(data, np.ndarray):
20            self.matrix = Matrix(M, N, dtype, data, dist)
21            self.in_memory = True
22        else:
23            self.matrix = MatrixInFile(M, N, dtype, data, dist)
24            self.in_memory = False
25        self.comm = None
26        self.dtype = self.matrix.dtype
27
28    def __len__(self):
29        return len(self.matrix)
30
31    def multiply(self, alpha, opa, b, opb, beta, c, symmetric):
32        self.matrix.multiply(alpha, opa, b.matrix, opb, beta, c, symmetric)
33        if opa == 'N' and self.comm:
34            if self.comm.size > 1:
35                c.comm = self.comm
36                c.state = 'a sum is needed'
37            assert opb in 'TC' and b.comm is self.comm
38
39    def matrix_elements(self, other=None, out=None, symmetric=False, cc=False,
40                        operator=None, result=None, serial=False):
41        if out is None:
42            out = Matrix(len(self), len(other or self), dtype=self.dtype,
43                         dist=(self.matrix.dist.comm,
44                               self.matrix.dist.rows,
45                               self.matrix.dist.columns))
46        if other is None or isinstance(other, ArrayWaveFunctions):
47            assert cc
48            if other is None:
49                assert symmetric
50                operate_and_multiply(self, self.dv, out, operator, result)
51            elif not serial:
52                assert not symmetric
53                operate_and_multiply_not_symmetric(self, self.dv, out,
54                                                   other)
55            else:
56                self.multiply(self.dv, 'N', other, 'C', 0.0, out, symmetric)
57        else:
58            assert not cc
59            P_ani = {a: P_ni for a, P_ni in out.items()}
60            other.integrate(self.array, P_ani, self.kpt)
61        return out
62
63    def add(self, lfc, coefs):
64        lfc.add(self.array, dict(coefs.items()), self.kpt)
65
66    def apply(self, func, out=None):
67        out = out or self.new()
68        func(self.array, out.array)
69        return out
70
71    def __setitem__(self, i, x):
72        x.eval(self.matrix)
73
74    def __iadd__(self, other):
75        other.eval(self.matrix, 1.0)
76        return self
77
78    def eval(self, matrix):
79        matrix.array[:] = self.matrix.array
80
81    def read_from_file(self):
82        """Read wave functions from file into memory."""
83        matrix = Matrix(*self.matrix.shape,
84                        dtype=self.dtype, dist=self.matrix.dist)
85        # Read band by band to save memory
86        rows = matrix.dist.rows
87        blocksize = (matrix.shape[0] + rows - 1) // rows
88        for myn, psit_G in enumerate(matrix.array):
89            n = matrix.dist.comm.rank * blocksize + myn
90            if self.comm.rank == 0:
91                big_psit_G = self.array[n]
92                if big_psit_G.dtype == complex and self.dtype == float:
93                    big_psit_G = big_psit_G.view(float)
94                elif big_psit_G.dtype == float and self.dtype == complex:
95                    big_psit_G = np.asarray(big_psit_G, complex)
96            else:
97                big_psit_G = None
98            self._distribute(big_psit_G, psit_G)
99        self.matrix = matrix
100        self.in_memory = True
101
102
103class UniformGridWaveFunctions(ArrayWaveFunctions):
104    def __init__(self, nbands, gd, dtype=None, data=None, kpt=None, dist=None,
105                 spin=0, collinear=True):
106        ngpts = gd.n_c.prod()
107        ArrayWaveFunctions.__init__(self, nbands, ngpts, dtype, data, dist,
108                                    collinear)
109
110        M = self.matrix
111
112        if data is None:
113            M.array = M.array.reshape(-1).reshape(M.dist.shape)
114
115        self.myshape = (M.dist.shape[0],) + tuple(gd.n_c)
116        self.gd = gd
117        self.dv = gd.dv
118        self.kpt = kpt
119        self.spin = spin
120        self.comm = gd.comm
121
122    @property
123    def array(self):
124        if self.in_memory:
125            return self.matrix.array.reshape(self.myshape)
126        else:
127            return self.matrix.array
128
129    def _distribute(self, big_psit_R, psit_R):
130        self.gd.distribute(big_psit_R, psit_R.reshape(self.gd.n_c))
131
132    def __repr__(self):
133        s = ArrayWaveFunctions.__repr__(self).split('(')[1][:-1]
134        shape = self.gd.get_size_of_global_array()
135        s = 'UniformGridWaveFunctions({}, gpts={}x{}x{})'.format(s, *shape)
136        return s
137
138    def new(self, buf=None, dist='inherit', nbands=None):
139        if dist == 'inherit':
140            dist = self.matrix.dist
141        return UniformGridWaveFunctions(nbands or len(self),
142                                        self.gd, self.dtype,
143                                        buf,
144                                        self.kpt, dist,
145                                        self.spin)
146
147    def view(self, n1, n2):
148        return UniformGridWaveFunctions(n2 - n1, self.gd, self.dtype,
149                                        self.array[n1:n2],
150                                        self.kpt, None,
151                                        self.spin)
152
153    def plot(self):
154        import matplotlib.pyplot as plt
155        ax = plt.figure().add_subplot(111)
156        a, b, c = self.array.shape[1:]
157        ax.plot(self.array[0, a // 2, b // 2])
158        plt.show()
159
160
161class PlaneWaveExpansionWaveFunctions(ArrayWaveFunctions):
162    def __init__(self, nbands, pd, dtype=None, data=None, kpt=0, dist=None,
163                 spin=0, collinear=True):
164        ng = ng0 = pd.myng_q[kpt]
165        if data is not None:
166            assert data.dtype == complex
167        if dtype == float:
168            ng *= 2
169            if isinstance(data, np.ndarray):
170                data = data.view(float)
171
172        ArrayWaveFunctions.__init__(self, nbands, ng, dtype, data, dist,
173                                    collinear)
174        self.pd = pd
175        self.gd = pd.gd
176        self.comm = pd.gd.comm
177        self.dv = pd.gd.dv / pd.gd.N_c.prod()
178        self.kpt = kpt
179        self.spin = spin
180        if collinear:
181            self.myshape = (self.matrix.dist.shape[0], ng0)
182        else:
183            self.myshape = (self.matrix.dist.shape[0], 2, ng0)
184
185    @property
186    def array(self):
187        if not self.in_memory:
188            return self.matrix.array
189        elif self.dtype == float:
190            return self.matrix.array.view(complex)
191        else:
192            return self.matrix.array.reshape(self.myshape)
193
194    def _distribute(self, big_psit_G, psit_G):
195        if self.collinear:
196            if self.dtype == float:
197                if big_psit_G is not None:
198                    big_psit_G = big_psit_G.view(complex)
199                psit_G = psit_G.view(complex)
200            psit_G[:] = self.pd.scatter(big_psit_G, self.kpt)
201        else:
202            psit_sG = psit_G.reshape((2, -1))
203            psit_sG[0] = self.pd.scatter(big_psit_G[0], self.kpt)
204            psit_sG[1] = self.pd.scatter(big_psit_G[1], self.kpt)
205
206    def matrix_elements(self, other=None, out=None, symmetric=False, cc=False,
207                        operator=None, result=None, serial=False):
208        if other is None or isinstance(other, ArrayWaveFunctions):
209            if out is None:
210                out = Matrix(len(self), len(other or self), dtype=self.dtype,
211                             dist=(self.matrix.dist.comm,
212                                   self.matrix.dist.rows,
213                                   self.matrix.dist.columns))
214            assert cc
215            if other is None:
216                assert symmetric
217                operate_and_multiply(self, self.dv, out, operator, result)
218            elif not serial:
219                assert not symmetric
220                operate_and_multiply_not_symmetric(self, self.dv, out,
221                                                   other)
222            elif self.dtype == complex:
223                self.matrix.multiply(self.dv, 'N', other.matrix, 'C',
224                                     0.0, out, symmetric)
225            else:
226                self.matrix.multiply(2 * self.dv, 'N', other.matrix, 'T',
227                                     0.0, out, symmetric)
228                if self.gd.comm.rank == 0:
229                    correction = np.outer(self.matrix.array[:, 0],
230                                          other.matrix.array[:, 0])
231                    if symmetric:
232                        out.array -= 0.5 * self.dv * (correction +
233                                                      correction.T)
234                    else:
235                        out.array -= self.dv * correction
236        else:
237            assert not cc
238            P_ani = {a: P_ni for a, P_ni in out.items()}
239            other.integrate(self.array, P_ani, self.kpt)
240        return out
241
242    def new(self, buf=None, dist='inherit', nbands=None):
243        if buf is not None:
244            array = self.array
245            buf = buf.ravel()[:array.size]
246            buf.shape = array.shape
247        if dist == 'inherit':
248            dist = self.matrix.dist
249        return PlaneWaveExpansionWaveFunctions(nbands or len(self),
250                                               self.pd, self.dtype,
251                                               buf,
252                                               self.kpt, dist,
253                                               self.spin, self.collinear)
254
255    def view(self, n1, n2):
256        return PlaneWaveExpansionWaveFunctions(n2 - n1, self.pd, self.dtype,
257                                               self.array[n1:n2],
258                                               self.kpt, None,
259                                               self.spin, self.collinear)
260
261
262def operate_and_multiply(psit1, dv, out, operator, psit2):
263    if psit1.comm:
264        if psit2 is not None:
265            assert psit2.comm is psit1.comm
266        if psit1.comm.size > 1:
267            out.comm = psit1.comm
268            out.state = 'a sum is needed'
269
270    comm = psit1.matrix.dist.comm
271    N = len(psit1)
272    n = (N + comm.size - 1) // comm.size
273    mynbands = len(psit1.matrix.array)
274
275    buf1 = psit1.new(nbands=n, dist=None)
276    buf2 = psit1.new(nbands=n, dist=None)
277    half = comm.size // 2
278    psit = psit1.view(0, mynbands)
279    if psit2 is not None:
280        psit2 = psit2.view(0, mynbands)
281
282    for r in range(half + 1):
283        rrequest = None
284        srequest = None
285
286        if r < half:
287            srank = (comm.rank + r + 1) % comm.size
288            rrank = (comm.rank - r - 1) % comm.size
289            skip = (comm.size % 2 == 0 and r == half - 1)
290            n1 = min(rrank * n, N)
291            n2 = min(n1 + n, N)
292            if not (skip and comm.rank < half) and n2 > n1:
293                rrequest = comm.receive(buf1.array[:n2 - n1], rrank, 11, False)
294            if not (skip and comm.rank >= half) and len(psit1.array) > 0:
295                srequest = comm.send(psit1.array, srank, 11, False)
296
297        if r == 0:
298            if operator:
299                operator(psit1.array, psit2.array)
300            else:
301                psit2 = psit
302
303        if not (comm.size % 2 == 0 and r == half and comm.rank < half):
304            m12 = psit2.matrix_elements(psit, symmetric=(r == 0), cc=True,
305                                        serial=True)
306            n1 = min(((comm.rank - r) % comm.size) * n, N)
307            n2 = min(n1 + n, N)
308            out.array[:, n1:n2] = m12.array[:, :n2 - n1]
309
310        if rrequest:
311            comm.wait(rrequest)
312        if srequest:
313            comm.wait(srequest)
314
315        psit = buf1
316        buf1, buf2 = buf2, buf1
317
318    requests = []
319    blocks = []
320    nrows = (comm.size - 1) // 2
321    for row in range(nrows):
322        for column in range(comm.size - nrows + row, comm.size):
323            if comm.rank == row:
324                n1 = min(column * n, N)
325                n2 = min(n1 + n, N)
326                if mynbands > 0 and n2 > n1:
327                    requests.append(
328                        comm.send(out.array[:, n1:n2].T.conj().copy(),
329                                  column, 12, False))
330            elif comm.rank == column:
331                n1 = min(row * n, N)
332                n2 = min(n1 + n, N)
333                if mynbands > 0 and n2 > n1:
334                    block = np.empty((mynbands, n2 - n1), out.dtype)
335                    blocks.append((n1, n2, block))
336                    requests.append(comm.receive(block, row, 12, False))
337
338    comm.waitall(requests)
339    for n1, n2, block in blocks:
340        out.array[:, n1:n2] = block
341
342
343def operate_and_multiply_not_symmetric(psit1, dv, out, psit2):
344    if psit1.comm:
345        if psit2 is not None:
346            assert psit2.comm is psit1.comm
347        if psit1.comm.size > 1:
348            out.comm = psit1.comm
349            out.state = 'a sum is needed'
350
351    comm = psit1.matrix.dist.comm
352    N = len(psit1)
353    n = (N + comm.size - 1) // comm.size
354    mynbands = len(psit1.matrix.array)
355
356    buf1 = psit1.new(nbands=n, dist=None)
357    buf2 = psit1.new(nbands=n, dist=None)
358
359    psit1 = psit1.view(0, mynbands)
360    psit = psit2.view(0, mynbands)
361    for r in range(comm.size):
362        rrequest = None
363        srequest = None
364
365        if r < comm.size - 1:
366            srank = (comm.rank + r + 1) % comm.size
367            rrank = (comm.rank - r - 1) % comm.size
368            n1 = min(rrank * n, N)
369            n2 = min(n1 + n, N)
370            if n2 > n1:
371                rrequest = comm.receive(buf1.array[:n2 - n1], rrank, 11, False)
372            if len(psit1.array) > 0:
373                srequest = comm.send(psit2.array, srank, 11, False)
374
375        m12 = psit1.matrix_elements(psit, cc=True, serial=True)
376        n1 = min(((comm.rank - r) % comm.size) * n, N)
377        n2 = min(n1 + n, N)
378        out.array[:, n1:n2] = m12.array[:, :n2 - n1]
379
380        if rrequest:
381            comm.wait(rrequest)
382        if srequest:
383            comm.wait(srequest)
384
385        psit = buf1
386        buf1, buf2 = buf2, buf1
387