1from __future__ import absolute_import, division, print_function
2__metaclass__ = type
3from ansible.module_utils.basic import AnsibleModule, missing_required_lib
4from ansible.module_utils.six.moves import configparser
5from distutils.version import LooseVersion
6import traceback
7import os
8import ssl as ssl_lib
9
10MongoClient = None
11PYMONGO_IMP_ERR = None
12pymongo_found = None
13PyMongoVersion = None
14ConnectionFailure = None
15OperationFailure = None
16
17try:
18    from pymongo.errors import ConnectionFailure
19    from pymongo.errors import OperationFailure
20    from pymongo import version as PyMongoVersion
21    from pymongo import MongoClient
22    pymongo_found = True
23except ImportError:
24    PYMONGO_IMP_ERR = traceback.format_exc()
25    pymongo_found = False
26
27
28def check_compatibility(module, srv_version, driver_version):
29    """Check the compatibility between the driver and the database.
30
31    See: https://docs.mongodb.com/ecosystem/drivers/driver-compatibility-reference/#python-driver-compatibility
32
33    Args:
34        module: Ansible module.
35        srv_version (LooseVersion): MongoDB server version.
36        driver_version (LooseVersion): Pymongo version.
37    """
38    msg = 'pymongo driver version and MongoDB version are incompatible: '
39
40    if srv_version >= LooseVersion('4.2') and driver_version < LooseVersion('3.9'):
41        msg += 'you must use pymongo 3.9+ with MongoDB >= 4.2'
42        module.fail_json(msg=msg)
43
44    elif srv_version >= LooseVersion('4.0') and driver_version < LooseVersion('3.7'):
45        msg += 'you must use pymongo 3.7+ with MongoDB >= 4.0'
46        module.fail_json(msg=msg)
47
48    elif srv_version >= LooseVersion('3.6') and driver_version < LooseVersion('3.6'):
49        msg += 'you must use pymongo 3.6+ with MongoDB >= 3.6'
50        module.fail_json(msg=msg)
51
52    elif srv_version >= LooseVersion('3.4') and driver_version < LooseVersion('3.4'):
53        msg += 'you must use pymongo 3.4+ with MongoDB >= 3.4'
54        module.fail_json(msg=msg)
55
56    elif srv_version >= LooseVersion('3.2') and driver_version < LooseVersion('3.2'):
57        msg += 'you must use pymongo 3.2+ with MongoDB >= 3.2'
58        module.fail_json(msg=msg)
59
60    elif srv_version >= LooseVersion('3.0') and driver_version <= LooseVersion('2.8'):
61        msg += 'you must use pymongo 2.8+ with MongoDB 3.0'
62        module.fail_json(msg=msg)
63
64    elif srv_version >= LooseVersion('2.6') and driver_version <= LooseVersion('2.7'):
65        msg += 'you must use pymongo 2.7+ with MongoDB 2.6'
66        module.fail_json(msg=msg)
67
68
69def load_mongocnf():
70    config = configparser.RawConfigParser()
71    mongocnf = os.path.expanduser('~/.mongodb.cnf')
72
73    try:
74        config.readfp(open(mongocnf))
75    except (configparser.NoOptionError, IOError):
76        return False
77
78    creds = dict(
79        user=config.get('client', 'user'),
80        password=config.get('client', 'pass')
81    )
82
83    return creds
84
85
86def index_exists(client, database, collection, index_name):
87    """
88    Returns true if an index on the collection exists with the given name
89    @client: MongoDB connection.
90    @database: MongoDB Database.
91    @collection: MongoDB collection.
92    @index_name: The index name.
93    """
94    exists = False
95    indexes = client[database][collection].list_indexes()
96    for index in indexes:
97        if index["name"] == index_name:
98            exists = True
99    return exists
100
101
102def create_index(client, database, collection, keys, options):
103    """
104    Creates an index on the given collection
105    @client: MongoDB connection.
106    @database: MongoDB Database - str.
107    @collection: MongoDB collection - str.
108    @keys: Specification of index - dict.
109    """
110    client[database][collection].create_index(list(keys.items()),
111                                              **options)
112
113
114def drop_index(client, database, collection, index_name):
115    client[database][collection].drop_index(index_name)
116
117
118def member_state(client):
119    """Check if a replicaset exists.
120
121    Args:
122        client (cursor): Mongodb cursor on admin database.
123
124    Returns:
125        str: member state i.e. PRIMARY, SECONDARY
126    """
127    state = None
128    doc = client['admin'].command('replSetGetStatus')
129    for member in doc["members"]:
130        if "self" in member.keys():
131            state = str(member['stateStr'])
132    return state
133
134
135def mongodb_common_argument_spec(ssl_options=True):
136    """
137    Returns a dict containing common options shared across the MongoDB modules.
138    """
139    options = dict(
140        login_user=dict(type='str', required=False),
141        login_password=dict(type='str', required=False, no_log=True),
142        login_database=dict(type='str', required=False, default='admin'),
143        login_host=dict(type='str', required=False, default='localhost'),
144        login_port=dict(type='int', required=False, default=27017),
145    )
146    ssl_options_dict = dict(
147        ssl=dict(type='bool', required=False, default=False),
148        ssl_cert_reqs=dict(type='str',
149                           required=False,
150                           default='CERT_REQUIRED',
151                           choices=['CERT_NONE',
152                                    'CERT_OPTIONAL',
153                                    'CERT_REQUIRED']),
154        ssl_ca_certs=dict(type='str', default=None),
155        ssl_crlfile=dict(type='str', default=None),
156        ssl_certfile=dict(type='str', default=None),
157        ssl_keyfile=dict(type='str', default=None, no_log=True),
158        ssl_pem_passphrase=dict(type='str', default=None, no_log=True),
159        auth_mechanism=dict(type='str',
160                            required=False,
161                            default=None,
162                            choices=['SCRAM-SHA-256',
163                                     'SCRAM-SHA-1',
164                                     'MONGODB-X509',
165                                     'GSSAPI',
166                                     'PLAIN']),
167        connection_options=dict(type='list',
168                                elements='raw',
169                                default=None)
170    )
171    if ssl_options:
172        options.update(ssl_options_dict)
173    return options
174
175
176def add_option_if_not_none(param_name, module, connection_params):
177    '''
178    @param_name - The parameter name to check
179    @module - The ansible module object
180    @connection_params - Dict containing the connection params
181    '''
182    if module.params[param_name] is not None:
183        connection_params[param_name] = module.params[param_name]
184    return connection_params
185
186
187def ssl_connection_options(connection_params, module):
188    connection_params['ssl'] = True
189    if module.params['ssl_cert_reqs'] is not None:
190        connection_params['ssl_cert_reqs'] = getattr(ssl_lib, module.params['ssl_cert_reqs'])
191    connection_params = add_option_if_not_none('ssl_ca_certs', module, connection_params)
192    connection_params = add_option_if_not_none('ssl_crlfile', module, connection_params)
193    connection_params = add_option_if_not_none('ssl_certfile', module, connection_params)
194    connection_params = add_option_if_not_none('ssl_keyfile', module, connection_params)
195    connection_params = add_option_if_not_none('ssl_pem_passphrase', module, connection_params)
196    if module.params['auth_mechanism'] is not None:
197        connection_params['authMechanism'] = module.params['auth_mechanism']
198    if module.params['connection_options'] is not None:
199        for item in module.params['connection_options']:
200            if isinstance(item, dict):
201                for key, value in item.items():
202                    connection_params[key] = value
203            elif isinstance(item, str) and "=" in item:
204                connection_params[item.split('=')[0]] = item.split('=')[1]
205            else:
206                raise ValueError("Invalid value supplied in connection_options: {0} .".format(str(item)))
207    return connection_params
208