1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18"""Defines a utility for representing deferred class instatiations as JSON."""
19
20import importlib
21import json
22import typing
23
24
25JsonSerializable = typing.Union[int, float, str, None, bool]
26
27
28class SerializedFactoryError(Exception):
29    """Raised when ClassFactory.from_json is invoked with an invalid JSON blob."""
30
31
32class ClassFactory:
33    """Describes a JSON-serializable class instantiation, for use with the RPC server."""
34
35    # When not None, the superclass from which all cls must derive.
36    SUPERCLASS = None
37
38    def __init__(
39        self,
40        cls: typing.Callable,
41        init_args: typing.List[JsonSerializable],
42        init_kw: typing.Dict[str, JsonSerializable],
43    ):
44        self.cls = cls
45        self.init_args = init_args
46        self.init_kw = init_kw
47
48    def override_kw(self, **kw_overrides):
49        kwargs = self.init_kw
50        if kw_overrides:
51            kwargs = dict(kwargs)
52            for k, v in kw_overrides.items():
53                kwargs[k] = v
54
55        return self.__class__(self.cls, self.init_args, kwargs)
56
57    def instantiate(self):
58        return self.cls(*self.init_args, **self.init_kw)
59
60    @property
61    def to_json(self):
62        return json.dumps(
63            {
64                "cls": ".".join([self.cls.__module__, self.cls.__name__]),
65                "init_args": self.init_args,
66                "init_kw": self.init_kw,
67            }
68        )
69
70    EXPECTED_KEYS = ("cls", "init_args", "init_kw")
71
72    @classmethod
73    def from_json(cls, data):
74        """Reconstruct a ClassFactory instance from its JSON representation.
75
76        Parameters
77        ----------
78        data : str
79            The JSON representation of the ClassFactory.
80
81        Returns
82        -------
83        ClassFactory :
84            The reconstructed ClassFactory instance.
85
86        Raises
87        ------
88        SerializedFactoryError :
89            If the JSON object represented by `data` is malformed.
90        """
91        obj = json.loads(data)
92        if not isinstance(obj, dict):
93            raise SerializedFactoryError(f"deserialized json payload: want dict, got: {obj!r}")
94
95        for key in cls.EXPECTED_KEYS:
96            if key not in obj:
97                raise SerializedFactoryError(
98                    f"deserialized json payload: expect key {key}, got: {obj!r}"
99                )
100
101        cls_package_name, cls_name = obj["cls"].rsplit(".", 1)
102        cls_package = importlib.import_module(cls_package_name)
103        cls_obj = getattr(cls_package, cls_name)
104        return cls(cls_obj, obj["init_args"], obj["init_kw"])
105