1# This Source Code Form is subject to the terms of the Mozilla Public
2# License, v. 2.0. If a copy of the MPL was not distributed with this
3# file, You can obtain one at http://mozilla.org/MPL/2.0/.
4
5from __future__ import absolute_import
6
7import inspect
8import six
9from six.moves import range
10from six.moves import zip
11
12convertor_registry = {}
13missing = object()
14no_default = object()
15
16
17class log_action(object):
18    def __init__(self, *args):
19        self.args = {}
20
21        self.args_no_default = []
22        self.args_with_default = []
23        self.optional_args = set()
24
25        # These are the required fields in a log message that usually aren't
26        # supplied by the caller, but can be in the case of log_raw
27        self.default_args = [
28            Unicode("action"),
29            Int("time"),
30            Unicode("thread"),
31            Int("pid", default=None),
32            Unicode("source"),
33            Unicode("component"),
34        ]
35
36        for arg in args:
37            if arg.default is no_default:
38                self.args_no_default.append(arg.name)
39            else:
40                self.args_with_default.append(arg.name)
41
42            if arg.optional:
43                self.optional_args.add(arg.name)
44
45            if arg.name in self.args:
46                raise ValueError("Repeated argument name %s" % arg.name)
47
48            self.args[arg.name] = arg
49
50        for extra in self.default_args:
51            self.args[extra.name] = extra
52
53    def __call__(self, f):
54        convertor_registry[f.__name__] = self
55        converter = self
56
57        def inner(self, *args, **kwargs):
58            data = converter.convert(*args, **kwargs)
59            return f(self, data)
60
61        if hasattr(f, "__doc__"):
62            setattr(inner, "__doc__", f.__doc__)
63
64        return inner
65
66    def convert(self, *args, **kwargs):
67        data = {}
68        values = {}
69        values.update(kwargs)
70
71        positional_no_default = [
72            item for item in self.args_no_default if item not in values
73        ]
74
75        num_no_default = len(positional_no_default)
76
77        if len(args) < num_no_default:
78            raise TypeError("Too few arguments")
79
80        if len(args) > num_no_default + len(self.args_with_default):
81            raise TypeError("Too many arguments")
82
83        for i, name in enumerate(positional_no_default):
84            values[name] = args[i]
85
86        positional_with_default = [
87            self.args_with_default[i] for i in range(len(args) - num_no_default)
88        ]
89
90        for i, name in enumerate(positional_with_default):
91            if name in values:
92                raise TypeError("Argument %s specified twice" % name)
93            values[name] = args[i + num_no_default]
94
95        # Fill in missing arguments
96        for name in self.args_with_default:
97            if name not in values:
98                values[name] = self.args[name].default
99
100        for key, value in six.iteritems(values):
101            if key in self.args:
102                out_value = self.args[key](value)
103                if out_value is not missing:
104                    if key in self.optional_args and value == self.args[key].default:
105                        pass
106                    else:
107                        data[key] = out_value
108            else:
109                raise TypeError("Unrecognised argument %s" % key)
110
111        return data
112
113    def convert_known(self, **kwargs):
114        known_kwargs = {
115            name: value for name, value in six.iteritems(kwargs) if name in self.args
116        }
117        return self.convert(**known_kwargs)
118
119
120class DataType(object):
121    def __init__(self, name, default=no_default, optional=False):
122        self.name = name
123        self.default = default
124
125        if default is no_default and optional is not False:
126            raise ValueError("optional arguments require a default value")
127
128        self.optional = optional
129
130    def __call__(self, value):
131        if value == self.default:
132            if self.optional:
133                return missing
134            return self.default
135
136        try:
137            return self.convert(value)
138        except Exception:
139            raise ValueError(
140                "Failed to convert value %s of type %s for field %s to type %s"
141                % (value, type(value).__name__, self.name, self.__class__.__name__)
142            )
143
144
145class ContainerType(DataType):
146    """A DataType that contains other DataTypes.
147
148    ContainerTypes must specify which other DataType they will contain. ContainerTypes
149    may contain other ContainerTypes.
150
151    Some examples:
152
153        List(Int, 'numbers')
154        Tuple((Unicode, Int, Any), 'things')
155        Dict(Unicode, 'config')
156        Dict({TestId: Status}, 'results')
157        Dict(List(Unicode), 'stuff')
158    """
159
160    def __init__(self, item_type, name=None, **kwargs):
161        DataType.__init__(self, name, **kwargs)
162        self.item_type = self._format_item_type(item_type)
163
164    def _format_item_type(self, item_type):
165        if inspect.isclass(item_type):
166            return item_type(None)
167        return item_type
168
169
170class Unicode(DataType):
171    def convert(self, data):
172        if isinstance(data, six.text_type):
173            return data
174        if isinstance(data, str):
175            return data.decode("utf8", "replace")
176        return six.text_type(data)
177
178
179class TestId(DataType):
180    def convert(self, data):
181        if isinstance(data, six.text_type):
182            return data
183        elif isinstance(data, bytes):
184            return data.decode("utf-8", "replace")
185        elif isinstance(data, (tuple, list)):
186            # This is really a bit of a hack; should really split out convertors from the
187            # fields they operate on
188            func = Unicode(None).convert
189            return tuple(func(item) for item in data)
190        else:
191            raise ValueError
192
193
194class Status(DataType):
195    allowed = [
196        "PASS",
197        "FAIL",
198        "OK",
199        "ERROR",
200        "TIMEOUT",
201        "CRASH",
202        "ASSERT",
203        "PRECONDITION_FAILED",
204        "SKIP",
205    ]
206
207    def convert(self, data):
208        value = data.upper()
209        if value not in self.allowed:
210            raise ValueError
211        return value
212
213
214class SubStatus(Status):
215    allowed = [
216        "PASS",
217        "FAIL",
218        "ERROR",
219        "TIMEOUT",
220        "ASSERT",
221        "PRECONDITION_FAILED",
222        "NOTRUN",
223        "SKIP",
224    ]
225
226
227class Dict(ContainerType):
228    def _format_item_type(self, item_type):
229        superfmt = super(Dict, self)._format_item_type
230
231        if isinstance(item_type, dict):
232            if len(item_type) != 1:
233                raise ValueError(
234                    "Dict item type specifier must contain a single entry."
235                )
236            key_type, value_type = list(item_type.items())[0]
237            return superfmt(key_type), superfmt(value_type)
238        return Any(None), superfmt(item_type)
239
240    def convert(self, data):
241        key_type, value_type = self.item_type
242        return {
243            key_type.convert(k): value_type.convert(v) for k, v in dict(data).items()
244        }
245
246
247class List(ContainerType):
248    def convert(self, data):
249        # while dicts and strings _can_ be cast to lists,
250        # doing so is likely not intentional behaviour
251        if isinstance(data, (six.string_types, dict)):
252            raise ValueError("Expected list but got %s" % type(data))
253        return [self.item_type.convert(item) for item in data]
254
255
256class TestList(DataType):
257    """A TestList is a list of tests that can be either keyed by a group name,
258    or specified as a flat list.
259    """
260
261    def convert(self, data):
262        if isinstance(data, (list, tuple)):
263            data = {"default": data}
264        return Dict({Unicode: List(Unicode)}).convert(data)
265
266
267class Int(DataType):
268    def convert(self, data):
269        return int(data)
270
271
272class Any(DataType):
273    def convert(self, data):
274        return data
275
276
277class Boolean(DataType):
278    def convert(self, data):
279        return bool(data)
280
281
282class Tuple(ContainerType):
283    def _format_item_type(self, item_type):
284        superfmt = super(Tuple, self)._format_item_type
285
286        if isinstance(item_type, (tuple, list)):
287            return [superfmt(t) for t in item_type]
288        return (superfmt(item_type),)
289
290    def convert(self, data):
291        if len(data) != len(self.item_type):
292            raise ValueError(
293                "Expected %i items got %i" % (len(self.item_type), len(data))
294            )
295        return tuple(
296            item_type.convert(value) for item_type, value in zip(self.item_type, data)
297        )
298
299
300class Nullable(ContainerType):
301    def convert(self, data):
302        if data is None:
303            return data
304        return self.item_type.convert(data)
305