1"""Support for Abinit input variables."""
2
3import string
4import collections
5import numpy as np
6
7
8__all__ = [
9    'InputVariable',
10]
11
12_SPECIAL_DATASET_INDICES = (':', '+', '?')
13
14_DATASET_INDICES = ''.join(list(string.digits) + list(_SPECIAL_DATASET_INDICES))
15
16_INTERNAL_DATASET_INDICES = ('__s', '__i', '__a')
17
18_SPECIAL_CONVERSION = zip(_INTERNAL_DATASET_INDICES, _SPECIAL_DATASET_INDICES)
19
20_UNITS = {
21    'bohr': 1.0,
22    'angstrom': 1.8897261328856432,
23    'hartree': 1.0,
24    'Ha': 1.0,
25    'eV': 0.03674932539796232,
26}
27
28
29class InputVariable(object):
30    """
31    An Abinit input variable.
32    """
33    def __init__(self, name, value, units='', valperline=3):
34        """
35        Args:
36            name: Name of the variable.
37            value: Value of the variable.
38            units: String specifying one of the units supported by Abinit. Default: atomic units.
39            valperline: Number of items printed per line.
40        """
41
42        self._name = name
43        self.value = value
44        self._units = units
45
46        # Maximum number of values per line.
47        self.valperline = valperline
48        if name in ['bdgw']:
49            self.valperline = 2
50
51        if (is_iter(self.value) and isinstance(self.value[-1], str) and self.value[-1] in _UNITS):
52            self.value = list(self.value)
53            self._units = self.value.pop(-1)
54
55    def get_value(self):
56        """Return the value."""
57        if self.units:
58            return list(self.value) + [self.units]
59        else:
60            return self.value
61
62    @property
63    def name(self):
64        """Name of the variable."""
65        return self._name
66
67    @property
68    def basename(self):
69        """Return the name trimmed of any dataset index."""
70        basename = self.name
71        return basename.rstrip(_DATASET_INDICES)
72
73    @property
74    def dataset(self):
75        """Return the dataset index in string form."""
76        return self.name.split(self.basename)[-1]
77
78    @property
79    def units(self):
80        """Return the units."""
81        return self._units
82
83    def __str__(self):
84        """Declaration of the variable in the input file."""
85        value = self.value
86        if value is None or not str(value):
87            return ''
88
89        var = self.name
90        line = ' ' + var
91
92        # By default, do not impose a number of decimal points
93        floatdecimal = 0
94
95        # For some inputs, enforce number of decimal points...
96        if any(inp in var for inp in ('xred', 'xcart', 'rprim', 'qpt', 'kpt')):
97            floatdecimal = 16
98
99        # ...but not for those
100        if any(inp in var for inp in ('ngkpt', 'kptrlatt', 'ngqpt', 'ng2qpt')):
101            floatdecimal = 0
102
103        if isinstance(value, np.ndarray):
104            n = 1
105            for i in np.shape(value):
106                n *= i
107            value = np.reshape(value, n)
108            value = list(value)
109
110        # values in lists
111        if isinstance(value, (list, tuple)):
112
113            # Reshape a list of lists into a single list
114            if all(isinstance(v, (list, tuple)) for v in value):
115                line += self.format_list2d(value, floatdecimal)
116
117            else:
118                line += self.format_list(value, floatdecimal)
119
120        # scalar values
121        else:
122            line += ' ' + str(value)
123
124        # Add units
125        if self.units:
126            line += ' ' + self.units
127
128        return line
129
130    def format_scalar(self, val, floatdecimal=0):
131        """
132        Format a single numerical value into a string
133        with the appropriate number of decimal.
134        """
135        sval = str(val)
136        if sval.lstrip('-').lstrip('+').isdigit() and floatdecimal == 0:
137            return sval
138
139        try:
140            fval = float(val)
141        except Exception:
142            return sval
143
144        if fval == 0 or (abs(fval) > 1e-3 and abs(fval) < 1e4):
145            form = 'f'
146            addlen = 5
147        else:
148            form = 'e'
149            addlen = 8
150
151        ndec = max(len(str(fval-int(fval)))-2, floatdecimal)
152        ndec = min(ndec, 10)
153
154        sval = '{v:>{l}.{p}{f}}'.format(v=fval, l=ndec+addlen, p=ndec, f=form)
155
156        sval = sval.replace('e', 'd')
157
158        return sval
159
160    def format_list2d(self, values, floatdecimal=0):
161        """Format a list of lists."""
162        lvals = flatten(values)
163
164        # Determine the representation
165        if all(isinstance(v, int) for v in lvals):
166            type_all = int
167        else:
168            try:
169                for v in lvals:
170                    float(v)
171                type_all = float
172            except Exception:
173                type_all = str
174
175        # Determine the format
176        width = max(len(str(s)) for s in lvals)
177        if type_all == int:
178            formatspec = '>{0}d'.format(width)
179        elif type_all == str:
180            formatspec = '>{0}'.format(width)
181        else:
182
183            # Number of decimal
184            maxdec = max(len(str(f-int(f)))-2 for f in lvals)
185            ndec = min(max(maxdec, floatdecimal), 10)
186
187            if all(f == 0 or (abs(f) > 1e-3 and abs(f) < 1e4) for f in lvals):
188                formatspec = '>{w}.{p}f'.format(w=ndec+5, p=ndec)
189            else:
190                formatspec = '>{w}.{p}e'.format(w=ndec+8, p=ndec)
191
192        line = '\n'
193        for L in values:
194            for val in L:
195                line += ' {v:{f}}'.format(v=val, f=formatspec)
196            line += '\n'
197
198        return line.rstrip('\n')
199
200    def format_list(self, values, floatdecimal=0):
201        """
202        Format a list of values into a string.
203        The result might be spread among several lines.
204        """
205        line = ''
206
207        # Format the line declaring the value
208        for i, val in enumerate(values):
209            line += ' ' + self.format_scalar(val, floatdecimal)
210            if self.valperline is not None and (i+1) % self.valperline == 0:
211                line += '\n'
212
213        # Add a carriage return in case of several lines
214        if '\n' in line.rstrip('\n'):
215            line = '\n' + line
216
217        return line.rstrip('\n')
218
219
220def is_iter(obj):
221    """Return True if the argument is list-like."""
222    return hasattr(obj, '__iter__')
223
224
225def flatten(iterable):
226    """Make an iterable flat, i.e. a 1d iterable object."""
227    iterator = iter(iterable)
228    array, stack = collections.deque(), collections.deque()
229    while True:
230        try:
231            value = next(iterator)
232        except StopIteration:
233            if not stack:
234                return tuple(array)
235            iterator = stack.pop()
236        else:
237            if not isinstance(value, str) \
238               and isinstance(value, collections.abc.Iterable):
239                stack.append(iterator)
240                iterator = iter(value)
241            else:
242                array.append(value)
243