1# Copyright: (c) 2012-2014, Michael DeHaan <michael.dehaan@gmail.com>
2# Copyright: (c) 2017, Ansible Project
3# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
4
5from __future__ import (absolute_import, division, print_function)
6__metaclass__ = type
7
8import itertools
9import operator
10
11from copy import copy as shallowcopy
12from functools import partial
13
14from jinja2.exceptions import UndefinedError
15
16from ansible import constants as C
17from ansible import context
18from ansible.module_utils.six import iteritems, string_types, with_metaclass
19from ansible.module_utils.parsing.convert_bool import boolean
20from ansible.errors import AnsibleParserError, AnsibleUndefinedVariable, AnsibleAssertionError
21from ansible.module_utils._text import to_text, to_native
22from ansible.playbook.attribute import Attribute, FieldAttribute
23from ansible.parsing.dataloader import DataLoader
24from ansible.utils.display import Display
25from ansible.utils.sentinel import Sentinel
26from ansible.utils.vars import combine_vars, isidentifier, get_unique_id
27
28display = Display()
29
30
31def _generic_g(prop_name, self):
32    try:
33        value = self._attributes[prop_name]
34    except KeyError:
35        raise AttributeError("'%s' object has no attribute '%s'" % (self.__class__.__name__, prop_name))
36
37    if value is Sentinel:
38        value = self._attr_defaults[prop_name]
39
40    return value
41
42
43def _generic_g_method(prop_name, self):
44    try:
45        if self._squashed:
46            return self._attributes[prop_name]
47        method = "_get_attr_%s" % prop_name
48        return getattr(self, method)()
49    except KeyError:
50        raise AttributeError("'%s' object has no attribute '%s'" % (self.__class__.__name__, prop_name))
51
52
53def _generic_g_parent(prop_name, self):
54    try:
55        if self._squashed or self._finalized:
56            value = self._attributes[prop_name]
57        else:
58            try:
59                value = self._get_parent_attribute(prop_name)
60            except AttributeError:
61                value = self._attributes[prop_name]
62    except KeyError:
63        raise AttributeError("'%s' object has no attribute '%s'" % (self.__class__.__name__, prop_name))
64
65    if value is Sentinel:
66        value = self._attr_defaults[prop_name]
67
68    return value
69
70
71def _generic_s(prop_name, self, value):
72    self._attributes[prop_name] = value
73
74
75def _generic_d(prop_name, self):
76    del self._attributes[prop_name]
77
78
79class BaseMeta(type):
80
81    """
82    Metaclass for the Base object, which is used to construct the class
83    attributes based on the FieldAttributes available.
84    """
85
86    def __new__(cls, name, parents, dct):
87        def _create_attrs(src_dict, dst_dict):
88            '''
89            Helper method which creates the attributes based on those in the
90            source dictionary of attributes. This also populates the other
91            attributes used to keep track of these attributes and via the
92            getter/setter/deleter methods.
93            '''
94            keys = list(src_dict.keys())
95            for attr_name in keys:
96                value = src_dict[attr_name]
97                if isinstance(value, Attribute):
98                    if attr_name.startswith('_'):
99                        attr_name = attr_name[1:]
100
101                    # here we selectively assign the getter based on a few
102                    # things, such as whether we have a _get_attr_<name>
103                    # method, or if the attribute is marked as not inheriting
104                    # its value from a parent object
105                    method = "_get_attr_%s" % attr_name
106                    if method in src_dict or method in dst_dict:
107                        getter = partial(_generic_g_method, attr_name)
108                    elif ('_get_parent_attribute' in dst_dict or '_get_parent_attribute' in src_dict) and value.inherit:
109                        getter = partial(_generic_g_parent, attr_name)
110                    else:
111                        getter = partial(_generic_g, attr_name)
112
113                    setter = partial(_generic_s, attr_name)
114                    deleter = partial(_generic_d, attr_name)
115
116                    dst_dict[attr_name] = property(getter, setter, deleter)
117                    dst_dict['_valid_attrs'][attr_name] = value
118                    dst_dict['_attributes'][attr_name] = Sentinel
119                    dst_dict['_attr_defaults'][attr_name] = value.default
120
121                    if value.alias is not None:
122                        dst_dict[value.alias] = property(getter, setter, deleter)
123                        dst_dict['_valid_attrs'][value.alias] = value
124                        dst_dict['_alias_attrs'][value.alias] = attr_name
125
126        def _process_parents(parents, dst_dict):
127            '''
128            Helper method which creates attributes from all parent objects
129            recursively on through grandparent objects
130            '''
131            for parent in parents:
132                if hasattr(parent, '__dict__'):
133                    _create_attrs(parent.__dict__, dst_dict)
134                    new_dst_dict = parent.__dict__.copy()
135                    new_dst_dict.update(dst_dict)
136                    _process_parents(parent.__bases__, new_dst_dict)
137
138        # create some additional class attributes
139        dct['_attributes'] = {}
140        dct['_attr_defaults'] = {}
141        dct['_valid_attrs'] = {}
142        dct['_alias_attrs'] = {}
143
144        # now create the attributes based on the FieldAttributes
145        # available, including from parent (and grandparent) objects
146        _create_attrs(dct, dct)
147        _process_parents(parents, dct)
148
149        return super(BaseMeta, cls).__new__(cls, name, parents, dct)
150
151
152class FieldAttributeBase(with_metaclass(BaseMeta, object)):
153
154    def __init__(self):
155
156        # initialize the data loader and variable manager, which will be provided
157        # later when the object is actually loaded
158        self._loader = None
159        self._variable_manager = None
160
161        # other internal params
162        self._validated = False
163        self._squashed = False
164        self._finalized = False
165
166        # every object gets a random uuid:
167        self._uuid = get_unique_id()
168
169        # we create a copy of the attributes here due to the fact that
170        # it was initialized as a class param in the meta class, so we
171        # need a unique object here (all members contained within are
172        # unique already).
173        self._attributes = self.__class__._attributes.copy()
174        self._attr_defaults = self.__class__._attr_defaults.copy()
175        for key, value in self._attr_defaults.items():
176            if callable(value):
177                self._attr_defaults[key] = value()
178
179        # and init vars, avoid using defaults in field declaration as it lives across plays
180        self.vars = dict()
181
182    def dump_me(self, depth=0):
183        ''' this is never called from production code, it is here to be used when debugging as a 'complex print' '''
184        if depth == 0:
185            display.debug("DUMPING OBJECT ------------------------------------------------------")
186        display.debug("%s- %s (%s, id=%s)" % (" " * depth, self.__class__.__name__, self, id(self)))
187        if hasattr(self, '_parent') and self._parent:
188            self._parent.dump_me(depth + 2)
189            dep_chain = self._parent.get_dep_chain()
190            if dep_chain:
191                for dep in dep_chain:
192                    dep.dump_me(depth + 2)
193        if hasattr(self, '_play') and self._play:
194            self._play.dump_me(depth + 2)
195
196    def preprocess_data(self, ds):
197        ''' infrequently used method to do some pre-processing of legacy terms '''
198        return ds
199
200    def load_data(self, ds, variable_manager=None, loader=None):
201        ''' walk the input datastructure and assign any values '''
202
203        if ds is None:
204            raise AnsibleAssertionError('ds (%s) should not be None but it is.' % ds)
205
206        # cache the datastructure internally
207        setattr(self, '_ds', ds)
208
209        # the variable manager class is used to manage and merge variables
210        # down to a single dictionary for reference in templating, etc.
211        self._variable_manager = variable_manager
212
213        # the data loader class is used to parse data from strings and files
214        if loader is not None:
215            self._loader = loader
216        else:
217            self._loader = DataLoader()
218
219        # call the preprocess_data() function to massage the data into
220        # something we can more easily parse, and then call the validation
221        # function on it to ensure there are no incorrect key values
222        ds = self.preprocess_data(ds)
223        self._validate_attributes(ds)
224
225        # Walk all attributes in the class. We sort them based on their priority
226        # so that certain fields can be loaded before others, if they are dependent.
227        for name, attr in sorted(iteritems(self._valid_attrs), key=operator.itemgetter(1)):
228            # copy the value over unless a _load_field method is defined
229            target_name = name
230            if name in self._alias_attrs:
231                target_name = self._alias_attrs[name]
232            if name in ds:
233                method = getattr(self, '_load_%s' % name, None)
234                if method:
235                    self._attributes[target_name] = method(name, ds[name])
236                else:
237                    self._attributes[target_name] = ds[name]
238
239        # run early, non-critical validation
240        self.validate()
241
242        # return the constructed object
243        return self
244
245    def get_ds(self):
246        try:
247            return getattr(self, '_ds')
248        except AttributeError:
249            return None
250
251    def get_loader(self):
252        return self._loader
253
254    def get_variable_manager(self):
255        return self._variable_manager
256
257    def _post_validate_debugger(self, attr, value, templar):
258        value = templar.template(value)
259        valid_values = frozenset(('always', 'on_failed', 'on_unreachable', 'on_skipped', 'never'))
260        if value and isinstance(value, string_types) and value not in valid_values:
261            raise AnsibleParserError("'%s' is not a valid value for debugger. Must be one of %s" % (value, ', '.join(valid_values)), obj=self.get_ds())
262        return value
263
264    def _validate_attributes(self, ds):
265        '''
266        Ensures that there are no keys in the datastructure which do
267        not map to attributes for this object.
268        '''
269
270        valid_attrs = frozenset(self._valid_attrs.keys())
271        for key in ds:
272            if key not in valid_attrs:
273                raise AnsibleParserError("'%s' is not a valid attribute for a %s" % (key, self.__class__.__name__), obj=ds)
274
275    def validate(self, all_vars=None):
276        ''' validation that is done at parse time, not load time '''
277        all_vars = {} if all_vars is None else all_vars
278
279        if not self._validated:
280            # walk all fields in the object
281            for (name, attribute) in iteritems(self._valid_attrs):
282
283                if name in self._alias_attrs:
284                    name = self._alias_attrs[name]
285
286                # run validator only if present
287                method = getattr(self, '_validate_%s' % name, None)
288                if method:
289                    method(attribute, name, getattr(self, name))
290                else:
291                    # and make sure the attribute is of the type it should be
292                    value = self._attributes[name]
293                    if value is not None:
294                        if attribute.isa == 'string' and isinstance(value, (list, dict)):
295                            raise AnsibleParserError(
296                                "The field '%s' is supposed to be a string type,"
297                                " however the incoming data structure is a %s" % (name, type(value)), obj=self.get_ds()
298                            )
299
300        self._validated = True
301
302    def squash(self):
303        '''
304        Evaluates all attributes and sets them to the evaluated version,
305        so that all future accesses of attributes do not need to evaluate
306        parent attributes.
307        '''
308        if not self._squashed:
309            for name in self._valid_attrs.keys():
310                self._attributes[name] = getattr(self, name)
311            self._squashed = True
312
313    def copy(self):
314        '''
315        Create a copy of this object and return it.
316        '''
317
318        new_me = self.__class__()
319
320        for name in self._valid_attrs.keys():
321            if name in self._alias_attrs:
322                continue
323            new_me._attributes[name] = shallowcopy(self._attributes[name])
324            new_me._attr_defaults[name] = shallowcopy(self._attr_defaults[name])
325
326        new_me._loader = self._loader
327        new_me._variable_manager = self._variable_manager
328        new_me._validated = self._validated
329        new_me._finalized = self._finalized
330        new_me._uuid = self._uuid
331
332        # if the ds value was set on the object, copy it to the new copy too
333        if hasattr(self, '_ds'):
334            new_me._ds = self._ds
335
336        return new_me
337
338    def get_validated_value(self, name, attribute, value, templar):
339        if attribute.isa == 'string':
340            value = to_text(value)
341        elif attribute.isa == 'int':
342            value = int(value)
343        elif attribute.isa == 'float':
344            value = float(value)
345        elif attribute.isa == 'bool':
346            value = boolean(value, strict=True)
347        elif attribute.isa == 'percent':
348            # special value, which may be an integer or float
349            # with an optional '%' at the end
350            if isinstance(value, string_types) and '%' in value:
351                value = value.replace('%', '')
352            value = float(value)
353        elif attribute.isa == 'list':
354            if value is None:
355                value = []
356            elif not isinstance(value, list):
357                value = [value]
358            if attribute.listof is not None:
359                for item in value:
360                    if not isinstance(item, attribute.listof):
361                        raise AnsibleParserError("the field '%s' should be a list of %s, "
362                                                 "but the item '%s' is a %s" % (name, attribute.listof, item, type(item)), obj=self.get_ds())
363                    elif attribute.required and attribute.listof == string_types:
364                        if item is None or item.strip() == "":
365                            raise AnsibleParserError("the field '%s' is required, and cannot have empty values" % (name,), obj=self.get_ds())
366        elif attribute.isa == 'set':
367            if value is None:
368                value = set()
369            elif not isinstance(value, (list, set)):
370                if isinstance(value, string_types):
371                    value = value.split(',')
372                else:
373                    # Making a list like this handles strings of
374                    # text and bytes properly
375                    value = [value]
376            if not isinstance(value, set):
377                value = set(value)
378        elif attribute.isa == 'dict':
379            if value is None:
380                value = dict()
381            elif not isinstance(value, dict):
382                raise TypeError("%s is not a dictionary" % value)
383        elif attribute.isa == 'class':
384            if not isinstance(value, attribute.class_type):
385                raise TypeError("%s is not a valid %s (got a %s instead)" % (name, attribute.class_type, type(value)))
386            value.post_validate(templar=templar)
387        return value
388
389    def post_validate(self, templar):
390        '''
391        we can't tell that everything is of the right type until we have
392        all the variables.  Run basic types (from isa) as well as
393        any _post_validate_<foo> functions.
394        '''
395
396        # save the omit value for later checking
397        omit_value = templar.available_variables.get('omit')
398
399        for (name, attribute) in iteritems(self._valid_attrs):
400
401            if attribute.static:
402                value = getattr(self, name)
403
404                # we don't template 'vars' but allow template as values for later use
405                if name not in ('vars',) and templar.is_template(value):
406                    display.warning('"%s" is not templatable, but we found: %s, '
407                                    'it will not be templated and will be used "as is".' % (name, value))
408                continue
409
410            if getattr(self, name) is None:
411                if not attribute.required:
412                    continue
413                else:
414                    raise AnsibleParserError("the field '%s' is required but was not set" % name)
415            elif not attribute.always_post_validate and self.__class__.__name__ not in ('Task', 'Handler', 'PlayContext'):
416                # Intermediate objects like Play() won't have their fields validated by
417                # default, as their values are often inherited by other objects and validated
418                # later, so we don't want them to fail out early
419                continue
420
421            try:
422                # Run the post-validator if present. These methods are responsible for
423                # using the given templar to template the values, if required.
424                method = getattr(self, '_post_validate_%s' % name, None)
425                if method:
426                    value = method(attribute, getattr(self, name), templar)
427                elif attribute.isa == 'class':
428                    value = getattr(self, name)
429                else:
430                    # if the attribute contains a variable, template it now
431                    value = templar.template(getattr(self, name))
432
433                # if this evaluated to the omit value, set the value back to
434                # the default specified in the FieldAttribute and move on
435                if omit_value is not None and value == omit_value:
436                    if callable(attribute.default):
437                        setattr(self, name, attribute.default())
438                    else:
439                        setattr(self, name, attribute.default)
440                    continue
441
442                # and make sure the attribute is of the type it should be
443                if value is not None:
444                    value = self.get_validated_value(name, attribute, value, templar)
445
446                # and assign the massaged value back to the attribute field
447                setattr(self, name, value)
448            except (TypeError, ValueError) as e:
449                value = getattr(self, name)
450                raise AnsibleParserError("the field '%s' has an invalid value (%s), and could not be converted to an %s."
451                                         "The error was: %s" % (name, value, attribute.isa, e), obj=self.get_ds(), orig_exc=e)
452            except (AnsibleUndefinedVariable, UndefinedError) as e:
453                if templar._fail_on_undefined_errors and name != 'name':
454                    if name == 'args':
455                        msg = "The task includes an option with an undefined variable. The error was: %s" % (to_native(e))
456                    else:
457                        msg = "The field '%s' has an invalid value, which includes an undefined variable. The error was: %s" % (name, to_native(e))
458                    raise AnsibleParserError(msg, obj=self.get_ds(), orig_exc=e)
459
460        self._finalized = True
461
462    def _load_vars(self, attr, ds):
463        '''
464        Vars in a play can be specified either as a dictionary directly, or
465        as a list of dictionaries. If the later, this method will turn the
466        list into a single dictionary.
467        '''
468
469        def _validate_variable_keys(ds):
470            for key in ds:
471                if not isidentifier(key):
472                    raise TypeError("'%s' is not a valid variable name" % key)
473
474        try:
475            if isinstance(ds, dict):
476                _validate_variable_keys(ds)
477                return combine_vars(self.vars, ds)
478            elif isinstance(ds, list):
479                all_vars = self.vars
480                for item in ds:
481                    if not isinstance(item, dict):
482                        raise ValueError
483                    _validate_variable_keys(item)
484                    all_vars = combine_vars(all_vars, item)
485                return all_vars
486            elif ds is None:
487                return {}
488            else:
489                raise ValueError
490        except ValueError as e:
491            raise AnsibleParserError("Vars in a %s must be specified as a dictionary, or a list of dictionaries" % self.__class__.__name__,
492                                     obj=ds, orig_exc=e)
493        except TypeError as e:
494            raise AnsibleParserError("Invalid variable name in vars specified for %s: %s" % (self.__class__.__name__, e), obj=ds, orig_exc=e)
495
496    def _extend_value(self, value, new_value, prepend=False):
497        '''
498        Will extend the value given with new_value (and will turn both
499        into lists if they are not so already). The values are run through
500        a set to remove duplicate values.
501        '''
502
503        if not isinstance(value, list):
504            value = [value]
505        if not isinstance(new_value, list):
506            new_value = [new_value]
507
508        # Due to where _extend_value may run for some attributes
509        # it is possible to end up with Sentinel in the list of values
510        # ensure we strip them
511        value = [v for v in value if v is not Sentinel]
512        new_value = [v for v in new_value if v is not Sentinel]
513
514        if prepend:
515            combined = new_value + value
516        else:
517            combined = value + new_value
518
519        return [i for i, _ in itertools.groupby(combined) if i is not None]
520
521    def dump_attrs(self):
522        '''
523        Dumps all attributes to a dictionary
524        '''
525        attrs = {}
526        for (name, attribute) in iteritems(self._valid_attrs):
527            attr = getattr(self, name)
528            if attribute.isa == 'class' and hasattr(attr, 'serialize'):
529                attrs[name] = attr.serialize()
530            else:
531                attrs[name] = attr
532        return attrs
533
534    def from_attrs(self, attrs):
535        '''
536        Loads attributes from a dictionary
537        '''
538        for (attr, value) in iteritems(attrs):
539            if attr in self._valid_attrs:
540                attribute = self._valid_attrs[attr]
541                if attribute.isa == 'class' and isinstance(value, dict):
542                    obj = attribute.class_type()
543                    obj.deserialize(value)
544                    setattr(self, attr, obj)
545                else:
546                    setattr(self, attr, value)
547
548        # from_attrs is only used to create a finalized task
549        # from attrs from the Worker/TaskExecutor
550        # Those attrs are finalized and squashed in the TE
551        # and controller side use needs to reflect that
552        self._finalized = True
553        self._squashed = True
554
555    def serialize(self):
556        '''
557        Serializes the object derived from the base object into
558        a dictionary of values. This only serializes the field
559        attributes for the object, so this may need to be overridden
560        for any classes which wish to add additional items not stored
561        as field attributes.
562        '''
563
564        repr = self.dump_attrs()
565
566        # serialize the uuid field
567        repr['uuid'] = self._uuid
568        repr['finalized'] = self._finalized
569        repr['squashed'] = self._squashed
570
571        return repr
572
573    def deserialize(self, data):
574        '''
575        Given a dictionary of values, load up the field attributes for
576        this object. As with serialize(), if there are any non-field
577        attribute data members, this method will need to be overridden
578        and extended.
579        '''
580
581        if not isinstance(data, dict):
582            raise AnsibleAssertionError('data (%s) should be a dict but is a %s' % (data, type(data)))
583
584        for (name, attribute) in iteritems(self._valid_attrs):
585            if name in data:
586                setattr(self, name, data[name])
587            else:
588                if callable(attribute.default):
589                    setattr(self, name, attribute.default())
590                else:
591                    setattr(self, name, attribute.default)
592
593        # restore the UUID field
594        setattr(self, '_uuid', data.get('uuid'))
595        self._finalized = data.get('finalized', False)
596        self._squashed = data.get('squashed', False)
597
598
599class Base(FieldAttributeBase):
600
601    _name = FieldAttribute(isa='string', default='', always_post_validate=True, inherit=False)
602
603    # connection/transport
604    _connection = FieldAttribute(isa='string', default=context.cliargs_deferred_get('connection'))
605    _port = FieldAttribute(isa='int')
606    _remote_user = FieldAttribute(isa='string', default=context.cliargs_deferred_get('remote_user'))
607
608    # variables
609    _vars = FieldAttribute(isa='dict', priority=100, inherit=False, static=True)
610
611    # module default params
612    _module_defaults = FieldAttribute(isa='list', extend=True, prepend=True)
613
614    # flags and misc. settings
615    _environment = FieldAttribute(isa='list', extend=True, prepend=True)
616    _no_log = FieldAttribute(isa='bool')
617    _run_once = FieldAttribute(isa='bool')
618    _ignore_errors = FieldAttribute(isa='bool')
619    _ignore_unreachable = FieldAttribute(isa='bool')
620    _check_mode = FieldAttribute(isa='bool', default=context.cliargs_deferred_get('check'))
621    _diff = FieldAttribute(isa='bool', default=context.cliargs_deferred_get('diff'))
622    _any_errors_fatal = FieldAttribute(isa='bool', default=C.ANY_ERRORS_FATAL)
623    _throttle = FieldAttribute(isa='int', default=0)
624    _timeout = FieldAttribute(isa='int', default=C.TASK_TIMEOUT)
625
626    # explicitly invoke a debugger on tasks
627    _debugger = FieldAttribute(isa='string')
628
629    # Privilege escalation
630    _become = FieldAttribute(isa='bool', default=context.cliargs_deferred_get('become'))
631    _become_method = FieldAttribute(isa='string', default=context.cliargs_deferred_get('become_method'))
632    _become_user = FieldAttribute(isa='string', default=context.cliargs_deferred_get('become_user'))
633    _become_flags = FieldAttribute(isa='string', default=context.cliargs_deferred_get('become_flags'))
634    _become_exe = FieldAttribute(isa='string', default=context.cliargs_deferred_get('become_exe'))
635
636    # used to hold sudo/su stuff
637    DEPRECATED_ATTRIBUTES = []
638