1import numpy as np 2 3from gpaw.blacs import BlacsGrid 4from gpaw.blacs import Redistributor 5 6 7def collect_uX(kd, comm, a_uX, s, k): 8 # Comm is a communicator orthogonal to kd.comm (ie, domainband_comm) 9 Xshape = a_uX[0].shape 10 dtype = a_uX[0].dtype 11 kpt_rank, q = kd.get_rank_and_index(k) 12 u = q * kd.nspins + s 13 if kd.comm.rank == kpt_rank: 14 a_X = a_uX[u] 15 # Comm master send to the global master 16 if comm.rank == 0: 17 if kpt_rank == 0: 18 # assert world.rank == 0 19 return a_X 20 else: 21 kd.comm.ssend(a_X, 0, 2018) 22 elif comm.rank == 0 and kpt_rank != 0: 23 # assert world.rank == 0 24 a_X = np.empty(Xshape, dtype=dtype) 25 kd.comm.receive(a_X, kpt_rank, 2018) 26 return a_X 27 28 29def write_uX(kd, comm, writer, name, a_uX): 30 ushape = (kd.nspins, kd.nibzkpts) 31 Xshape = a_uX[0].shape 32 dtype = a_uX[0].dtype 33 writer.add_array(name, ushape + Xshape, dtype=dtype) 34 for s in range(kd.nspins): 35 for k in range(kd.nibzkpts): 36 a_X = collect_uX(kd, comm, a_uX, s, k) 37 writer.fill(a_X) 38 39 40def read_uX(kpt_u, reader, name): 41 a_uX = [] 42 for kpt in kpt_u: 43 indices = (kpt.s, kpt.k) 44 # TODO: does this read on all the comm ranks in vain? 45 a_X = reader.proxy(name, *indices)[:] 46 a_uX.append(a_X) 47 return a_uX 48 49 50def distribute_nM(ksl, a_nM): 51 if not ksl.using_blacs: 52 return a_nM 53 54 dtype = a_nM.dtype 55 ksl.nMdescriptor.checkassert(a_nM) 56 if ksl.gd.rank != 0: 57 a_nM = ksl.nM_unique_descriptor.zeros(dtype=dtype) 58 59 nM2mm = Redistributor(ksl.block_comm, ksl.nM_unique_descriptor, 60 ksl.mmdescriptor) 61 62 a_mm = ksl.mmdescriptor.empty(dtype=dtype) 63 nM2mm.redistribute(a_nM, a_mm, ksl.bd.nbands, ksl.nao) 64 return a_mm 65 66 67def collect_MM(ksl, a_mm): 68 if not ksl.using_blacs: 69 return a_mm 70 71 dtype = a_mm.dtype 72 NM = ksl.nao 73 grid = BlacsGrid(ksl.block_comm, 1, 1) 74 MM_descriptor = grid.new_descriptor(NM, NM, NM, NM) 75 mm2MM = Redistributor(ksl.block_comm, 76 ksl.mmdescriptor, 77 MM_descriptor) 78 79 a_MM = MM_descriptor.empty(dtype=dtype) 80 mm2MM.redistribute(a_mm, a_MM) 81 return a_MM 82 83 84def collect_uMM(kd, ksl, a_uMM, s, k): 85 return collect_wuMM(kd, ksl, [a_uMM], 0, s, k) 86 87 88def collect_wuMM(kd, ksl, a_wuMM, w, s, k): 89 # This function is based on 90 # gpaw/wavefunctions/base.py: WaveFunctions.collect_auxiliary() 91 92 dtype = a_wuMM[0][0].dtype 93 NM = ksl.nao 94 kpt_rank, q = kd.get_rank_and_index(k) 95 u = q * kd.nspins + s 96 if kd.comm.rank == kpt_rank: 97 a_MM = a_wuMM[w][u] 98 99 # Collect within blacs grid 100 a_MM = collect_MM(ksl, a_MM) 101 102 # KSL master send a_MM to the global master 103 if ksl.block_comm.rank == 0: 104 if kpt_rank == 0: 105 assert ksl.world.rank == 0 106 # I have it already 107 return a_MM 108 else: 109 kd.comm.send(a_MM, 0, 2017) 110 return None 111 elif ksl.block_comm.rank == 0 and kpt_rank != 0: 112 assert ksl.world.rank == 0 113 a_MM = np.empty((NM, NM), dtype=dtype) 114 kd.comm.receive(a_MM, kpt_rank, 2017) 115 return a_MM 116 117 118def distribute_MM(ksl, a_MM): 119 if not ksl.using_blacs: 120 return a_MM 121 122 dtype = a_MM.dtype 123 NM = ksl.nao 124 grid = BlacsGrid(ksl.block_comm, 1, 1) 125 MM_descriptor = grid.new_descriptor(NM, NM, NM, NM) 126 MM2mm = Redistributor(ksl.block_comm, 127 MM_descriptor, 128 ksl.mmdescriptor) 129 if ksl.block_comm.rank != 0: 130 a_MM = MM_descriptor.empty(dtype=dtype) 131 132 a_mm = ksl.mmdescriptor.empty(dtype=dtype) 133 MM2mm.redistribute(a_MM, a_mm) 134 return a_mm 135 136 137def write_uMM(kd, ksl, writer, name, a_uMM): 138 return write_wuMM(kd, ksl, writer, name, [a_uMM], wlist=[0]) 139 140 141def write_wuMM(kd, ksl, writer, name, a_wuMM, wlist): 142 NM = ksl.nao 143 dtype = a_wuMM[0][0].dtype 144 writer.add_array(name, 145 (len(wlist), kd.nspins, kd.nibzkpts, NM, NM), 146 dtype=dtype) 147 for w in wlist: 148 for s in range(kd.nspins): 149 for k in range(kd.nibzkpts): 150 a_MM = collect_wuMM(kd, ksl, a_wuMM, w, s, k) 151 writer.fill(a_MM) 152 153 154def read_uMM(kpt_u, ksl, reader, name): 155 return read_wuMM(kpt_u, ksl, reader, name, wlist=[0])[0] 156 157 158def read_wuMM(kpt_u, ksl, reader, name, wlist): 159 a_wuMM = [] 160 for w in wlist: 161 a_uMM = [] 162 for kpt in kpt_u: 163 indices = (w, kpt.s, kpt.k) 164 # TODO: does this read on all the ksl ranks in vain? 165 a_MM = reader.proxy(name, *indices)[:] 166 a_MM = distribute_MM(ksl, a_MM) 167 a_uMM.append(a_MM) 168 a_wuMM.append(a_uMM) 169 return a_wuMM 170