1import numpy as np 2 3from gpaw.matrix import Matrix, create_distribution 4 5 6class MatrixInFile: 7 def __init__(self, M, N, dtype, data, dist): 8 self.shape = (M, N) 9 self.dtype = dtype 10 self.array = data # pointer to data in a file 11 self.dist = create_distribution(M, N, *dist) 12 13 14class ArrayWaveFunctions: 15 def __init__(self, M, N, dtype, data, dist, collinear): 16 self.collinear = collinear 17 if not collinear: 18 N *= 2 19 if data is None or isinstance(data, np.ndarray): 20 self.matrix = Matrix(M, N, dtype, data, dist) 21 self.in_memory = True 22 else: 23 self.matrix = MatrixInFile(M, N, dtype, data, dist) 24 self.in_memory = False 25 self.comm = None 26 self.dtype = self.matrix.dtype 27 28 def __len__(self): 29 return len(self.matrix) 30 31 def multiply(self, alpha, opa, b, opb, beta, c, symmetric): 32 self.matrix.multiply(alpha, opa, b.matrix, opb, beta, c, symmetric) 33 if opa == 'N' and self.comm: 34 if self.comm.size > 1: 35 c.comm = self.comm 36 c.state = 'a sum is needed' 37 assert opb in 'TC' and b.comm is self.comm 38 39 def matrix_elements(self, other=None, out=None, symmetric=False, cc=False, 40 operator=None, result=None, serial=False): 41 if out is None: 42 out = Matrix(len(self), len(other or self), dtype=self.dtype, 43 dist=(self.matrix.dist.comm, 44 self.matrix.dist.rows, 45 self.matrix.dist.columns)) 46 if other is None or isinstance(other, ArrayWaveFunctions): 47 assert cc 48 if other is None: 49 assert symmetric 50 operate_and_multiply(self, self.dv, out, operator, result) 51 elif not serial: 52 assert not symmetric 53 operate_and_multiply_not_symmetric(self, self.dv, out, 54 other) 55 else: 56 self.multiply(self.dv, 'N', other, 'C', 0.0, out, symmetric) 57 else: 58 assert not cc 59 P_ani = {a: P_ni for a, P_ni in out.items()} 60 other.integrate(self.array, P_ani, self.kpt) 61 return out 62 63 def add(self, lfc, coefs): 64 lfc.add(self.array, dict(coefs.items()), self.kpt) 65 66 def apply(self, func, out=None): 67 out = out or self.new() 68 func(self.array, out.array) 69 return out 70 71 def __setitem__(self, i, x): 72 x.eval(self.matrix) 73 74 def __iadd__(self, other): 75 other.eval(self.matrix, 1.0) 76 return self 77 78 def eval(self, matrix): 79 matrix.array[:] = self.matrix.array 80 81 def read_from_file(self): 82 """Read wave functions from file into memory.""" 83 matrix = Matrix(*self.matrix.shape, 84 dtype=self.dtype, dist=self.matrix.dist) 85 # Read band by band to save memory 86 rows = matrix.dist.rows 87 blocksize = (matrix.shape[0] + rows - 1) // rows 88 for myn, psit_G in enumerate(matrix.array): 89 n = matrix.dist.comm.rank * blocksize + myn 90 if self.comm.rank == 0: 91 big_psit_G = self.array[n] 92 if big_psit_G.dtype == complex and self.dtype == float: 93 big_psit_G = big_psit_G.view(float) 94 elif big_psit_G.dtype == float and self.dtype == complex: 95 big_psit_G = np.asarray(big_psit_G, complex) 96 else: 97 big_psit_G = None 98 self._distribute(big_psit_G, psit_G) 99 self.matrix = matrix 100 self.in_memory = True 101 102 103class UniformGridWaveFunctions(ArrayWaveFunctions): 104 def __init__(self, nbands, gd, dtype=None, data=None, kpt=None, dist=None, 105 spin=0, collinear=True): 106 ngpts = gd.n_c.prod() 107 ArrayWaveFunctions.__init__(self, nbands, ngpts, dtype, data, dist, 108 collinear) 109 110 M = self.matrix 111 112 if data is None: 113 M.array = M.array.reshape(-1).reshape(M.dist.shape) 114 115 self.myshape = (M.dist.shape[0],) + tuple(gd.n_c) 116 self.gd = gd 117 self.dv = gd.dv 118 self.kpt = kpt 119 self.spin = spin 120 self.comm = gd.comm 121 122 @property 123 def array(self): 124 if self.in_memory: 125 return self.matrix.array.reshape(self.myshape) 126 else: 127 return self.matrix.array 128 129 def _distribute(self, big_psit_R, psit_R): 130 self.gd.distribute(big_psit_R, psit_R.reshape(self.gd.n_c)) 131 132 def __repr__(self): 133 s = ArrayWaveFunctions.__repr__(self).split('(')[1][:-1] 134 shape = self.gd.get_size_of_global_array() 135 s = 'UniformGridWaveFunctions({}, gpts={}x{}x{})'.format(s, *shape) 136 return s 137 138 def new(self, buf=None, dist='inherit', nbands=None): 139 if dist == 'inherit': 140 dist = self.matrix.dist 141 return UniformGridWaveFunctions(nbands or len(self), 142 self.gd, self.dtype, 143 buf, 144 self.kpt, dist, 145 self.spin) 146 147 def view(self, n1, n2): 148 return UniformGridWaveFunctions(n2 - n1, self.gd, self.dtype, 149 self.array[n1:n2], 150 self.kpt, None, 151 self.spin) 152 153 def plot(self): 154 import matplotlib.pyplot as plt 155 ax = plt.figure().add_subplot(111) 156 a, b, c = self.array.shape[1:] 157 ax.plot(self.array[0, a // 2, b // 2]) 158 plt.show() 159 160 161class PlaneWaveExpansionWaveFunctions(ArrayWaveFunctions): 162 def __init__(self, nbands, pd, dtype=None, data=None, kpt=0, dist=None, 163 spin=0, collinear=True): 164 ng = ng0 = pd.myng_q[kpt] 165 if data is not None: 166 assert data.dtype == complex 167 if dtype == float: 168 ng *= 2 169 if isinstance(data, np.ndarray): 170 data = data.view(float) 171 172 ArrayWaveFunctions.__init__(self, nbands, ng, dtype, data, dist, 173 collinear) 174 self.pd = pd 175 self.gd = pd.gd 176 self.comm = pd.gd.comm 177 self.dv = pd.gd.dv / pd.gd.N_c.prod() 178 self.kpt = kpt 179 self.spin = spin 180 if collinear: 181 self.myshape = (self.matrix.dist.shape[0], ng0) 182 else: 183 self.myshape = (self.matrix.dist.shape[0], 2, ng0) 184 185 @property 186 def array(self): 187 if not self.in_memory: 188 return self.matrix.array 189 elif self.dtype == float: 190 return self.matrix.array.view(complex) 191 else: 192 return self.matrix.array.reshape(self.myshape) 193 194 def _distribute(self, big_psit_G, psit_G): 195 if self.collinear: 196 if self.dtype == float: 197 if big_psit_G is not None: 198 big_psit_G = big_psit_G.view(complex) 199 psit_G = psit_G.view(complex) 200 psit_G[:] = self.pd.scatter(big_psit_G, self.kpt) 201 else: 202 psit_sG = psit_G.reshape((2, -1)) 203 psit_sG[0] = self.pd.scatter(big_psit_G[0], self.kpt) 204 psit_sG[1] = self.pd.scatter(big_psit_G[1], self.kpt) 205 206 def matrix_elements(self, other=None, out=None, symmetric=False, cc=False, 207 operator=None, result=None, serial=False): 208 if other is None or isinstance(other, ArrayWaveFunctions): 209 if out is None: 210 out = Matrix(len(self), len(other or self), dtype=self.dtype, 211 dist=(self.matrix.dist.comm, 212 self.matrix.dist.rows, 213 self.matrix.dist.columns)) 214 assert cc 215 if other is None: 216 assert symmetric 217 operate_and_multiply(self, self.dv, out, operator, result) 218 elif not serial: 219 assert not symmetric 220 operate_and_multiply_not_symmetric(self, self.dv, out, 221 other) 222 elif self.dtype == complex: 223 self.matrix.multiply(self.dv, 'N', other.matrix, 'C', 224 0.0, out, symmetric) 225 else: 226 self.matrix.multiply(2 * self.dv, 'N', other.matrix, 'T', 227 0.0, out, symmetric) 228 if self.gd.comm.rank == 0: 229 correction = np.outer(self.matrix.array[:, 0], 230 other.matrix.array[:, 0]) 231 if symmetric: 232 out.array -= 0.5 * self.dv * (correction + 233 correction.T) 234 else: 235 out.array -= self.dv * correction 236 else: 237 assert not cc 238 P_ani = {a: P_ni for a, P_ni in out.items()} 239 other.integrate(self.array, P_ani, self.kpt) 240 return out 241 242 def new(self, buf=None, dist='inherit', nbands=None): 243 if buf is not None: 244 array = self.array 245 buf = buf.ravel()[:array.size] 246 buf.shape = array.shape 247 if dist == 'inherit': 248 dist = self.matrix.dist 249 return PlaneWaveExpansionWaveFunctions(nbands or len(self), 250 self.pd, self.dtype, 251 buf, 252 self.kpt, dist, 253 self.spin, self.collinear) 254 255 def view(self, n1, n2): 256 return PlaneWaveExpansionWaveFunctions(n2 - n1, self.pd, self.dtype, 257 self.array[n1:n2], 258 self.kpt, None, 259 self.spin, self.collinear) 260 261 262def operate_and_multiply(psit1, dv, out, operator, psit2): 263 if psit1.comm: 264 if psit2 is not None: 265 assert psit2.comm is psit1.comm 266 if psit1.comm.size > 1: 267 out.comm = psit1.comm 268 out.state = 'a sum is needed' 269 270 comm = psit1.matrix.dist.comm 271 N = len(psit1) 272 n = (N + comm.size - 1) // comm.size 273 mynbands = len(psit1.matrix.array) 274 275 buf1 = psit1.new(nbands=n, dist=None) 276 buf2 = psit1.new(nbands=n, dist=None) 277 half = comm.size // 2 278 psit = psit1.view(0, mynbands) 279 if psit2 is not None: 280 psit2 = psit2.view(0, mynbands) 281 282 for r in range(half + 1): 283 rrequest = None 284 srequest = None 285 286 if r < half: 287 srank = (comm.rank + r + 1) % comm.size 288 rrank = (comm.rank - r - 1) % comm.size 289 skip = (comm.size % 2 == 0 and r == half - 1) 290 n1 = min(rrank * n, N) 291 n2 = min(n1 + n, N) 292 if not (skip and comm.rank < half) and n2 > n1: 293 rrequest = comm.receive(buf1.array[:n2 - n1], rrank, 11, False) 294 if not (skip and comm.rank >= half) and len(psit1.array) > 0: 295 srequest = comm.send(psit1.array, srank, 11, False) 296 297 if r == 0: 298 if operator: 299 operator(psit1.array, psit2.array) 300 else: 301 psit2 = psit 302 303 if not (comm.size % 2 == 0 and r == half and comm.rank < half): 304 m12 = psit2.matrix_elements(psit, symmetric=(r == 0), cc=True, 305 serial=True) 306 n1 = min(((comm.rank - r) % comm.size) * n, N) 307 n2 = min(n1 + n, N) 308 out.array[:, n1:n2] = m12.array[:, :n2 - n1] 309 310 if rrequest: 311 comm.wait(rrequest) 312 if srequest: 313 comm.wait(srequest) 314 315 psit = buf1 316 buf1, buf2 = buf2, buf1 317 318 requests = [] 319 blocks = [] 320 nrows = (comm.size - 1) // 2 321 for row in range(nrows): 322 for column in range(comm.size - nrows + row, comm.size): 323 if comm.rank == row: 324 n1 = min(column * n, N) 325 n2 = min(n1 + n, N) 326 if mynbands > 0 and n2 > n1: 327 requests.append( 328 comm.send(out.array[:, n1:n2].T.conj().copy(), 329 column, 12, False)) 330 elif comm.rank == column: 331 n1 = min(row * n, N) 332 n2 = min(n1 + n, N) 333 if mynbands > 0 and n2 > n1: 334 block = np.empty((mynbands, n2 - n1), out.dtype) 335 blocks.append((n1, n2, block)) 336 requests.append(comm.receive(block, row, 12, False)) 337 338 comm.waitall(requests) 339 for n1, n2, block in blocks: 340 out.array[:, n1:n2] = block 341 342 343def operate_and_multiply_not_symmetric(psit1, dv, out, psit2): 344 if psit1.comm: 345 if psit2 is not None: 346 assert psit2.comm is psit1.comm 347 if psit1.comm.size > 1: 348 out.comm = psit1.comm 349 out.state = 'a sum is needed' 350 351 comm = psit1.matrix.dist.comm 352 N = len(psit1) 353 n = (N + comm.size - 1) // comm.size 354 mynbands = len(psit1.matrix.array) 355 356 buf1 = psit1.new(nbands=n, dist=None) 357 buf2 = psit1.new(nbands=n, dist=None) 358 359 psit1 = psit1.view(0, mynbands) 360 psit = psit2.view(0, mynbands) 361 for r in range(comm.size): 362 rrequest = None 363 srequest = None 364 365 if r < comm.size - 1: 366 srank = (comm.rank + r + 1) % comm.size 367 rrank = (comm.rank - r - 1) % comm.size 368 n1 = min(rrank * n, N) 369 n2 = min(n1 + n, N) 370 if n2 > n1: 371 rrequest = comm.receive(buf1.array[:n2 - n1], rrank, 11, False) 372 if len(psit1.array) > 0: 373 srequest = comm.send(psit2.array, srank, 11, False) 374 375 m12 = psit1.matrix_elements(psit, cc=True, serial=True) 376 n1 = min(((comm.rank - r) % comm.size) * n, N) 377 n2 = min(n1 + n, N) 378 out.array[:, n1:n2] = m12.array[:, :n2 - n1] 379 380 if rrequest: 381 comm.wait(rrequest) 382 if srequest: 383 comm.wait(srequest) 384 385 psit = buf1 386 buf1, buf2 = buf2, buf1 387