1"""Support for Abinit input variables.""" 2 3import string 4import collections 5import numpy as np 6 7 8__all__ = [ 9 'InputVariable', 10] 11 12_SPECIAL_DATASET_INDICES = (':', '+', '?') 13 14_DATASET_INDICES = ''.join(list(string.digits) + list(_SPECIAL_DATASET_INDICES)) 15 16_INTERNAL_DATASET_INDICES = ('__s', '__i', '__a') 17 18_SPECIAL_CONVERSION = zip(_INTERNAL_DATASET_INDICES, _SPECIAL_DATASET_INDICES) 19 20_UNITS = { 21 'bohr': 1.0, 22 'angstrom': 1.8897261328856432, 23 'hartree': 1.0, 24 'Ha': 1.0, 25 'eV': 0.03674932539796232, 26} 27 28 29class InputVariable(object): 30 """ 31 An Abinit input variable. 32 """ 33 def __init__(self, name, value, units='', valperline=3): 34 """ 35 Args: 36 name: Name of the variable. 37 value: Value of the variable. 38 units: String specifying one of the units supported by Abinit. Default: atomic units. 39 valperline: Number of items printed per line. 40 """ 41 42 self._name = name 43 self.value = value 44 self._units = units 45 46 # Maximum number of values per line. 47 self.valperline = valperline 48 if name in ['bdgw']: 49 self.valperline = 2 50 51 if (is_iter(self.value) and isinstance(self.value[-1], str) and self.value[-1] in _UNITS): 52 self.value = list(self.value) 53 self._units = self.value.pop(-1) 54 55 def get_value(self): 56 """Return the value.""" 57 if self.units: 58 return list(self.value) + [self.units] 59 else: 60 return self.value 61 62 @property 63 def name(self): 64 """Name of the variable.""" 65 return self._name 66 67 @property 68 def basename(self): 69 """Return the name trimmed of any dataset index.""" 70 basename = self.name 71 return basename.rstrip(_DATASET_INDICES) 72 73 @property 74 def dataset(self): 75 """Return the dataset index in string form.""" 76 return self.name.split(self.basename)[-1] 77 78 @property 79 def units(self): 80 """Return the units.""" 81 return self._units 82 83 def __str__(self): 84 """Declaration of the variable in the input file.""" 85 value = self.value 86 if value is None or not str(value): 87 return '' 88 89 var = self.name 90 line = ' ' + var 91 92 # By default, do not impose a number of decimal points 93 floatdecimal = 0 94 95 # For some inputs, enforce number of decimal points... 96 if any(inp in var for inp in ('xred', 'xcart', 'rprim', 'qpt', 'kpt')): 97 floatdecimal = 16 98 99 # ...but not for those 100 if any(inp in var for inp in ('ngkpt', 'kptrlatt', 'ngqpt', 'ng2qpt')): 101 floatdecimal = 0 102 103 if isinstance(value, np.ndarray): 104 n = 1 105 for i in np.shape(value): 106 n *= i 107 value = np.reshape(value, n) 108 value = list(value) 109 110 # values in lists 111 if isinstance(value, (list, tuple)): 112 113 # Reshape a list of lists into a single list 114 if all(isinstance(v, (list, tuple)) for v in value): 115 line += self.format_list2d(value, floatdecimal) 116 117 else: 118 line += self.format_list(value, floatdecimal) 119 120 # scalar values 121 else: 122 line += ' ' + str(value) 123 124 # Add units 125 if self.units: 126 line += ' ' + self.units 127 128 return line 129 130 def format_scalar(self, val, floatdecimal=0): 131 """ 132 Format a single numerical value into a string 133 with the appropriate number of decimal. 134 """ 135 sval = str(val) 136 if sval.lstrip('-').lstrip('+').isdigit() and floatdecimal == 0: 137 return sval 138 139 try: 140 fval = float(val) 141 except Exception: 142 return sval 143 144 if fval == 0 or (abs(fval) > 1e-3 and abs(fval) < 1e4): 145 form = 'f' 146 addlen = 5 147 else: 148 form = 'e' 149 addlen = 8 150 151 ndec = max(len(str(fval-int(fval)))-2, floatdecimal) 152 ndec = min(ndec, 10) 153 154 sval = '{v:>{l}.{p}{f}}'.format(v=fval, l=ndec+addlen, p=ndec, f=form) 155 156 sval = sval.replace('e', 'd') 157 158 return sval 159 160 def format_list2d(self, values, floatdecimal=0): 161 """Format a list of lists.""" 162 lvals = flatten(values) 163 164 # Determine the representation 165 if all(isinstance(v, int) for v in lvals): 166 type_all = int 167 else: 168 try: 169 for v in lvals: 170 float(v) 171 type_all = float 172 except Exception: 173 type_all = str 174 175 # Determine the format 176 width = max(len(str(s)) for s in lvals) 177 if type_all == int: 178 formatspec = '>{0}d'.format(width) 179 elif type_all == str: 180 formatspec = '>{0}'.format(width) 181 else: 182 183 # Number of decimal 184 maxdec = max(len(str(f-int(f)))-2 for f in lvals) 185 ndec = min(max(maxdec, floatdecimal), 10) 186 187 if all(f == 0 or (abs(f) > 1e-3 and abs(f) < 1e4) for f in lvals): 188 formatspec = '>{w}.{p}f'.format(w=ndec+5, p=ndec) 189 else: 190 formatspec = '>{w}.{p}e'.format(w=ndec+8, p=ndec) 191 192 line = '\n' 193 for L in values: 194 for val in L: 195 line += ' {v:{f}}'.format(v=val, f=formatspec) 196 line += '\n' 197 198 return line.rstrip('\n') 199 200 def format_list(self, values, floatdecimal=0): 201 """ 202 Format a list of values into a string. 203 The result might be spread among several lines. 204 """ 205 line = '' 206 207 # Format the line declaring the value 208 for i, val in enumerate(values): 209 line += ' ' + self.format_scalar(val, floatdecimal) 210 if self.valperline is not None and (i+1) % self.valperline == 0: 211 line += '\n' 212 213 # Add a carriage return in case of several lines 214 if '\n' in line.rstrip('\n'): 215 line = '\n' + line 216 217 return line.rstrip('\n') 218 219 220def is_iter(obj): 221 """Return True if the argument is list-like.""" 222 return hasattr(obj, '__iter__') 223 224 225def flatten(iterable): 226 """Make an iterable flat, i.e. a 1d iterable object.""" 227 iterator = iter(iterable) 228 array, stack = collections.deque(), collections.deque() 229 while True: 230 try: 231 value = next(iterator) 232 except StopIteration: 233 if not stack: 234 return tuple(array) 235 iterator = stack.pop() 236 else: 237 if not isinstance(value, str) \ 238 and isinstance(value, collections.abc.Iterable): 239 stack.append(iterator) 240 iterator = iter(value) 241 else: 242 array.append(value) 243