1"""Type mapping helpers."""
2
3from __future__ import division
4
5__copyright__ = "Copyright (C) 2011 Andreas Kloeckner"
6
7__license__ = """
8Permission is hereby granted, free of charge, to any person
9obtaining a copy of this software and associated documentation
10files (the "Software"), to deal in the Software without
11restriction, including without limitation the rights to use,
12copy, modify, merge, publish, distribute, sublicense, and/or sell
13copies of the Software, and to permit persons to whom the
14Software is furnished to do so, subject to the following
15conditions:
16
17The above copyright notice and this permission notice shall be
18included in all copies or substantial portions of the Software.
19
20THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
21EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
22OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
23NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
24HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
25WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
26FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
27OTHER DEALINGS IN THE SOFTWARE.
28"""
29
30import numpy as np
31
32
33class TypeNameNotKnown(RuntimeError):
34    pass
35
36
37# {{{ registry
38
39class DTypeRegistry(object):
40    def __init__(self):
41        self.dtype_to_name = {}
42        self.name_to_dtype = {}
43
44    def get_or_register_dtype(self, c_names, dtype=None):
45        """Get or register a :class:`numpy.dtype` associated with the C type names
46        in the string list *c_names*. If *dtype* is `None`, no registration is
47        performed, and the :class:`numpy.dtype` must already have been registered.
48        If so, it is returned.  If not, :exc:`TypeNameNotKnown` is raised.
49
50        If *dtype* is not `None`, registration is attempted. If the *c_names* are
51        already known and registered to identical :class:`numpy.dtype` objects,
52        then the previously dtype object of the previously  registered type is
53        returned. If the *c_names* are not yet known, the type is registered. If
54        one of the *c_names* is known but registered to a different type, an error
55        is raised. In this latter case, the type may end up partially registered
56        and any further behavior is undefined.
57
58        .. versionadded:: 2012.2
59        """
60
61        if isinstance(c_names, str):
62            c_names = [c_names]
63
64        if dtype is None:
65            from pytools import single_valued
66            return single_valued(self.name_to_dtype[name] for name in c_names)
67
68        dtype = np.dtype(dtype)
69
70        # check if we've seen an identical dtype, if so retrieve exact dtype object.
71        try:
72            existing_name = self.dtype_to_name[dtype]
73        except KeyError:
74            existed = False
75        else:
76            existed = True
77            existing_dtype = self.name_to_dtype[existing_name]
78            assert existing_dtype == dtype
79            dtype = existing_dtype
80
81        for nm in c_names:
82            try:
83                name_dtype = self.name_to_dtype[nm]
84            except KeyError:
85                self.name_to_dtype[nm] = dtype
86            else:
87                if name_dtype != dtype:
88                    raise RuntimeError("name '%s' already registered to "
89                            "different dtype" % nm)
90
91        if not existed:
92            self.dtype_to_name[dtype] = c_names[0]
93        if not str(dtype) in self.dtype_to_name:
94            self.dtype_to_name[str(dtype)] = c_names[0]
95
96        return dtype
97
98    def dtype_to_ctype(self, dtype):
99        if dtype is None:
100            raise ValueError("dtype may not be None")
101
102        dtype = np.dtype(dtype)
103
104        try:
105            return self.dtype_to_name[dtype]
106        except KeyError:
107            raise ValueError("unable to map dtype '%s'" % dtype)
108
109# }}}
110
111
112# {{{ C types
113
114def fill_registry_with_c_types(reg, respect_windows, include_bool=True):
115    from sys import platform
116    import struct
117
118    if include_bool:
119        # bool is of unspecified size in the OpenCL spec and may in fact be
120        # 4-byte.
121        reg.get_or_register_dtype("bool", np.bool)
122
123    reg.get_or_register_dtype(["signed char", "char"], np.int8)
124    reg.get_or_register_dtype("unsigned char", np.uint8)
125    reg.get_or_register_dtype(["short", "signed short",
126        "signed short int", "short signed int"], np.int16)
127    reg.get_or_register_dtype(["unsigned short",
128        "unsigned short int", "short unsigned int"], np.uint16)
129    reg.get_or_register_dtype(["int", "signed int"], np.int32)
130    reg.get_or_register_dtype(["unsigned", "unsigned int"], np.uint32)
131
132    is_64_bit = struct.calcsize('@P') * 8 == 64
133    if is_64_bit:
134        if 'win32' in platform and respect_windows:
135            i64_name = "long long"
136        else:
137            i64_name = "long"
138
139        reg.get_or_register_dtype(
140                [i64_name, "%s int" % i64_name, "signed %s int" % i64_name,
141                    "%s signed int" % i64_name],
142                np.int64)
143        reg.get_or_register_dtype(
144                ["unsigned %s" % i64_name, "unsigned %s int" % i64_name,
145                    "%s unsigned int" % i64_name],
146                np.uint64)
147
148    # http://projects.scipy.org/numpy/ticket/2017
149    if is_64_bit:
150        reg.get_or_register_dtype(["unsigned %s" % i64_name], np.uintp)
151    else:
152        reg.get_or_register_dtype(["unsigned"], np.uintp)
153
154    reg.get_or_register_dtype("float", np.float32)
155    reg.get_or_register_dtype("double", np.float64)
156
157
158def fill_registry_with_opencl_c_types(reg):
159    reg.get_or_register_dtype(["char", "signed char"], np.int8)
160    reg.get_or_register_dtype(["uchar", "unsigned char"], np.uint8)
161    reg.get_or_register_dtype(["short", "signed short",
162        "signed short int", "short signed int"], np.int16)
163    reg.get_or_register_dtype(["ushort", "unsigned short",
164        "unsigned short int", "short unsigned int"], np.uint16)
165    reg.get_or_register_dtype(["int", "signed int"], np.int32)
166    reg.get_or_register_dtype(["uint", "unsigned", "unsigned int"], np.uint32)
167
168    reg.get_or_register_dtype(
169            ["long", "long int", "signed long int",
170                "long signed int"],
171            np.int64)
172    reg.get_or_register_dtype(
173            ["ulong", "unsigned long", "unsigned long int",
174                "long unsigned int"],
175            np.uint64)
176
177    reg.get_or_register_dtype(["intptr_t"], np.intp)
178    reg.get_or_register_dtype(["uintptr_t"], np.uintp)
179
180    reg.get_or_register_dtype("float", np.float32)
181    reg.get_or_register_dtype("double", np.float64)
182# }}}
183
184
185# {{{ backward compatibility
186
187TYPE_REGISTRY = DTypeRegistry()
188
189# These are deprecated and should no longer be used
190DTYPE_TO_NAME = TYPE_REGISTRY.dtype_to_name
191NAME_TO_DTYPE = TYPE_REGISTRY.name_to_dtype
192
193dtype_to_ctype = TYPE_REGISTRY.dtype_to_ctype
194get_or_register_dtype = TYPE_REGISTRY.get_or_register_dtype
195
196
197def _fill_dtype_registry(respect_windows, include_bool=True):
198    fill_registry_with_c_types(
199            TYPE_REGISTRY, respect_windows, include_bool)
200
201# }}}
202
203
204# {{{ c declarator parsing
205
206def parse_c_arg_backend(c_arg, scalar_arg_factory, vec_arg_factory,
207        name_to_dtype=None):
208    if isinstance(name_to_dtype, DTypeRegistry):
209        name_to_dtype = name_to_dtype.name_to_dtype__getitem__
210    elif name_to_dtype is None:
211        name_to_dtype = NAME_TO_DTYPE.__getitem__
212
213    c_arg = (c_arg
214            .replace("const", "")
215            .replace("volatile", "")
216            .replace("__restrict__", "")
217            .replace("restrict", ""))
218
219    # process and remove declarator
220    import re
221    decl_re = re.compile(r"(\**)\s*([_a-zA-Z0-9]+)(\s*\[[ 0-9]*\])*\s*$")
222    decl_match = decl_re.search(c_arg)
223
224    if decl_match is None:
225        raise ValueError("couldn't parse C declarator '%s'" % c_arg)
226
227    name = decl_match.group(2)
228
229    if decl_match.group(1) or decl_match.group(3) is not None:
230        arg_class = vec_arg_factory
231    else:
232        arg_class = scalar_arg_factory
233
234    tp = c_arg[:decl_match.start()]
235    tp = " ".join(tp.split())
236
237    try:
238        dtype = name_to_dtype(tp)
239    except KeyError:
240        raise ValueError("unknown type '%s'" % tp)
241
242    return arg_class(dtype, name)
243
244# }}}
245
246
247def register_dtype(dtype, c_names, alias_ok=False):
248    from warnings import warn
249    warn("register_dtype is deprecated. Use get_or_register_dtype instead.",
250            DeprecationWarning, stacklevel=2)
251
252    if isinstance(c_names, str):
253        c_names = [c_names]
254
255    dtype = np.dtype(dtype)
256
257    # check if we've seen this dtype before and error out if a) it was seen before
258    # and b) alias_ok is False.
259
260    if not alias_ok and dtype in TYPE_REGISTRY.dtype_to_name:
261        raise RuntimeError("dtype '%s' already registered (as '%s', new names '%s')"
262                % (dtype, TYPE_REGISTRY.dtype_to_name[dtype], ", ".join(c_names)))
263
264    TYPE_REGISTRY.get_or_register_dtype(c_names, dtype)
265
266
267# vim: foldmethod=marker
268