1# This Source Code Form is subject to the terms of the Mozilla Public
2# License, v. 2.0. If a copy of the MPL was not distributed with this
3# file, You can obtain one at http://mozilla.org/MPL/2.0/.
4
5from __future__ import absolute_import
6
7import inspect
8import six
9from six.moves import range
10from six.moves import zip
11
12convertor_registry = {}
13missing = object()
14no_default = object()
15
16
17class log_action(object):
18
19    def __init__(self, *args):
20        self.args = {}
21
22        self.args_no_default = []
23        self.args_with_default = []
24        self.optional_args = set()
25
26        # These are the required fields in a log message that usually aren't
27        # supplied by the caller, but can be in the case of log_raw
28        self.default_args = [
29            Unicode("action"),
30            Int("time"),
31            Unicode("thread"),
32            Int("pid", default=None),
33            Unicode("source"),
34            Unicode("component")]
35
36        for arg in args:
37            if arg.default is no_default:
38                self.args_no_default.append(arg.name)
39            else:
40                self.args_with_default.append(arg.name)
41
42            if arg.optional:
43                self.optional_args.add(arg.name)
44
45            if arg.name in self.args:
46                raise ValueError("Repeated argument name %s" % arg.name)
47
48            self.args[arg.name] = arg
49
50        for extra in self.default_args:
51            self.args[extra.name] = extra
52
53    def __call__(self, f):
54        convertor_registry[f.__name__] = self
55        converter = self
56
57        def inner(self, *args, **kwargs):
58            data = converter.convert(*args, **kwargs)
59            return f(self, data)
60
61        if hasattr(f, '__doc__'):
62            setattr(inner, '__doc__', f.__doc__)
63
64        return inner
65
66    def convert(self, *args, **kwargs):
67        data = {}
68        values = {}
69        values.update(kwargs)
70
71        positional_no_default = [item for item in self.args_no_default if item not in values]
72
73        num_no_default = len(positional_no_default)
74
75        if len(args) < num_no_default:
76            raise TypeError("Too few arguments")
77
78        if len(args) > num_no_default + len(self.args_with_default):
79            raise TypeError("Too many arguments")
80
81        for i, name in enumerate(positional_no_default):
82            values[name] = args[i]
83
84        positional_with_default = [self.args_with_default[i]
85                                   for i in range(len(args) - num_no_default)]
86
87        for i, name in enumerate(positional_with_default):
88            if name in values:
89                raise TypeError("Argument %s specified twice" % name)
90            values[name] = args[i + num_no_default]
91
92        # Fill in missing arguments
93        for name in self.args_with_default:
94            if name not in values:
95                values[name] = self.args[name].default
96
97        for key, value in six.iteritems(values):
98            if key in self.args:
99                out_value = self.args[key](value)
100                if out_value is not missing:
101                    if (key in self.optional_args and
102                            value == self.args[key].default):
103                        pass
104                    else:
105                        data[key] = out_value
106            else:
107                raise TypeError("Unrecognised argument %s" % key)
108
109        return data
110
111    def convert_known(self, **kwargs):
112        known_kwargs = {name: value for name, value in six.iteritems(kwargs)
113                        if name in self.args}
114        return self.convert(**known_kwargs)
115
116
117class DataType(object):
118
119    def __init__(self, name, default=no_default, optional=False):
120        self.name = name
121        self.default = default
122
123        if default is no_default and optional is not False:
124            raise ValueError("optional arguments require a default value")
125
126        self.optional = optional
127
128    def __call__(self, value):
129        if value == self.default:
130            if self.optional:
131                return missing
132            return self.default
133
134        try:
135            return self.convert(value)
136        except Exception:
137            raise ValueError("Failed to convert value %s of type %s for field %s to type %s" %
138                             (value, type(value).__name__, self.name, self.__class__.__name__))
139
140
141class ContainerType(DataType):
142    """A DataType that contains other DataTypes.
143
144    ContainerTypes must specify which other DataType they will contain. ContainerTypes
145    may contain other ContainerTypes.
146
147    Some examples:
148
149        List(Int, 'numbers')
150        Tuple((Unicode, Int, Any), 'things')
151        Dict(Unicode, 'config')
152        Dict({TestId: Status}, 'results')
153        Dict(List(Unicode), 'stuff')
154    """
155
156    def __init__(self, item_type, name=None, **kwargs):
157        DataType.__init__(self, name, **kwargs)
158        self.item_type = self._format_item_type(item_type)
159
160    def _format_item_type(self, item_type):
161        if inspect.isclass(item_type):
162            return item_type(None)
163        return item_type
164
165
166class Unicode(DataType):
167
168    def convert(self, data):
169        if isinstance(data, six.text_type):
170            return data
171        if isinstance(data, str):
172            return data.decode("utf8", "replace")
173        return six.text_type(data)
174
175
176class TestId(DataType):
177
178    def convert(self, data):
179        if isinstance(data, six.text_type):
180            return data
181        elif isinstance(data, bytes):
182            return data.decode("utf-8", "replace")
183        elif isinstance(data, (tuple, list)):
184            # This is really a bit of a hack; should really split out convertors from the
185            # fields they operate on
186            func = Unicode(None).convert
187            return tuple(func(item) for item in data)
188        else:
189            raise ValueError
190
191
192class Status(DataType):
193    allowed = ["PASS", "FAIL", "OK", "ERROR", "TIMEOUT", "CRASH", "ASSERT", "PRECONDITION_FAILED",
194               "SKIP"]
195
196    def convert(self, data):
197        value = data.upper()
198        if value not in self.allowed:
199            raise ValueError
200        return value
201
202
203class SubStatus(Status):
204    allowed = ["PASS", "FAIL", "ERROR", "TIMEOUT", "ASSERT", "PRECONDITION_FAILED", "NOTRUN",
205               "SKIP"]
206
207
208class Dict(ContainerType):
209
210    def _format_item_type(self, item_type):
211        superfmt = super(Dict, self)._format_item_type
212
213        if isinstance(item_type, dict):
214            if len(item_type) != 1:
215                raise ValueError("Dict item type specifier must contain a single entry.")
216            key_type, value_type = list(item_type.items())[0]
217            return superfmt(key_type), superfmt(value_type)
218        return Any(None), superfmt(item_type)
219
220    def convert(self, data):
221        key_type, value_type = self.item_type
222        return {key_type.convert(k): value_type.convert(v) for k, v in dict(data).items()}
223
224
225class List(ContainerType):
226
227    def convert(self, data):
228        # while dicts and strings _can_ be cast to lists,
229        # doing so is likely not intentional behaviour
230        if isinstance(data, (six.string_types, dict)):
231            raise ValueError("Expected list but got %s" % type(data))
232        return [self.item_type.convert(item) for item in data]
233
234
235class TestList(DataType):
236    """A TestList is a list of tests that can be either keyed by a group name,
237    or specified as a flat list.
238    """
239
240    def convert(self, data):
241        if isinstance(data, (list, tuple)):
242            data = {'default': data}
243        return Dict({Unicode: List(Unicode)}).convert(data)
244
245
246class Int(DataType):
247
248    def convert(self, data):
249        return int(data)
250
251
252class Any(DataType):
253
254    def convert(self, data):
255        return data
256
257
258class Boolean(DataType):
259
260    def convert(self, data):
261        return bool(data)
262
263
264class Tuple(ContainerType):
265
266    def _format_item_type(self, item_type):
267        superfmt = super(Tuple, self)._format_item_type
268
269        if isinstance(item_type, (tuple, list)):
270            return [superfmt(t) for t in item_type]
271        return (superfmt(item_type),)
272
273    def convert(self, data):
274        if len(data) != len(self.item_type):
275            raise ValueError("Expected %i items got %i" % (len(self.item_type), len(data)))
276        return tuple(item_type.convert(value)
277                     for item_type, value in zip(self.item_type, data))
278
279
280class Nullable(ContainerType):
281    def convert(self, data):
282        if data is None:
283            return data
284        return self.item_type.convert(data)
285