1# coding: utf-8 2# Copyright (c) Pymatgen Development Team. 3# Distributed under the terms of the MIT License. 4""" 5Objects and helper function used to store the results in a MongoDb database 6""" 7import collections 8import copy 9 10from .utils import as_bool 11 12 13#def mongo_getattr(rec, key): 14# """ 15# Get value from dict using MongoDB dot-separated path semantics. 16# For example: 17# 18# >>> assert mongo_getattr({'a': {'b': 1}, 'x': 2}, 'a.b') == 1 19# >>> assert mongo_getattr({'a': {'b': 1}, 'x': 2}, 'x') == 2 20# >>> assert mongo_getattr({'a': {'b': 1}, 'x': 2}, 'a.b.c') is None 21# 22# :param rec: mongodb document 23# :param key: path to mongo value 24# :param default: default to return if not found 25# :return: value, potentially nested, or default if not found 26# :raise: AttributeError, if record is not a dict or key is not found. 27# """ 28# if not isinstance(rec, collections..abc.Mapping): 29# raise AttributeError('input record must act like a dict') 30# if not rec: 31# raise AttributeError('Empty dict') 32# 33# if not '.' in key: 34# return rec.get(key) 35# 36# for key_part in key.split('.'): 37# if not isinstance(rec, collections.abc.Mapping): 38# raise AttributeError('not a mapping for rec_part %s' % key_part) 39# if not key_part in rec: 40# raise AttributeError('key %s not in dict %s' % key) 41# rec = rec[key_part] 42# 43# return rec 44 45 46def scan_nestdict(d, key): 47 """ 48 Scan a nested dict d, and return the first value associated to the given key. 49 Returns None if key is not found. 50 51 >>> d = {0: 1, 1: {"hello": {"world": {None: [1,2,3]}}}, "foo": [{"bar": 1}, {"color": "red"}]} 52 >>> assert scan_nestdict(d, 1) == {"hello": {"world": {None: [1,2,3]}}} 53 >>> assert scan_nestdict(d, "hello") == {"world": {None: [1,2,3]}} 54 >>> assert scan_nestdict(d, "world") == {None: [1,2,3]} 55 >>> assert scan_nestdict(d, None) == [1,2,3] 56 >>> assert scan_nestdict(d, "color") == "red" 57 """ 58 if isinstance(d, (list, tuple)): 59 for item in d: 60 res = scan_nestdict(item, key) 61 if res is not None: 62 return res 63 return None 64 65 if not isinstance(d, collections.abc.Mapping): 66 return None 67 68 if key in d: 69 return d[key] 70 else: 71 for v in d.values(): 72 res = scan_nestdict(v, key) 73 if res is not None: 74 return res 75 return None 76 77 78class DBConnector(object): 79 80 #DEFAULTS = dict( 81 # database="abinit", 82 # collection=None, 83 # port=None, 84 # host=None, 85 # user=None, 86 # password=None, 87 #} 88 89 @classmethod 90 def autodoc(cls): 91 return """ 92 enabled: # yes or no (default yes) 93 database: # Name of the mongodb database (default abinit) 94 collection: # Name of the collection (default test) 95 host: # host address e.g. 0.0.0.0 (default None) 96 port: # port e.g. 8080 (default None) 97 user: # user name (default None) 98 password: # password for authentication (default None) 99 """ 100 101 def __init__(self, **kwargs): 102 if not kwargs: 103 self.enabled = False 104 return 105 106 self.enabled = as_bool(kwargs.pop("enabled", True)) 107 self.dbname = kwargs.pop("database", "abinit") 108 self.collection = kwargs.pop("collection", "test") 109 self.host = kwargs.pop("host", None) 110 self.port = kwargs.pop("port", None) 111 self.user = kwargs.pop("user", None) 112 self.password = kwargs.pop("password", None) 113 114 if kwargs: 115 raise ValueError("Found invalid keywords in the database section:\n %s" % kwargs.keys()) 116 117 def __bool__(self): 118 return self.enabled 119 120 __nonzero__ = __bool__ 121 122 def __repr__(self): 123 return "<%s object at %s>" % (self.__class__.__name__, id(self)) 124 125 #def __str__(self): 126 # return str(self.config) 127 128 def deepcopy(self): 129 return copy.deepcopy(self) 130 131 def set_collection_name(self, value): 132 """Set the name of the collection, return old value""" 133 old = self.collection 134 self.collection = str(value) 135 return old 136 137 def get_collection(self, **kwargs): 138 """ 139 Establish a connection with the database. 140 141 Returns MongoDb collection 142 """ 143 from pymongo import MongoClient 144 145 if self.host and self.port: 146 client = MongoClient(host=self.config.host, port=self.config.port) 147 else: 148 client = MongoClient() 149 db = client[self.dbname] 150 151 # Authenticate if needed 152 if self.user and self.password: 153 db.autenticate(self.user, password=self.password) 154 155 return db[self.collection] 156 157 158if __name__ == "__main__": 159 connector = DBConnector() 160 print(connector.get_collection()) 161 #connector.set_collection_name("foo") 162 print(connector) 163 print(connector.get_collection()) 164 165 #import unittest 166 #unittest.main() 167