1# coding: utf-8
2"""
3JSON serialization and deserialization utilities.
4"""
5
6import os
7import json
8import types
9import datetime
10
11from hashlib import sha1
12from collections import OrderedDict, defaultdict
13from enum import Enum
14
15from importlib import import_module
16
17from inspect import getfullargspec
18from uuid import UUID
19
20try:
21    import numpy as np
22except ImportError:
23    np = None  # type: ignore
24
25try:
26    import pandas as pd
27except ImportError:
28    pd = None  # type: ignore
29
30try:
31    import pydantic
32except ImportError:
33    pydantic = None  # type: ignore
34
35try:
36    import bson
37except ImportError:
38    bson = None
39
40try:
41    from ruamel import yaml
42except ImportError:
43    try:
44        import yaml  # type: ignore
45    except ImportError:
46        yaml = None  # type: ignore
47
48__version__ = "3.0.0"
49
50
51def _load_redirect(redirect_file):
52    try:
53        with open(redirect_file, "rt") as f:
54            d = yaml.safe_load(f)
55    except IOError:
56        # If we can't find the file
57        # Just use an empty redirect dict
58        return {}
59
60    # Convert the full paths to module/class
61    redirect_dict = defaultdict(dict)
62    for old_path, new_path in d.items():
63        old_class = old_path.split(".")[-1]
64        old_module = ".".join(old_path.split(".")[:-1])
65
66        new_class = new_path.split(".")[-1]
67        new_module = ".".join(new_path.split(".")[:-1])
68
69        redirect_dict[old_module][old_class] = {
70            "@module": new_module,
71            "@class": new_class,
72        }
73
74    return dict(redirect_dict)
75
76
77class MSONable:
78    """
79    This is a mix-in base class specifying an API for msonable objects. MSON
80    is Monty JSON. Essentially, MSONable objects must implement an as_dict
81    method, which must return a json serializable dict and must also support
82    no arguments (though optional arguments to finetune the output is ok),
83    and a from_dict class method that regenerates the object from the dict
84    generated by the as_dict method. The as_dict method should contain the
85    "@module" and "@class" keys which will allow the MontyEncoder to
86    dynamically deserialize the class. E.g.::
87
88        d["@module"] = self.__class__.__module__
89        d["@class"] = self.__class__.__name__
90
91    A default implementation is provided in MSONable, which automatically
92    determines if the class already contains self.argname or self._argname
93    attributes for every arg. If so, these will be used for serialization in
94    the dict format. Similarly, the default from_dict will deserialization
95    classes of such form. An example is given below::
96
97        class MSONClass(MSONable):
98
99        def __init__(self, a, b, c, d=1, **kwargs):
100            self.a = a
101            self.b = b
102            self._c = c
103            self._d = d
104            self.kwargs = kwargs
105
106    For such classes, you merely need to inherit from MSONable and you do not
107    need to implement your own as_dict or from_dict protocol.
108
109    New to Monty V2.0.6....
110    Classes can be redirected to moved implementations by putting in the old
111    fully qualified path and new fully qualified path into .monty.yaml in the
112    home folder
113
114    Example:
115    old_module.old_class: new_module.new_class
116    """
117
118    REDIRECT = _load_redirect(os.path.join(os.path.expanduser("~"), ".monty.yaml"))
119
120    def as_dict(self) -> dict:
121        """
122        A JSON serializable dict representation of an object.
123        """
124        d = {"@module": self.__class__.__module__, "@class": self.__class__.__name__}
125
126        try:
127            parent_module = self.__class__.__module__.split(".", maxsplit=1)[0]
128            module_version = import_module(parent_module).__version__  # type: ignore
129            d["@version"] = "{}".format(module_version)
130        except (AttributeError, ImportError):
131            d["@version"] = None  # type: ignore
132
133        spec = getfullargspec(self.__class__.__init__)
134        args = spec.args
135
136        def recursive_as_dict(obj):
137            if isinstance(obj, (list, tuple)):
138                return [recursive_as_dict(it) for it in obj]
139            if isinstance(obj, dict):
140                return {kk: recursive_as_dict(vv) for kk, vv in obj.items()}
141            if hasattr(obj, "as_dict"):
142                return obj.as_dict()
143            return obj
144
145        for c in args:
146            if c != "self":
147                try:
148                    a = self.__getattribute__(c)
149                except AttributeError:
150                    try:
151                        a = self.__getattribute__("_" + c)
152                    except AttributeError:
153                        raise NotImplementedError(
154                            "Unable to automatically determine as_dict "
155                            "format from class. MSONAble requires all "
156                            "args to be present as either self.argname or "
157                            "self._argname, and kwargs to be present under"
158                            "a self.kwargs variable to automatically "
159                            "determine the dict format. Alternatively, "
160                            "you can implement both as_dict and from_dict."
161                        )
162                d[c] = recursive_as_dict(a)
163        if hasattr(self, "kwargs"):
164            # type: ignore
165            d.update(**getattr(self, "kwargs"))  # pylint: disable=E1101
166        if spec.varargs is not None and getattr(self, spec.varargs, None) is not None:
167            d.update({spec.varargs: getattr(self, spec.varargs)})
168        if hasattr(self, "_kwargs"):
169            d.update(**getattr(self, "_kwargs"))  # pylint: disable=E1101
170        if isinstance(self, Enum):
171            d.update({"value": self.value})  # pylint: disable=E1101
172        return d
173
174    @classmethod
175    def from_dict(cls, d):
176        """
177        :param d: Dict representation.
178        :return: MSONable class.
179        """
180        decoded = {k: MontyDecoder().process_decoded(v) for k, v in d.items() if not k.startswith("@")}
181        return cls(**decoded)
182
183    def to_json(self) -> str:
184        """
185        Returns a json string representation of the MSONable object.
186        """
187        return json.dumps(self, cls=MontyEncoder)
188
189    def unsafe_hash(self):
190        """
191        Returns an hash of the current object. This uses a generic but low
192        performance method of converting the object to a dictionary, flattening
193        any nested keys, and then performing a hash on the resulting object
194        """
195
196        def flatten(obj, seperator="."):
197            # Flattens a dictionary
198
199            flat_dict = {}
200            for key, value in obj.items():
201                if isinstance(value, dict):
202                    flat_dict.update({seperator.join([key, _key]): _value for _key, _value in flatten(value).items()})
203                elif isinstance(value, list):
204                    list_dict = {"{}{}{}".format(key, seperator, num): item for num, item in enumerate(value)}
205                    flat_dict.update(flatten(list_dict))
206                else:
207                    flat_dict[key] = value
208
209            return flat_dict
210
211        ordered_keys = sorted(flatten(jsanitize(self.as_dict())).items(), key=lambda x: x[0])
212        ordered_keys = [item for item in ordered_keys if "@" not in item[0]]
213        return sha1(json.dumps(OrderedDict(ordered_keys)).encode("utf-8"))
214
215    @classmethod
216    def __get_validators__(cls):
217        """Return validators for use in pydantic"""
218        yield cls.validate_monty
219
220    @classmethod
221    def validate_monty(cls, v):
222        """
223        pydantic Validator for MSONable pattern
224        """
225        if isinstance(v, cls):
226            return v
227        if isinstance(v, dict):
228            new_obj = MontyDecoder().process_decoded(v)
229            if isinstance(new_obj, cls):
230                return new_obj
231
232            new_obj = cls(**v)
233            return new_obj
234
235        raise ValueError(f"Must provide {cls.__name__}, the as_dict form, or the proper")
236
237    @classmethod
238    def __modify_schema__(cls, field_schema):
239        """JSON schema for MSONable pattern"""
240        field_schema.update(
241            {
242                "type": "object",
243                "properties": {
244                    "@class": {"enum": [cls.__name__], "type": "string"},
245                    "@module": {"enum": [cls.__module__], "type": "string"},
246                    "@version": {"type": "string"},
247                },
248                "required": ["@class", "@module"],
249            }
250        )
251
252
253class MontyEncoder(json.JSONEncoder):
254    """
255    A Json Encoder which supports the MSONable API, plus adds support for
256    numpy arrays, datetime objects, bson ObjectIds (requires bson).
257
258    Usage::
259
260        # Add it as a *cls* keyword when using json.dump
261        json.dumps(object, cls=MontyEncoder)
262    """
263
264    def default(self, o) -> dict:  # pylint: disable=E0202
265        """
266        Overriding default method for JSON encoding. This method does two
267        things: (a) If an object has a to_dict property, return the to_dict
268        output. (b) If the @module and @class keys are not in the to_dict,
269        add them to the output automatically. If the object has no to_dict
270        property, the default Python json encoder default method is called.
271
272        Args:
273            o: Python object.
274
275        Return:
276            Python dict representation.
277        """
278        if isinstance(o, datetime.datetime):
279            return {"@module": "datetime", "@class": "datetime", "string": o.__str__()}
280        if isinstance(o, UUID):
281            return {"@module": "uuid", "@class": "UUID", "string": o.__str__()}
282
283        if np is not None:
284            if isinstance(o, np.ndarray):
285                if str(o.dtype).startswith("complex"):
286                    return {
287                        "@module": "numpy",
288                        "@class": "array",
289                        "dtype": o.dtype.__str__(),
290                        "data": [o.real.tolist(), o.imag.tolist()],
291                    }
292                return {
293                    "@module": "numpy",
294                    "@class": "array",
295                    "dtype": o.dtype.__str__(),
296                    "data": o.tolist(),
297                }
298            if isinstance(o, np.generic):
299                return o.item()
300
301        if pd is not None:
302            if isinstance(o, pd.DataFrame):
303                return {
304                    "@module": "pandas",
305                    "@class": "DataFrame",
306                    "data": o.to_json(default_handler=MontyEncoder().encode),
307                }
308
309        if bson is not None:
310            if isinstance(o, bson.objectid.ObjectId):
311                return {"@module": "bson.objectid", "@class": "ObjectId", "oid": str(o)}
312
313        if callable(o) and not isinstance(o, MSONable):
314            return _serialize_callable(o)
315
316        try:
317            if pydantic is not None and isinstance(o, pydantic.BaseModel):
318                d = o.dict()
319            else:
320                d = o.as_dict()
321
322            if "@module" not in d:
323                d["@module"] = "{}".format(o.__class__.__module__)
324            if "@class" not in d:
325                d["@class"] = "{}".format(o.__class__.__name__)
326            if "@version" not in d:
327                try:
328                    parent_module = o.__class__.__module__.split(".")[0]
329                    module_version = import_module(parent_module).__version__  # type: ignore
330                    d["@version"] = "{}".format(module_version)
331                except (AttributeError, ImportError):
332                    d["@version"] = None
333            return d
334        except AttributeError:
335            return json.JSONEncoder.default(self, o)
336
337
338class MontyDecoder(json.JSONDecoder):
339    """
340    A Json Decoder which supports the MSONable API. By default, the
341    decoder attempts to find a module and name associated with a dict. If
342    found, the decoder will generate a Pymatgen as a priority.  If that fails,
343    the original decoded dictionary from the string is returned. Note that
344    nested lists and dicts containing pymatgen object will be decoded correctly
345    as well.
346
347    Usage:
348
349        # Add it as a *cls* keyword when using json.load
350        json.loads(json_string, cls=MontyDecoder)
351    """
352
353    def process_decoded(self, d):
354        """
355        Recursive method to support decoding dicts and lists containing
356        pymatgen objects.
357        """
358        if isinstance(d, dict):
359            if "@module" in d and "@class" in d:
360                modname = d["@module"]
361                classname = d["@class"]
362                if classname in MSONable.REDIRECT.get(modname, {}):
363                    modname = MSONable.REDIRECT[modname][classname]["@module"]
364                    classname = MSONable.REDIRECT[modname][classname]["@class"]
365            elif "@module" in d and "@callable" in d:
366                modname = d["@module"]
367                objname = d["@callable"]
368                if d.get("@bound", None) is not None:
369                    # if the function is bound to an instance or class, first
370                    # deserialize the bound object and then remove the object name
371                    # from the function name.
372                    obj = self.process_decoded(d["@bound"])
373                    objname = objname.split(".")[1:]
374                else:
375                    # if the function is not bound to an object, import the
376                    # function from the module name
377                    obj = __import__(modname, globals(), locals(), [objname], 0)
378                    objname = objname.split(".")
379                try:
380                    # the function could be nested. e.g., MyClass.NestedClass.function
381                    # so iteratively access the nesting
382                    for attr in objname:
383                        obj = getattr(obj, attr)
384
385                    return obj
386
387                except AttributeError:
388                    pass
389            else:
390                modname = None
391                classname = None
392            if modname and modname not in ["bson.objectid", "numpy", "pandas"]:
393                if modname == "datetime" and classname == "datetime":
394                    try:
395                        dt = datetime.datetime.strptime(d["string"], "%Y-%m-%d %H:%M:%S.%f")
396                    except ValueError:
397                        dt = datetime.datetime.strptime(d["string"], "%Y-%m-%d %H:%M:%S")
398                    return dt
399
400                if modname == "uuid" and classname == "UUID":
401                    return UUID(d["string"])
402
403                mod = __import__(modname, globals(), locals(), [classname], 0)
404                if hasattr(mod, classname):
405                    cls_ = getattr(mod, classname)
406                    data = {k: v for k, v in d.items() if not k.startswith("@")}
407                    if hasattr(cls_, "from_dict"):
408                        return cls_.from_dict(data)
409                    if pydantic is not None and issubclass(cls_, pydantic.BaseModel):
410                        return cls_(**data)
411            elif np is not None and modname == "numpy" and classname == "array":
412                if d["dtype"].startswith("complex"):
413                    return np.array(
414                        [np.array(r) + np.array(i) * 1j for r, i in zip(*d["data"])],
415                        dtype=d["dtype"],
416                    )
417                return np.array(d["data"], dtype=d["dtype"])
418            elif pd is not None and modname == "pandas" and classname == "DataFrame":
419                decoded_data = MontyDecoder().decode(d["data"])
420                return pd.DataFrame(decoded_data)
421            elif (bson is not None) and modname == "bson.objectid" and classname == "ObjectId":
422                return bson.objectid.ObjectId(d["oid"])
423
424            return {self.process_decoded(k): self.process_decoded(v) for k, v in d.items()}
425
426        if isinstance(d, list):
427            return [self.process_decoded(x) for x in d]
428
429        return d
430
431    def decode(self, s):
432        """
433        Overrides decode from JSONDecoder.
434
435        :param s: string
436        :return: Object.
437        """
438        d = json.JSONDecoder.decode(self, s)
439        return self.process_decoded(d)
440
441
442class MSONError(Exception):
443    """
444    Exception class for serialization errors.
445    """
446
447
448def jsanitize(obj, strict=False, allow_bson=False, enum_values=False):
449    """
450    This method cleans an input json-like object, either a list or a dict or
451    some sequence, nested or otherwise, by converting all non-string
452    dictionary keys (such as int and float) to strings, and also recursively
453    encodes all objects using Monty's as_dict() protocol.
454
455    Args:
456        obj: input json-like object.
457        strict (bool): This parameters sets the behavior when jsanitize
458            encounters an object it does not understand. If strict is True,
459            jsanitize will try to get the as_dict() attribute of the object. If
460            no such attribute is found, an attribute error will be thrown. If
461            strict is False, jsanitize will simply call str(object) to convert
462            the object to a string representation.
463        allow_bson (bool): This parameters sets the behavior when jsanitize
464            encounters an bson supported type such as objectid and datetime. If
465            True, such bson types will be ignored, allowing for proper
466            insertion into MongoDb databases.
467        enum_values (bool): Convert Enums to their values.
468
469    Returns:
470        Sanitized dict that can be json serialized.
471    """
472    if isinstance(obj, Enum) and enum_values:
473        return obj.value
474
475    if allow_bson and (
476        isinstance(obj, (datetime.datetime, bytes)) or (bson is not None and isinstance(obj, bson.objectid.ObjectId))
477    ):
478        return obj
479    if isinstance(obj, (list, tuple)):
480        return [jsanitize(i, strict=strict, allow_bson=allow_bson, enum_values=enum_values) for i in obj]
481    if np is not None and isinstance(obj, np.ndarray):
482        return [jsanitize(i, strict=strict, allow_bson=allow_bson, enum_values=enum_values) for i in obj.tolist()]
483    if np is not None and isinstance(obj, np.generic):
484        return obj.item()
485    if isinstance(obj, dict):
486        return {
487            k.__str__(): jsanitize(v, strict=strict, allow_bson=allow_bson, enum_values=enum_values)
488            for k, v in obj.items()
489        }
490    if isinstance(obj, (int, float)):
491        return obj
492    if obj is None:
493        return None
494
495    if callable(obj) and not isinstance(obj, MSONable):
496        try:
497            return _serialize_callable(obj)
498        except TypeError:
499            pass
500
501    if not strict:
502        return obj.__str__()
503
504    if isinstance(obj, str):
505        return obj.__str__()
506
507    if pydantic is not None and isinstance(obj, pydantic.BaseModel):
508        return jsanitize(MontyEncoder().default(obj), strict=strict, allow_bson=allow_bson, enum_values=enum_values)
509
510    return jsanitize(obj.as_dict(), strict=strict, allow_bson=allow_bson, enum_values=enum_values)
511
512
513def _serialize_callable(o):
514    if isinstance(o, types.BuiltinFunctionType):
515        # don't care about what builtin functions (sum, open, etc) are bound to
516        bound = None
517    else:
518        # bound methods (i.e., instance methods) have a __self__ attribute
519        # that points to the class/module/instance
520        bound = getattr(o, "__self__", None)
521
522    # we are only able to serialize bound methods if the object the method is
523    # bound to is itself serializable
524    if bound is not None:
525        try:
526            bound = MontyEncoder().default(bound)
527        except TypeError:
528            raise TypeError("Only bound methods of classes or MSONable instances are supported.")
529
530    return {
531        "@module": o.__module__,
532        "@callable": getattr(o, "__qualname__", o.__name__),
533        "@bound": bound,
534    }
535