1# coding: utf-8
2# Copyright (c) Pymatgen Development Team.
3# Distributed under the terms of the MIT License.
4"""
5Objects and helper function used to store the results in a MongoDb database
6"""
7import collections
8import copy
9
10from .utils import as_bool
11
12
13#def mongo_getattr(rec, key):
14#    """
15#    Get value from dict using MongoDB dot-separated path semantics.
16#    For example:
17#
18#    >>> assert mongo_getattr({'a': {'b': 1}, 'x': 2}, 'a.b') == 1
19#    >>> assert mongo_getattr({'a': {'b': 1}, 'x': 2}, 'x') == 2
20#    >>> assert mongo_getattr({'a': {'b': 1}, 'x': 2}, 'a.b.c') is None
21#
22#    :param rec: mongodb document
23#    :param key: path to mongo value
24#    :param default: default to return if not found
25#    :return: value, potentially nested, or default if not found
26#    :raise: AttributeError, if record is not a dict or key is not found.
27#    """
28#    if not isinstance(rec, collections..abc.Mapping):
29#        raise AttributeError('input record must act like a dict')
30#    if not rec:
31#        raise AttributeError('Empty dict')
32#
33#    if not '.' in key:
34#        return rec.get(key)
35#
36#    for key_part in key.split('.'):
37#        if not isinstance(rec, collections.abc.Mapping):
38#            raise AttributeError('not a mapping for rec_part %s' % key_part)
39#        if not key_part in rec:
40#            raise AttributeError('key %s not in dict %s' % key)
41#        rec = rec[key_part]
42#
43#    return rec
44
45
46def scan_nestdict(d, key):
47    """
48    Scan a nested dict d, and return the first value associated to the given key.
49    Returns None if key is not found.
50
51    >>> d = {0: 1, 1: {"hello": {"world": {None: [1,2,3]}}}, "foo": [{"bar": 1}, {"color": "red"}]}
52    >>> assert scan_nestdict(d, 1) == {"hello": {"world": {None: [1,2,3]}}}
53    >>> assert scan_nestdict(d, "hello") == {"world": {None: [1,2,3]}}
54    >>> assert scan_nestdict(d, "world") == {None: [1,2,3]}
55    >>> assert scan_nestdict(d, None) == [1,2,3]
56    >>> assert scan_nestdict(d, "color") == "red"
57    """
58    if isinstance(d, (list, tuple)):
59        for item in d:
60            res = scan_nestdict(item, key)
61            if res is not None:
62                return res
63        return None
64
65    if not isinstance(d, collections.abc.Mapping):
66        return None
67
68    if key in d:
69        return d[key]
70    else:
71        for v in d.values():
72            res = scan_nestdict(v, key)
73            if res is not None:
74                return res
75        return None
76
77
78class DBConnector(object):
79
80    #DEFAULTS = dict(
81    #    database="abinit",
82    #    collection=None,
83    #    port=None,
84    #    host=None,
85    #    user=None,
86    #    password=None,
87    #}
88
89    @classmethod
90    def autodoc(cls):
91        return """
92     enabled:     # yes or no (default yes)
93     database:    # Name of the mongodb database (default abinit)
94     collection:  # Name of the collection (default test)
95     host:        # host address e.g. 0.0.0.0 (default None)
96     port:        # port e.g. 8080 (default None)
97     user:        # user name (default None)
98     password:    # password for authentication (default None)
99     """
100
101    def __init__(self, **kwargs):
102        if not kwargs:
103            self.enabled = False
104            return
105
106        self.enabled = as_bool(kwargs.pop("enabled", True))
107        self.dbname = kwargs.pop("database", "abinit")
108        self.collection = kwargs.pop("collection", "test")
109        self.host = kwargs.pop("host", None)
110        self.port = kwargs.pop("port", None)
111        self.user = kwargs.pop("user", None)
112        self.password = kwargs.pop("password", None)
113
114        if kwargs:
115            raise ValueError("Found invalid keywords in the database section:\n %s" % kwargs.keys())
116
117    def __bool__(self):
118        return self.enabled
119
120    __nonzero__ = __bool__
121
122    def __repr__(self):
123        return "<%s object at %s>" % (self.__class__.__name__, id(self))
124
125    #def __str__(self):
126    #    return str(self.config)
127
128    def deepcopy(self):
129        return copy.deepcopy(self)
130
131    def set_collection_name(self, value):
132        """Set the name of the collection, return old value"""
133        old = self.collection
134        self.collection = str(value)
135        return old
136
137    def get_collection(self, **kwargs):
138        """
139        Establish a connection with the database.
140
141        Returns MongoDb collection
142        """
143        from pymongo import MongoClient
144
145        if self.host and self.port:
146            client = MongoClient(host=self.config.host, port=self.config.port)
147        else:
148            client = MongoClient()
149        db = client[self.dbname]
150
151        # Authenticate if needed
152        if self.user and self.password:
153            db.autenticate(self.user, password=self.password)
154
155        return db[self.collection]
156
157
158if __name__ == "__main__":
159    connector = DBConnector()
160    print(connector.get_collection())
161    #connector.set_collection_name("foo")
162    print(connector)
163    print(connector.get_collection())
164
165    #import unittest
166    #unittest.main()
167