1# -*- coding: utf-8 -*-
2# Copyright (c) 2015, René Moser <mail@renemoser.net>
3# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause)
4
5from __future__ import absolute_import, division, print_function
6__metaclass__ = type
7
8
9import os
10import sys
11import time
12import traceback
13
14from ansible.module_utils._text import to_text, to_native
15from ansible.module_utils.basic import missing_required_lib, env_fallback
16
17CS_IMP_ERR = None
18try:
19    from cs import CloudStack, CloudStackException
20    HAS_LIB_CS = True
21except ImportError:
22    CS_IMP_ERR = traceback.format_exc()
23    HAS_LIB_CS = False
24
25
26if sys.version_info > (3,):
27    long = int
28
29
30def cs_argument_spec():
31    return dict(
32        api_key=dict(type='str', fallback=(env_fallback, ['CLOUDSTACK_KEY']), required=True, no_log=False),
33        api_secret=dict(type='str', fallback=(env_fallback, ['CLOUDSTACK_SECRET']), required=True, no_log=True),
34        api_url=dict(type='str', fallback=(env_fallback, ['CLOUDSTACK_ENDPOINT']), required=True),
35        api_http_method=dict(type='str', fallback=(env_fallback, ['CLOUDSTACK_METHOD']), choices=['get', 'post'], default='get'),
36        api_timeout=dict(type='int', fallback=(env_fallback, ['CLOUDSTACK_TIMEOUT']), default=10),
37        api_verify_ssl_cert=dict(type='str', fallback=(env_fallback, ['CLOUDSTACK_VERIFY'])),
38    )
39
40
41def cs_required_together():
42    return []
43
44
45class AnsibleCloudStack:
46
47    def __init__(self, module):
48        if not HAS_LIB_CS:
49            module.fail_json(msg=missing_required_lib('cs'), exception=CS_IMP_ERR)
50
51        self.result = {
52            'changed': False,
53            'diff': {
54                'before': dict(),
55                'after': dict()
56            }
57        }
58
59        # Common returns, will be merged with self.returns
60        # search_for_key: replace_with_key
61        self.common_returns = {
62            'id': 'id',
63            'name': 'name',
64            'created': 'created',
65            'zonename': 'zone',
66            'state': 'state',
67            'project': 'project',
68            'account': 'account',
69            'domain': 'domain',
70            'displaytext': 'display_text',
71            'displayname': 'display_name',
72            'description': 'description',
73            'tags': 'tags',
74        }
75
76        # Init returns dict for use in subclasses
77        self.returns = {}
78        # these values will be casted to int
79        self.returns_to_int = {}
80        # these keys will be compared case sensitive in self.has_changed()
81        self.case_sensitive_keys = [
82            'id',
83            'displaytext',
84            'displayname',
85            'description',
86        ]
87
88        self.module = module
89        self._cs = None
90
91        # Helper for VPCs
92        self._vpc_networks_ids = None
93
94        self.domain = None
95        self.account = None
96        self.project = None
97        self.ip_address = None
98        self.network = None
99        self.physical_network = None
100        self.vpc = None
101        self.zone = None
102        self.vm = None
103        self.vm_default_nic = None
104        self.os_type = None
105        self.hypervisor = None
106        self.capabilities = None
107        self.network_acl = None
108
109    @property
110    def cs(self):
111        if self._cs is None:
112            api_config = self.get_api_config()
113            self._cs = CloudStack(**api_config)
114        return self._cs
115
116    def get_api_config(self):
117        api_config = {
118            'endpoint': self.module.params.get('api_url'),
119            'key': self.module.params.get('api_key'),
120            'secret': self.module.params.get('api_secret'),
121            'timeout': self.module.params.get('api_timeout'),
122            'method': self.module.params.get('api_http_method'),
123            'verify': self.module.params.get('api_verify_ssl_cert'),
124        }
125        self.result.update({
126            'api_url': api_config['endpoint'],
127            'api_key': api_config['key'],
128            'api_timeout': int(api_config['timeout']),
129            'api_http_method': api_config['method'],
130            'api_verify_ssl_cert': api_config['verify'],
131        })
132        return api_config
133
134    def fail_json(self, **kwargs):
135        self.result.update(kwargs)
136        self.module.fail_json(**self.result)
137
138    def get_or_fallback(self, key=None, fallback_key=None):
139        value = self.module.params.get(key)
140        if not value:
141            value = self.module.params.get(fallback_key)
142        return value
143
144    def has_changed(self, want_dict, current_dict, only_keys=None, skip_diff_for_keys=None):
145        result = False
146        for key, value in want_dict.items():
147
148            # Optionally limit by a list of keys
149            if only_keys and key not in only_keys:
150                continue
151
152            # Skip None values
153            if value is None:
154                continue
155
156            if key in current_dict:
157                if isinstance(value, (int, float, long, complex)):
158
159                    # ensure we compare the same type
160                    if isinstance(value, int):
161                        current_dict[key] = int(current_dict[key])
162
163                    elif isinstance(value, float):
164                        current_dict[key] = float(current_dict[key])
165
166                    elif isinstance(value, long):
167                        current_dict[key] = long(current_dict[key])
168
169                    elif isinstance(value, complex):
170                        current_dict[key] = complex(current_dict[key])
171
172                    if value != current_dict[key]:
173                        if skip_diff_for_keys and key not in skip_diff_for_keys:
174                            self.result['diff']['before'][key] = current_dict[key]
175                            self.result['diff']['after'][key] = value
176                        result = True
177                else:
178                    before_value = to_text(current_dict[key])
179                    after_value = to_text(value)
180
181                    if self.case_sensitive_keys and key in self.case_sensitive_keys:
182                        if before_value != after_value:
183                            if skip_diff_for_keys and key not in skip_diff_for_keys:
184                                self.result['diff']['before'][key] = before_value
185                                self.result['diff']['after'][key] = after_value
186                            result = True
187
188                    # Test for diff in case insensitive way
189                    elif before_value.lower() != after_value.lower():
190                        if skip_diff_for_keys and key not in skip_diff_for_keys:
191                            self.result['diff']['before'][key] = before_value
192                            self.result['diff']['after'][key] = after_value
193                        result = True
194            else:
195                if skip_diff_for_keys and key not in skip_diff_for_keys:
196                    self.result['diff']['before'][key] = None
197                    self.result['diff']['after'][key] = to_text(value)
198                result = True
199        return result
200
201    def _get_by_key(self, key=None, my_dict=None):
202        if my_dict is None:
203            my_dict = {}
204        if key:
205            if key in my_dict:
206                return my_dict[key]
207            self.fail_json(msg="Something went wrong: %s not found" % key)
208        return my_dict
209
210    def query_api(self, command, **args):
211        try:
212            res = getattr(self.cs, command)(**args)
213
214            if 'errortext' in res:
215                self.fail_json(msg="Failed: '%s'" % res['errortext'])
216
217        except CloudStackException as e:
218            self.fail_json(msg='CloudStackException: %s' % to_native(e))
219
220        except Exception as e:
221            self.fail_json(msg=to_native(e))
222
223        return res
224
225    def get_network_acl(self, key=None):
226        if self.network_acl is None:
227            args = {
228                'name': self.module.params.get('network_acl'),
229                'vpcid': self.get_vpc(key='id'),
230            }
231            network_acls = self.query_api('listNetworkACLLists', **args)
232            if network_acls:
233                self.network_acl = network_acls['networkacllist'][0]
234                self.result['network_acl'] = self.network_acl['name']
235        if self.network_acl:
236            return self._get_by_key(key, self.network_acl)
237        else:
238            self.fail_json(msg="Network ACL %s not found" % self.module.params.get('network_acl'))
239
240    def get_vpc(self, key=None):
241        """Return a VPC dictionary or the value of given key of."""
242        if self.vpc:
243            return self._get_by_key(key, self.vpc)
244
245        vpc = self.module.params.get('vpc')
246        if not vpc:
247            vpc = os.environ.get('CLOUDSTACK_VPC')
248        if not vpc:
249            return None
250
251        args = {
252            'account': self.get_account(key='name'),
253            'domainid': self.get_domain(key='id'),
254            'projectid': self.get_project(key='id'),
255            'zoneid': self.get_zone(key='id'),
256        }
257        vpcs = self.query_api('listVPCs', **args)
258        if not vpcs:
259            self.fail_json(msg="No VPCs available.")
260
261        for v in vpcs['vpc']:
262            if vpc in [v['name'], v['displaytext'], v['id']]:
263                # Fail if the identifyer matches more than one VPC
264                if self.vpc:
265                    self.fail_json(msg="More than one VPC found with the provided identifyer '%s'" % vpc)
266                else:
267                    self.vpc = v
268                    self.result['vpc'] = v['name']
269        if self.vpc:
270            return self._get_by_key(key, self.vpc)
271        self.fail_json(msg="VPC '%s' not found" % vpc)
272
273    def is_vpc_network(self, network_id):
274        """Returns True if network is in VPC."""
275        # This is an efficient way to query a lot of networks at a time
276        if self._vpc_networks_ids is None:
277            args = {
278                'account': self.get_account(key='name'),
279                'domainid': self.get_domain(key='id'),
280                'projectid': self.get_project(key='id'),
281                'zoneid': self.get_zone(key='id'),
282            }
283            vpcs = self.query_api('listVPCs', **args)
284            self._vpc_networks_ids = []
285            if vpcs:
286                for vpc in vpcs['vpc']:
287                    for n in vpc.get('network', []):
288                        self._vpc_networks_ids.append(n['id'])
289        return network_id in self._vpc_networks_ids
290
291    def get_physical_network(self, key=None):
292        if self.physical_network:
293            return self._get_by_key(key, self.physical_network)
294        physical_network = self.module.params.get('physical_network')
295        args = {
296            'zoneid': self.get_zone(key='id')
297        }
298        physical_networks = self.query_api('listPhysicalNetworks', **args)
299        if not physical_networks:
300            self.fail_json(msg="No physical networks available.")
301
302        for net in physical_networks['physicalnetwork']:
303            if physical_network in [net['name'], net['id']]:
304                self.physical_network = net
305                self.result['physical_network'] = net['name']
306                return self._get_by_key(key, self.physical_network)
307        self.fail_json(msg="Physical Network '%s' not found" % physical_network)
308
309    def get_network(self, key=None):
310        """Return a network dictionary or the value of given key of."""
311        if self.network:
312            return self._get_by_key(key, self.network)
313
314        network = self.module.params.get('network')
315        if not network:
316            vpc_name = self.get_vpc(key='name')
317            if vpc_name:
318                self.fail_json(msg="Could not find network for VPC '%s' due missing argument: network" % vpc_name)
319            return None
320
321        args = {
322            'account': self.get_account(key='name'),
323            'domainid': self.get_domain(key='id'),
324            'projectid': self.get_project(key='id'),
325            'zoneid': self.get_zone(key='id'),
326            'vpcid': self.get_vpc(key='id')
327        }
328        networks = self.query_api('listNetworks', **args)
329        if not networks:
330            self.fail_json(msg="No networks available.")
331
332        for n in networks['network']:
333            # ignore any VPC network if vpc param is not given
334            if 'vpcid' in n and not self.get_vpc(key='id'):
335                continue
336            if network in [n['displaytext'], n['name'], n['id']]:
337                self.result['network'] = n['name']
338                self.network = n
339                return self._get_by_key(key, self.network)
340        self.fail_json(msg="Network '%s' not found" % network)
341
342    def get_project(self, key=None):
343        if self.project:
344            return self._get_by_key(key, self.project)
345
346        project = self.module.params.get('project')
347        if not project:
348            project = os.environ.get('CLOUDSTACK_PROJECT')
349        if not project:
350            return None
351        args = {
352            'account': self.get_account(key='name'),
353            'domainid': self.get_domain(key='id')
354        }
355        projects = self.query_api('listProjects', **args)
356        if projects:
357            for p in projects['project']:
358                if project.lower() in [p['name'].lower(), p['id']]:
359                    self.result['project'] = p['name']
360                    self.project = p
361                    return self._get_by_key(key, self.project)
362        self.fail_json(msg="project '%s' not found" % project)
363
364    def get_pod(self, key=None):
365        pod_name = self.module.params.get('pod')
366        if not pod_name:
367            return None
368        args = {
369            'name': pod_name,
370            'zoneid': self.get_zone(key='id'),
371        }
372        pods = self.query_api('listPods', **args)
373        if pods:
374            return self._get_by_key(key, pods['pod'][0])
375        self.module.fail_json(msg="Pod %s not found in zone %s" % (self.module.params.get('pod'), self.get_zone(key='name')))
376
377    def get_ip_address(self, key=None):
378        if self.ip_address:
379            return self._get_by_key(key, self.ip_address)
380
381        ip_address = self.module.params.get('ip_address')
382        if not ip_address:
383            self.fail_json(msg="IP address param 'ip_address' is required")
384
385        args = {
386            'ipaddress': ip_address,
387            'account': self.get_account(key='name'),
388            'domainid': self.get_domain(key='id'),
389            'projectid': self.get_project(key='id'),
390            'vpcid': self.get_vpc(key='id'),
391        }
392
393        ip_addresses = self.query_api('listPublicIpAddresses', **args)
394
395        if not ip_addresses:
396            self.fail_json(msg="IP address '%s' not found" % args['ipaddress'])
397
398        self.ip_address = ip_addresses['publicipaddress'][0]
399        return self._get_by_key(key, self.ip_address)
400
401    def get_vm_guest_ip(self):
402        vm_guest_ip = self.module.params.get('vm_guest_ip')
403        default_nic = self.get_vm_default_nic()
404
405        if not vm_guest_ip:
406            return default_nic['ipaddress']
407
408        for secondary_ip in default_nic['secondaryip']:
409            if vm_guest_ip == secondary_ip['ipaddress']:
410                return vm_guest_ip
411        self.fail_json(msg="Secondary IP '%s' not assigned to VM" % vm_guest_ip)
412
413    def get_vm_default_nic(self):
414        if self.vm_default_nic:
415            return self.vm_default_nic
416
417        nics = self.query_api('listNics', virtualmachineid=self.get_vm(key='id'))
418        if nics:
419            for n in nics['nic']:
420                if n['isdefault']:
421                    self.vm_default_nic = n
422                    return self.vm_default_nic
423        self.fail_json(msg="No default IP address of VM '%s' found" % self.module.params.get('vm'))
424
425    def get_vm(self, key=None, filter_zone=True):
426        if self.vm:
427            return self._get_by_key(key, self.vm)
428
429        vm = self.module.params.get('vm')
430        if not vm:
431            self.fail_json(msg="Virtual machine param 'vm' is required")
432
433        args = {
434            'account': self.get_account(key='name'),
435            'domainid': self.get_domain(key='id'),
436            'projectid': self.get_project(key='id'),
437            'zoneid': self.get_zone(key='id') if filter_zone else None,
438            'fetch_list': True,
439        }
440        vms = self.query_api('listVirtualMachines', **args)
441        if vms:
442            for v in vms:
443                if vm.lower() in [v['name'].lower(), v['displayname'].lower(), v['id']]:
444                    self.vm = v
445                    return self._get_by_key(key, self.vm)
446        self.fail_json(msg="Virtual machine '%s' not found" % vm)
447
448    def get_disk_offering(self, key=None):
449        disk_offering = self.module.params.get('disk_offering')
450        if not disk_offering:
451            return None
452
453        # Do not add domain filter for disk offering listing.
454        disk_offerings = self.query_api('listDiskOfferings')
455        if disk_offerings:
456            for d in disk_offerings['diskoffering']:
457                if disk_offering in [d['displaytext'], d['name'], d['id']]:
458                    return self._get_by_key(key, d)
459        self.fail_json(msg="Disk offering '%s' not found" % disk_offering)
460
461    def get_zone(self, key=None):
462        if self.zone:
463            return self._get_by_key(key, self.zone)
464
465        zone = self.module.params.get('zone')
466        if not zone:
467            zone = os.environ.get('CLOUDSTACK_ZONE')
468        zones = self.query_api('listZones')
469
470        if not zones:
471            self.fail_json(msg="No zones available. Please create a zone first")
472
473        # this check is theoretically not required, as module argument specification should take care of it
474        # however, due to deprecated default zone is left behind just in case non obvious callers.
475        # Some modules benefit form the check anyway like those where zone if effectively optional like
476        # template registration (local/cross zone) or configuration (zone or global)
477        if not zone:
478            self.fail_json(msg="Zone is required due to unreliable API.")
479
480        if zones:
481            for z in zones['zone']:
482                if zone.lower() in [z['name'].lower(), z['id']]:
483                    self.result['zone'] = z['name']
484                    self.zone = z
485                    return self._get_by_key(key, self.zone)
486        self.fail_json(msg="zone '%s' not found" % zone)
487
488    def get_os_type(self, key=None):
489        if self.os_type:
490            return self._get_by_key(key, self.zone)
491
492        os_type = self.module.params.get('os_type')
493        if not os_type:
494            return None
495
496        os_types = self.query_api('listOsTypes')
497        if os_types:
498            for o in os_types['ostype']:
499                if os_type in [o['description'], o['id']]:
500                    self.os_type = o
501                    return self._get_by_key(key, self.os_type)
502        self.fail_json(msg="OS type '%s' not found" % os_type)
503
504    def get_hypervisor(self):
505        if self.hypervisor:
506            return self.hypervisor
507
508        hypervisor = self.module.params.get('hypervisor')
509        hypervisors = self.query_api('listHypervisors')
510
511        # use the first hypervisor if no hypervisor param given
512        if not hypervisor:
513            self.hypervisor = hypervisors['hypervisor'][0]['name']
514            return self.hypervisor
515
516        for h in hypervisors['hypervisor']:
517            if hypervisor.lower() == h['name'].lower():
518                self.hypervisor = h['name']
519                return self.hypervisor
520        self.fail_json(msg="Hypervisor '%s' not found" % hypervisor)
521
522    def get_account(self, key=None):
523        if self.account:
524            return self._get_by_key(key, self.account)
525
526        account = self.module.params.get('account')
527        if not account:
528            account = os.environ.get('CLOUDSTACK_ACCOUNT')
529        if not account:
530            return None
531
532        domain = self.module.params.get('domain')
533        if not domain:
534            self.fail_json(msg="Account must be specified with Domain")
535
536        args = {
537            'name': account,
538            'domainid': self.get_domain(key='id'),
539            'listall': True
540        }
541        accounts = self.query_api('listAccounts', **args)
542        if accounts:
543            self.account = accounts['account'][0]
544            self.result['account'] = self.account['name']
545            return self._get_by_key(key, self.account)
546        self.fail_json(msg="Account '%s' not found" % account)
547
548    def get_domain(self, key=None):
549        if self.domain:
550            return self._get_by_key(key, self.domain)
551
552        domain = self.module.params.get('domain')
553        if not domain:
554            domain = os.environ.get('CLOUDSTACK_DOMAIN')
555        if not domain:
556            return None
557
558        args = {
559            'listall': True,
560        }
561        domains = self.query_api('listDomains', **args)
562        if domains:
563            for d in domains['domain']:
564                if d['path'].lower() in [domain.lower(), "root/" + domain.lower(), "root" + domain.lower()]:
565                    self.domain = d
566                    self.result['domain'] = d['path']
567                    return self._get_by_key(key, self.domain)
568        self.fail_json(msg="Domain '%s' not found" % domain)
569
570    def query_tags(self, resource, resource_type):
571        args = {
572            'resourceid': resource['id'],
573            'resourcetype': resource_type,
574        }
575        tags = self.query_api('listTags', **args)
576        return self.get_tags(resource=tags, key='tag')
577
578    def get_tags(self, resource=None, key='tags'):
579        existing_tags = []
580        for tag in resource.get(key) or []:
581            existing_tags.append({'key': tag['key'], 'value': tag['value']})
582        return existing_tags
583
584    def _process_tags(self, resource, resource_type, tags, operation="create"):
585        if tags:
586            self.result['changed'] = True
587            if not self.module.check_mode:
588                args = {
589                    'resourceids': resource['id'],
590                    'resourcetype': resource_type,
591                    'tags': tags,
592                }
593                if operation == "create":
594                    response = self.query_api('createTags', **args)
595                else:
596                    response = self.query_api('deleteTags', **args)
597                self.poll_job(response)
598
599    def _tags_that_should_exist_or_be_updated(self, resource, tags):
600        existing_tags = self.get_tags(resource)
601        return [tag for tag in tags if tag not in existing_tags]
602
603    def _tags_that_should_not_exist(self, resource, tags):
604        existing_tags = self.get_tags(resource)
605        return [tag for tag in existing_tags if tag not in tags]
606
607    def ensure_tags(self, resource, resource_type=None):
608        if not resource_type or not resource:
609            self.fail_json(msg="Error: Missing resource or resource_type for tags.")
610
611        if 'tags' in resource:
612            tags = self.module.params.get('tags')
613            if tags is not None:
614                self._process_tags(resource, resource_type, self._tags_that_should_not_exist(resource, tags), operation="delete")
615                self._process_tags(resource, resource_type, self._tags_that_should_exist_or_be_updated(resource, tags))
616                resource['tags'] = self.query_tags(resource=resource, resource_type=resource_type)
617        return resource
618
619    def get_capabilities(self, key=None):
620        if self.capabilities:
621            return self._get_by_key(key, self.capabilities)
622        capabilities = self.query_api('listCapabilities')
623        self.capabilities = capabilities['capability']
624        return self._get_by_key(key, self.capabilities)
625
626    def poll_job(self, job=None, key=None):
627        if 'jobid' in job:
628            while True:
629                res = self.query_api('queryAsyncJobResult', jobid=job['jobid'])
630                if res['jobstatus'] != 0 and 'jobresult' in res:
631
632                    if 'errortext' in res['jobresult']:
633                        self.fail_json(msg="Failed: '%s'" % res['jobresult']['errortext'])
634
635                    if key and key in res['jobresult']:
636                        job = res['jobresult'][key]
637
638                    break
639                time.sleep(2)
640        return job
641
642    def update_result(self, resource, result=None):
643        if result is None:
644            result = dict()
645        if resource:
646            returns = self.common_returns.copy()
647            returns.update(self.returns)
648            for search_key, return_key in returns.items():
649                if search_key in resource:
650                    result[return_key] = resource[search_key]
651
652            # Bad bad API does not always return int when it should.
653            for search_key, return_key in self.returns_to_int.items():
654                if search_key in resource:
655                    result[return_key] = int(resource[search_key])
656
657        return result
658
659    def get_result(self, resource):
660        return self.update_result(resource, self.result)
661
662    def get_result_and_facts(self, facts_name, resource):
663        result = self.get_result(resource)
664
665        ansible_facts = {
666            facts_name: result.copy()
667        }
668        for k in ['diff', 'changed']:
669            if k in ansible_facts[facts_name]:
670                del ansible_facts[facts_name][k]
671
672        result.update(ansible_facts=ansible_facts)
673        return result
674