1#
2# -*- coding: utf-8 -*-
3# Copyright 2019 Red Hat Inc.
4# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
5"""
6The asa_ogs class
7It is in this file where the current configuration (as dict)
8is compared to the provided configuration (as dict) and the command set
9necessary to bring the current configuration to it's desired end-state is
10created
11"""
12
13from __future__ import absolute_import, division, print_function
14
15__metaclass__ = type
16
17import copy
18from ansible.module_utils.six import iteritems
19from ansible_collections.cisco.asa.plugins.module_utils.network.asa.facts.facts import (
20    Facts,
21)
22from ansible_collections.cisco.asa.plugins.module_utils.network.asa.rm_templates.ogs import (
23    OGsTemplate,
24)
25from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.utils import (
26    dict_merge,
27)
28from ansible_collections.ansible.netcommon.plugins.module_utils.network.common.resource_module import (
29    ResourceModule,
30)
31
32
33class OGs(ResourceModule):
34    """
35    The asa_ogs class
36    """
37
38    gather_subset = ["!all", "!min"]
39
40    gather_network_resources = ["ogs"]
41
42    def __init__(self, module):
43        super(OGs, self).__init__(
44            empty_fact_val={},
45            facts_module=Facts(module),
46            module=module,
47            resource="ogs",
48            tmplt=OGsTemplate(),
49        )
50
51    def execute_module(self):
52        """Execute the module
53        :rtype: A dictionary
54        :returns: The result from module execution
55        """
56        self.gen_config()
57        self.run_commands()
58        return self.result
59
60    def gen_config(self):
61        """Select the appropriate function based on the state provided
62        :rtype: A list
63        :returns: the commands necessary to migrate the current configuration
64                  to the desired configuration
65        """
66        if self.want:
67            temp = {}
68            for entry in self.want:
69                temp.update({(entry["object_type"]): entry})
70            wantd = temp
71        else:
72            wantd = {}
73        if self.have:
74            temp = {}
75            for entry in self.have:
76                temp.update({(entry["object_type"]): entry})
77            haved = temp
78        else:
79            haved = {}
80
81        obj_gp = {}
82        for k, v in wantd.items():
83            temp = {}
84            for each in v.get("object_groups"):
85                temp[each.get("name")] = each
86                temp["object_type"] = k
87                obj_gp[k] = temp
88        if obj_gp:
89            wantd = obj_gp
90            obj_gp = {}
91        for k, v in haved.items():
92            temp = {}
93            for each in v.get("object_groups"):
94                temp[each.get("name")] = each
95                temp["object_type"] = k
96                obj_gp[k] = temp
97        if obj_gp:
98            haved = obj_gp
99
100        # if state is merged, merge want onto have
101        if self.state == "merged":
102            wantd = dict_merge(haved, wantd)
103
104        # if state is deleted, empty out wantd and set haved to wantd
105        if self.state == "deleted":
106            temp = {}
107            for k, v in iteritems(haved):
108                temp_have = {}
109                if k in wantd or not wantd:
110                    for key, val in iteritems(v):
111                        if not wantd or key in wantd[k]:
112                            temp_have.update({key: val})
113                    temp.update({k: temp_have})
114            haved = temp
115            wantd = {}
116
117        # delete processes first so we do run into "more than one" errors
118        if self.state in ["overridden", "deleted"]:
119            for k, have in iteritems(haved):
120                if k not in wantd:
121                    for each_key, each_val in iteritems(have):
122                        if each_key != "object_type":
123                            each_val.update(
124                                {"object_type": have.get("object_type")}
125                            )
126                            self.addcmd(each_val, "og_name", True)
127
128        for k, want in iteritems(wantd):
129            self._compare(want=want, have=haved.pop(k, {}))
130
131    def _compare(self, want, have):
132        if want != have:
133            for k, v in iteritems(want):
134                if k != "object_type":
135                    v.update({"object_type": want.get("object_type")})
136            if have:
137                for k, v in iteritems(have):
138                    if k != "object_type":
139                        v.update({"object_type": want.get("object_type")})
140
141            object_type = want.get("object_type")
142            if object_type == "icmp-type":
143                self._icmp_object_compare(want, have)
144            if object_type == "network":
145                self._network_object_compare(want, have)
146            elif object_type == "protocol":
147                self._protocol_object_compare(want, have)
148            elif object_type == "security":
149                self._security_object_compare(want, have)
150            elif object_type == "service":
151                self._service_object_compare(want, have)
152            elif object_type == "user":
153                self._user_object_compare(want, have)
154
155    def get_list_diff(self, want, have, object, param):
156        diff = [
157            item
158            for item in want[object][param]
159            if item not in have[object][param]
160        ]
161        return diff
162
163    def check_for_have_and_overidden(self, have):
164        if have and self.state == "overridden":
165            for name, entry in iteritems(have):
166                if name != "object_type":
167                    self.addcmd(entry, "og_name", True)
168
169    def _icmp_object_compare(self, want, have):
170        icmp_obj = "icmp_type"
171        for name, entry in iteritems(want):
172            h_item = have.pop(name, {})
173            if (
174                entry != h_item
175                and name != "object_type"
176                and entry[icmp_obj].get("icmp_object")
177            ):
178                if h_item and entry.get("group_object"):
179                    self.addcmd(entry, "og_name", False)
180                    self._add_group_object_cmd(entry, h_item)
181                    continue
182                if h_item:
183                    self._add_object_cmd(
184                        entry, h_item, icmp_obj, ["icmp_type"]
185                    )
186                else:
187                    self.addcmd(entry, "og_name", False)
188                    self.compare(["description"], entry, h_item)
189                if entry.get("group_object"):
190                    self._add_group_object_cmd(entry, h_item)
191                    continue
192                if self.state in ("overridden", "replaced") and h_item:
193                    self.compare(["icmp_type"], {}, h_item)
194                if h_item and h_item[icmp_obj].get("icmp_object"):
195                    li_diff = self.get_list_diff(
196                        entry, h_item, icmp_obj, "icmp_object"
197                    )
198                else:
199                    li_diff = entry[icmp_obj].get("icmp_object")
200                entry[icmp_obj]["icmp_object"] = li_diff
201                self.addcmd(entry, "icmp_type", False)
202        self.check_for_have_and_overidden(have)
203
204    def _network_object_compare(self, want, have):
205        network_obj = "network_object"
206        parsers = [
207            "network_object.host",
208            "network_object.address",
209            "network_object.ipv6_address",
210            "network_object.object",
211        ]
212        add_obj_cmd = False
213        for name, entry in iteritems(want):
214            h_item = have.pop(name, {})
215            if entry != h_item and name != "object_type":
216                if h_item and entry.get("group_object"):
217                    self.addcmd(entry, "og_name", False)
218                    self._add_group_object_cmd(entry, h_item)
219                    continue
220                if h_item:
221                    self._add_object_cmd(
222                        entry,
223                        h_item,
224                        network_obj,
225                        ["address", "host", "ipv6_address", "object"],
226                    )
227                else:
228                    add_obj_cmd = True
229                    self.addcmd(entry, "og_name", False)
230                    self.compare(["description"], entry, h_item)
231                if entry.get("group_object"):
232                    self._add_group_object_cmd(entry, h_item)
233                    continue
234                if entry[network_obj].get("address"):
235                    self._compare_object_diff(
236                        entry,
237                        h_item,
238                        network_obj,
239                        "address",
240                        parsers,
241                        "network_object.address",
242                    )
243                elif (
244                    h_item
245                    and h_item.get(network_obj)
246                    and h_item[network_obj].get("address")
247                ):
248                    h_item[network_obj] = {
249                        "address": h_item[network_obj].get("address")
250                    }
251                    if not add_obj_cmd:
252                        self.addcmd(entry, "og_name", False)
253                    self.compare(parsers, {}, h_item)
254                if entry[network_obj].get("host"):
255                    self._compare_object_diff(
256                        entry,
257                        h_item,
258                        network_obj,
259                        "host",
260                        parsers,
261                        "network_object.host",
262                    )
263                elif h_item and h_item[network_obj].get("host"):
264                    h_item[network_obj] = {
265                        "host": h_item[network_obj].get("host")
266                    }
267                    if not add_obj_cmd:
268                        self.addcmd(entry, "og_name", False)
269                    self.compare(parsers, {}, h_item)
270                if entry[network_obj].get("ipv6_address"):
271                    self._compare_object_diff(
272                        entry,
273                        h_item,
274                        network_obj,
275                        "ipv6_address",
276                        parsers,
277                        "network_object.ipv6_address",
278                    )
279                elif (
280                    h_item
281                    and h_item.get(network_obj)
282                    and h_item[network_obj].get("ipv6_address")
283                ):
284                    h_item[network_obj] = {
285                        "ipv6_address": h_item[network_obj].get("ipv6_address")
286                    }
287                    if not add_obj_cmd:
288                        self.addcmd(entry, "og_name", False)
289                    self.compare(parsers, {}, h_item)
290                if entry[network_obj].get("object"):
291                    self._compare_object_diff(
292                        entry,
293                        h_item,
294                        network_obj,
295                        "object",
296                        parsers,
297                        "network_object.object",
298                    )
299                elif (
300                    h_item
301                    and h_item.get(network_obj)
302                    and h_item[network_obj].get("object")
303                ):
304                    h_item[network_obj] = {
305                        "object": h_item[network_obj].get("object")
306                    }
307                    if not add_obj_cmd:
308                        self.addcmd(entry, "og_name", False)
309                    self.compare(parsers, {}, h_item)
310        self.check_for_have_and_overidden(have)
311
312    def _protocol_object_compare(self, want, have):
313        protocol_obj = "protocol_object"
314        for name, entry in iteritems(want):
315            h_item = have.pop(name, {})
316            if entry != h_item and name != "object_type":
317                if h_item and entry.get("group_object"):
318                    self.addcmd(entry, "og_name", False)
319                    self._add_group_object_cmd(entry, h_item)
320                    continue
321                if h_item:
322                    self._add_object_cmd(
323                        entry, h_item, protocol_obj, ["protocol"]
324                    )
325                else:
326                    self.addcmd(entry, "og_name", False)
327                    self.compare(["description"], entry, h_item)
328                if entry.get("group_object"):
329                    self._add_group_object_cmd(entry, h_item)
330                    continue
331                if entry[protocol_obj].get("protocol"):
332                    self._compare_object_diff(
333                        entry,
334                        h_item,
335                        protocol_obj,
336                        "protocol",
337                        [protocol_obj],
338                        protocol_obj,
339                    )
340        self.check_for_have_and_overidden(have)
341
342    def _security_object_compare(self, want, have):
343        security_obj = "security_group"
344        parsers = ["security_group.sec_name", "security_group.tag"]
345        add_obj_cmd = False
346        for name, entry in iteritems(want):
347            h_item = have.pop(name, {})
348            if entry != h_item and name != "object_type":
349                if h_item and entry.get("group_object"):
350                    self.addcmd(entry, "og_name", False)
351                    self._add_group_object_cmd(entry, h_item)
352                    continue
353                if h_item:
354                    self._add_object_cmd(
355                        entry, h_item, security_obj, ["sec_name", "tag"]
356                    )
357                else:
358                    add_obj_cmd = True
359                    self.addcmd(entry, "og_name", False)
360                    self.compare(["description"], entry, h_item)
361                if entry.get("group_object"):
362                    self._add_group_object_cmd(entry, h_item)
363                    continue
364                if entry[security_obj].get("sec_name"):
365                    self._compare_object_diff(
366                        entry,
367                        h_item,
368                        security_obj,
369                        "sec_name",
370                        parsers,
371                        "security_group.sec_name",
372                    )
373                elif h_item and h_item[security_obj].get("sec_name"):
374                    h_item[security_obj] = {
375                        "sec_name": h_item[security_obj].get("sec_name")
376                    }
377                    if not add_obj_cmd:
378                        self.addcmd(entry, "og_name", False)
379                    self.compare(parsers, {}, h_item)
380                if entry[security_obj].get("tag"):
381                    self._compare_object_diff(
382                        entry,
383                        h_item,
384                        security_obj,
385                        "tag",
386                        parsers,
387                        "security_group.tag",
388                    )
389                elif h_item and h_item[security_obj].get("tag"):
390                    h_item[security_obj] = {
391                        "tag": h_item[security_obj].get("tag")
392                    }
393                    if not add_obj_cmd:
394                        self.addcmd(entry, "og_name", False)
395                    self.compare(parsers, {}, h_item)
396        self.check_for_have_and_overidden(have)
397
398    def _service_object_compare(self, want, have):
399        service_obj = "service_object"
400        services_obj = "services_object"
401        port_obj = "port_object"
402        for name, entry in iteritems(want):
403            h_item = have.pop(name, {})
404            if entry != h_item and name != "object_type":
405                if h_item and entry.get("group_object"):
406                    self.addcmd(entry, "og_name", False)
407                    self._add_group_object_cmd(entry, h_item)
408                    continue
409                if h_item:
410                    self._add_object_cmd(
411                        entry, h_item, service_obj, ["protocol"]
412                    )
413                else:
414                    self.addcmd(entry, "og_name", False)
415                    self.compare(["description"], entry, h_item)
416                if entry.get("group_object"):
417                    self._add_group_object_cmd(entry, h_item)
418                    continue
419                if entry.get(service_obj):
420                    if entry[service_obj].get("protocol"):
421                        self._compare_object_diff(
422                            entry,
423                            h_item,
424                            service_obj,
425                            "protocol",
426                            ["service_object"],
427                            service_obj,
428                        )
429                elif entry.get(services_obj):
430                    if h_item:
431                        h_item = self.convert_list_to_dict(
432                            val=h_item,
433                            source="source_port",
434                            destination="destination_port",
435                        )
436                    entry = self.convert_list_to_dict(
437                        val=entry,
438                        source="source_port",
439                        destination="destination_port",
440                    )
441                    command_len = len(self.commands)
442                    for k, v in iteritems(entry):
443                        if h_item:
444                            h_service_item = h_item.pop(k, {})
445                            if h_service_item != v:
446                                self.compare(
447                                    [services_obj],
448                                    want={services_obj: v},
449                                    have={services_obj: h_service_item},
450                                )
451                        else:
452                            temp_want = {"name": name, services_obj: v}
453                            self.addcmd(temp_want, "og_name", True)
454
455                            self.compare(
456                                [services_obj], want=temp_want, have={}
457                            )
458                    if h_item and self.state in ["overridden", "replaced"]:
459                        for k, v in iteritems(h_item):
460                            temp_have = {"name": name, services_obj: v}
461                            self.compare(
462                                [services_obj], want={}, have=temp_have
463                            )
464                    if command_len < len(self.commands):
465                        cmd = "object-group service {0}".format(name)
466                        if cmd not in self.commands:
467                            self.commands.insert(command_len, cmd)
468                elif entry.get(port_obj):
469                    protocol = entry.get("protocol")
470                    if h_item:
471                        h_item = self.convert_list_to_dict(
472                            val=h_item,
473                            source="source_port",
474                            destination="destination_port",
475                        )
476                    entry = self.convert_list_to_dict(
477                        val=entry,
478                        source="source_port",
479                        destination="destination_port",
480                    )
481                    command_len = len(self.commands)
482                    for k, v in iteritems(entry):
483                        h_port_item = h_item.pop(k, {})
484                        if "http" in k and "_" in k:
485                            # This condition is to TC of device behaviour, where if user tries to
486                            # configure http it gets converted to www.
487                            temp = k.split("_")[0]
488                            h_port_item = {temp: "http"}
489                        if h_port_item != v:
490                            self.compare(
491                                [port_obj],
492                                want={port_obj: v},
493                                have={port_obj: h_port_item},
494                            )
495                        elif not h_port_item:
496                            temp_want = {"name": name, port_obj: v}
497                            self.compare([port_obj], want=temp_want, have={})
498                    if h_item and self.state in ["overridden", "replaced"]:
499                        for k, v in iteritems(h_item):
500                            temp_have = {"name": name, port_obj: v}
501                            self.compare([port_obj], want={}, have=temp_have)
502                    if command_len < len(self.commands):
503                        self.commands.insert(
504                            command_len,
505                            "object-group service {0} {1}".format(
506                                name, protocol
507                            ),
508                        )
509        self.check_for_have_and_overidden(have)
510
511    def convert_list_to_dict(self, *args, **kwargs):
512        temp = {}
513        if kwargs["val"].get("services_object"):
514            for every in kwargs["val"]["services_object"]:
515                temp_key = every["protocol"]
516                if "source_port" in every:
517                    if "range" in every["source_port"]:
518                        temp_key = (
519                            "range"
520                            + "_"
521                            + str(every["source_port"]["range"]["start"])
522                            + "_"
523                            + str(every["source_port"]["range"]["end"])
524                        )
525                    else:
526                        source_key = list(every["source_port"])[0]
527                        temp_key = (
528                            temp_key
529                            + "_"
530                            + source_key
531                            + "_"
532                            + every["source_port"][source_key]
533                        )
534                if "destination_port" in every:
535                    if "range" in every["destination_port"]:
536                        temp_key = (
537                            "range"
538                            + "_"
539                            + str(every["destination_port"]["range"]["start"])
540                            + "_"
541                            + str(every["destination_port"]["range"]["end"])
542                        )
543                    else:
544                        destination_key = list(every["destination_port"])[0]
545                        temp_key = (
546                            temp_key
547                            + "_"
548                            + destination_key
549                            + "_"
550                            + every["destination_port"][destination_key]
551                        )
552                temp.update({temp_key: every})
553            return temp
554        elif kwargs["val"].get("port_object"):
555            for every in kwargs["val"]["port_object"]:
556                if "range" in every:
557                    temp_key = (
558                        "start"
559                        + "_"
560                        + every["range"]["start"]
561                        + "_"
562                        + "end"
563                        + "_"
564                        + every["range"]["end"]
565                    )
566                else:
567                    every_key = list(every)[0]
568                    temp_key = every_key + "_" + every[every_key]
569                temp.update({temp_key: every})
570            return temp
571
572    def _user_object_compare(self, want, have):
573        user_obj = "user_object"
574        parsers = ["user_object.user", "user_object.user_gp"]
575        add_obj_cmd = False
576        for name, entry in iteritems(want):
577            h_item = have.pop(name, {})
578            if entry != h_item and name != "object_type":
579                if h_item and entry.get("group_object"):
580                    self.addcmd(entry, "og_name", False)
581                    self._add_group_object_cmd(entry, h_item)
582                    continue
583                if h_item:
584                    self._add_object_cmd(
585                        entry, h_item, user_obj, ["user", "user_group"]
586                    )
587                else:
588                    add_obj_cmd = True
589                    self.addcmd(entry, "og_name", False)
590                    self.compare(["description"], entry, h_item)
591                if entry.get("group_object"):
592                    self._add_group_object_cmd(entry, h_item)
593                    continue
594                if entry[user_obj].get("user"):
595                    self._compare_object_diff(
596                        entry,
597                        h_item,
598                        user_obj,
599                        "user",
600                        ["user_object.user"],
601                        "user_object.user",
602                    )
603                elif h_item and h_item[user_obj].get("user"):
604                    h_item[user_obj] = {"user": h_item[user_obj].get("user")}
605                    if not add_obj_cmd:
606                        self.addcmd(entry, "og_name", False)
607                    self.compare(parsers, {}, h_item)
608                if entry[user_obj].get("user_group"):
609                    self._compare_object_diff(
610                        entry,
611                        h_item,
612                        user_obj,
613                        "user_group",
614                        ["user_object.user_group"],
615                        "user_object.user_gp",
616                    )
617                elif h_item and h_item[user_obj].get("user_group"):
618                    h_item[user_obj] = {
619                        "user_group": h_item[user_obj].get("user_group")
620                    }
621                    if not add_obj_cmd:
622                        self.addcmd(entry, "og_name", False)
623                    self.compare(parsers, {}, h_item)
624        self.check_for_have_and_overidden(have)
625
626    def _add_object_cmd(self, want, have, object, object_elements):
627        obj_cmd_added = False
628        for each in object_elements:
629            want_element = want[object].get(each) if want.get(object) else want
630            have_element = have[object].get(each) if have.get(object) else have
631            if (
632                want_element
633                and isinstance(want_element, list)
634                and isinstance(want_element[0], dict)
635            ):
636                if (
637                    want_element
638                    and have_element
639                    and want_element != have_element
640                ):
641                    if not obj_cmd_added:
642                        self.addcmd(want, "og_name", False)
643                        self.compare(["description"], want, have)
644                        obj_cmd_added = True
645            else:
646                if (
647                    want_element
648                    and have_element
649                    and set(want_element) != set(have_element)
650                ):
651                    if not obj_cmd_added:
652                        self.addcmd(want, "og_name", False)
653                        self.compare(["description"], want, have)
654                        obj_cmd_added = True
655
656    def _add_group_object_cmd(self, want, have):
657        if have and have.get("group_object"):
658            want["group_object"] = list(
659                set(want.get("group_object")) - set(have.get("group_object"))
660            )
661            have["group_object"] = list(
662                set(have.get("group_object")) - set(want.get("group_object"))
663            )
664        for each in want["group_object"]:
665            self.compare(["group_object"], {"group_object": each}, dict())
666        if (
667            (self.state == "replaced" or self.state == "overridden")
668            and have
669            and have.get("group_object")
670        ):
671            for each in have["group_object"]:
672                self.compare(["group_object"], dict(), {"group_object": each})
673
674    def _compare_object_diff(
675        self, want, have, object, object_type, parsers, val
676    ):
677        temp_have = copy.copy(have)
678        temp_want = copy.copy(want)
679        if (
680            temp_have
681            and temp_have.get(object)
682            and temp_have[object].get(object_type)
683        ):
684            want_diff = self.get_list_diff(
685                temp_want, temp_have, object, object_type
686            )
687            have_diff = [
688                each
689                for each in temp_have[object][object_type]
690                if each not in temp_want[object][object_type]
691            ]
692            if have_diff:
693                temp_have[object].pop(object_type)
694        else:
695            have_diff = []
696            want_diff = temp_want[object].get(object_type)
697        temp_want[object][object_type] = want_diff
698        if (
699            have_diff
700            or temp_have.get(object)
701            and self.state in ("overridden", "replaced")
702        ):
703            if have_diff:
704                temp_have[object] = {object_type: have_diff}
705                self.compare(parsers, {}, temp_have)
706        self.addcmd(temp_want, val, False)
707