1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18"""Defines a utility for representing deferred class instatiations as JSON.""" 19 20import importlib 21import json 22import typing 23 24 25JsonSerializable = typing.Union[int, float, str, None, bool] 26 27 28class SerializedFactoryError(Exception): 29 """Raised when ClassFactory.from_json is invoked with an invalid JSON blob.""" 30 31 32class ClassFactory: 33 """Describes a JSON-serializable class instantiation, for use with the RPC server.""" 34 35 # When not None, the superclass from which all cls must derive. 36 SUPERCLASS = None 37 38 def __init__( 39 self, 40 cls: typing.Callable, 41 init_args: typing.List[JsonSerializable], 42 init_kw: typing.Dict[str, JsonSerializable], 43 ): 44 self.cls = cls 45 self.init_args = init_args 46 self.init_kw = init_kw 47 48 def override_kw(self, **kw_overrides): 49 kwargs = self.init_kw 50 if kw_overrides: 51 kwargs = dict(kwargs) 52 for k, v in kw_overrides.items(): 53 kwargs[k] = v 54 55 return self.__class__(self.cls, self.init_args, kwargs) 56 57 def instantiate(self): 58 return self.cls(*self.init_args, **self.init_kw) 59 60 @property 61 def to_json(self): 62 return json.dumps( 63 { 64 "cls": ".".join([self.cls.__module__, self.cls.__name__]), 65 "init_args": self.init_args, 66 "init_kw": self.init_kw, 67 } 68 ) 69 70 EXPECTED_KEYS = ("cls", "init_args", "init_kw") 71 72 @classmethod 73 def from_json(cls, data): 74 """Reconstruct a ClassFactory instance from its JSON representation. 75 76 Parameters 77 ---------- 78 data : str 79 The JSON representation of the ClassFactory. 80 81 Returns 82 ------- 83 ClassFactory : 84 The reconstructed ClassFactory instance. 85 86 Raises 87 ------ 88 SerializedFactoryError : 89 If the JSON object represented by `data` is malformed. 90 """ 91 obj = json.loads(data) 92 if not isinstance(obj, dict): 93 raise SerializedFactoryError(f"deserialized json payload: want dict, got: {obj!r}") 94 95 for key in cls.EXPECTED_KEYS: 96 if key not in obj: 97 raise SerializedFactoryError( 98 f"deserialized json payload: expect key {key}, got: {obj!r}" 99 ) 100 101 cls_package_name, cls_name = obj["cls"].rsplit(".", 1) 102 cls_package = importlib.import_module(cls_package_name) 103 cls_obj = getattr(cls_package, cls_name) 104 return cls(cls_obj, obj["init_args"], obj["init_kw"]) 105