1import json
2
3import webdriver
4
5
6"""WebDriver wire protocol codecs."""
7
8
9class Encoder(json.JSONEncoder):
10    def __init__(self, *args, **kwargs):
11        kwargs.pop("session")
12        super(Encoder, self).__init__(*args, **kwargs)
13
14    def default(self, obj):
15        if isinstance(obj, (list, tuple)):
16            return [self.default(x) for x in obj]
17        elif isinstance(obj, webdriver.Element):
18            return {webdriver.Element.identifier: obj.id}
19        elif isinstance(obj, webdriver.Frame):
20            return {webdriver.Frame.identifier: obj.id}
21        elif isinstance(obj, webdriver.Window):
22            return {webdriver.Frame.identifier: obj.id}
23        elif isinstance(obj, webdriver.ShadowRoot):
24            return {webdriver.ShadowRoot.identifier: obj.id}
25        return super(Encoder, self).default(obj)
26
27
28class Decoder(json.JSONDecoder):
29    def __init__(self, *args, **kwargs):
30        self.session = kwargs.pop("session")
31        super(Decoder, self).__init__(
32            object_hook=self.object_hook, *args, **kwargs)
33
34    def object_hook(self, payload):
35        if isinstance(payload, (list, tuple)):
36            return [self.object_hook(x) for x in payload]
37        elif isinstance(payload, dict) and webdriver.Element.identifier in payload:
38            return webdriver.Element.from_json(payload, self.session)
39        elif isinstance(payload, dict) and webdriver.Frame.identifier in payload:
40            return webdriver.Frame.from_json(payload, self.session)
41        elif isinstance(payload, dict) and webdriver.Window.identifier in payload:
42            return webdriver.Window.from_json(payload, self.session)
43        elif isinstance(payload, dict) and webdriver.ShadowRoot.identifier in payload:
44            return webdriver.ShadowRoot.from_json(payload, self.session)
45        elif isinstance(payload, dict):
46            return {k: self.object_hook(v) for k, v in payload.items()}
47        return payload
48