1import json
2
3import webdriver
4
5
6"""WebDriver wire protocol codecs."""
7
8
9class Encoder(json.JSONEncoder):
10    def __init__(self, *args, **kwargs):
11        kwargs.pop("session")
12        super(Encoder, self).__init__(*args, **kwargs)
13
14    def default(self, obj):
15        if isinstance(obj, (list, tuple)):
16            return [self.default(x) for x in obj]
17        elif isinstance(obj, webdriver.Element):
18            return {webdriver.Element.identifier: obj.id}
19        return super(Encoder, self).default(obj)
20
21
22class Decoder(json.JSONDecoder):
23    def __init__(self, *args, **kwargs):
24        self.session = kwargs.pop("session")
25        super(Decoder, self).__init__(
26            object_hook=self.object_hook, *args, **kwargs)
27
28    def object_hook(self, payload):
29        if isinstance(payload, (list, tuple)):
30            return [self.object_hook(x) for x in payload]
31        elif isinstance(payload, dict) and webdriver.Element.identifier in payload:
32            return webdriver.Element.from_json(payload, self.session)
33        elif isinstance(payload, dict):
34            return {k: self.object_hook(v) for k, v in payload.iteritems()}
35        return payload
36