1"""
2:maintainer: Evan Borgstrom <evan@borgstrom.ca>
3
4Pythonic object interface to creating state data, see the pyobjects renderer
5for more documentation.
6"""
7
8import inspect
9import logging
10
11from salt.utils.odict import OrderedDict
12
13REQUISITES = (
14    "listen",
15    "onchanges",
16    "onfail",
17    "require",
18    "watch",
19    "use",
20    "listen_in",
21    "onchanges_in",
22    "onfail_in",
23    "require_in",
24    "watch_in",
25    "use_in",
26)
27
28log = logging.getLogger(__name__)
29
30
31class StateException(Exception):
32    pass
33
34
35class DuplicateState(StateException):
36    pass
37
38
39class InvalidFunction(StateException):
40    pass
41
42
43class Registry:
44    """
45    The StateRegistry holds all of the states that have been created.
46    """
47
48    states = OrderedDict()
49    requisites = []
50    includes = []
51    extends = OrderedDict()
52    enabled = True
53
54    @classmethod
55    def empty(cls):
56        cls.states = OrderedDict()
57        cls.requisites = []
58        cls.includes = []
59        cls.extends = OrderedDict()
60
61    @classmethod
62    def include(cls, *args):
63        if not cls.enabled:
64            return
65
66        cls.includes += args
67
68    @classmethod
69    def salt_data(cls):
70        states = OrderedDict([(id_, states_) for id_, states_ in cls.states.items()])
71
72        if cls.includes:
73            states["include"] = cls.includes
74
75        if cls.extends:
76            states["extend"] = OrderedDict(
77                [(id_, states_) for id_, states_ in cls.extends.items()]
78            )
79
80        cls.empty()
81
82        return states
83
84    @classmethod
85    def add(cls, id_, state, extend=False):
86        if not cls.enabled:
87            return
88
89        if extend:
90            attr = cls.extends
91        else:
92            attr = cls.states
93
94        if id_ in attr:
95            if state.full_func in attr[id_]:
96                raise DuplicateState(
97                    "A state with id ''{}'', type ''{}'' exists".format(
98                        id_, state.full_func
99                    )
100                )
101        else:
102            attr[id_] = OrderedDict()
103
104        # if we have requisites in our stack then add them to the state
105        if cls.requisites:
106            for req in cls.requisites:
107                if req.requisite not in state.kwargs:
108                    state.kwargs[req.requisite] = []
109                state.kwargs[req.requisite].append(req())
110
111        attr[id_].update(state())
112
113    @classmethod
114    def extend(cls, id_, state):
115        cls.add(id_, state, extend=True)
116
117    @classmethod
118    def make_extend(cls, name):
119        return StateExtend(name)
120
121    @classmethod
122    def push_requisite(cls, requisite):
123        if not cls.enabled:
124            return
125
126        cls.requisites.append(requisite)
127
128    @classmethod
129    def pop_requisite(cls):
130        if not cls.enabled:
131            return
132
133        del cls.requisites[-1]
134
135
136class StateExtend:
137    def __init__(self, name):
138        self.name = name
139
140
141class StateRequisite:
142    def __init__(self, requisite, module, id_):
143        self.requisite = requisite
144        self.module = module
145        self.id_ = id_
146
147    def __call__(self):
148        return {self.module: self.id_}
149
150    def __enter__(self):
151        Registry.push_requisite(self)
152
153    def __exit__(self, type, value, traceback):
154        Registry.pop_requisite()
155
156
157class StateFactory:
158    """
159    The StateFactory is used to generate new States through a natural syntax
160
161    It is used by initializing it with the name of the salt module::
162
163        File = StateFactory("file")
164
165    Any attribute accessed on the instance returned by StateFactory is a lambda
166    that is a short cut for generating State objects::
167
168        File.managed('/path/', owner='root', group='root')
169
170    The kwargs are passed through to the State object
171    """
172
173    def __init__(self, module, valid_funcs=None):
174        self.module = module
175        if valid_funcs is None:
176            valid_funcs = []
177        self.valid_funcs = valid_funcs
178
179    def __getattr__(self, func):
180        if self.valid_funcs and func not in self.valid_funcs:
181            raise InvalidFunction(
182                "The function '{}' does not exist in the StateFactory for '{}'".format(
183                    func, self.module
184                )
185            )
186
187        def make_state(id_, **kwargs):
188            return State(id_, self.module, func, **kwargs)
189
190        return make_state
191
192    def __call__(self, id_, requisite="require"):
193        """
194        When an object is called it is being used as a requisite
195        """
196        # return the correct data structure for the requisite
197        return StateRequisite(requisite, self.module, id_)
198
199
200class State:
201    """
202    This represents a single item in the state tree
203
204    The id_ is the id of the state, the func is the full name of the salt
205    state (i.e. file.managed). All the keyword args you pass in become the
206    properties of your state.
207    """
208
209    def __init__(self, id_, module, func, **kwargs):
210        self.id_ = id_
211        self.module = module
212        self.func = func
213
214        # our requisites should all be lists, but when you only have a
215        # single item it's more convenient to provide it without
216        # wrapping it in a list. transform them into a list
217        for attr in REQUISITES:
218            if attr in kwargs:
219                try:
220                    iter(kwargs[attr])
221                except TypeError:
222                    kwargs[attr] = [kwargs[attr]]
223        self.kwargs = kwargs
224
225        if isinstance(self.id_, StateExtend):
226            Registry.extend(self.id_.name, self)
227            self.id_ = self.id_.name
228        else:
229            Registry.add(self.id_, self)
230
231        self.requisite = StateRequisite("require", self.module, self.id_)
232
233    @property
234    def attrs(self):
235        kwargs = self.kwargs
236
237        # handle our requisites
238        for attr in REQUISITES:
239            if attr in kwargs:
240                # rebuild the requisite list transforming any of the actual
241                # StateRequisite objects into their representative dict
242                kwargs[attr] = [
243                    req() if isinstance(req, StateRequisite) else req
244                    for req in kwargs[attr]
245                ]
246
247        # build our attrs from kwargs. we sort the kwargs by key so that we
248        # have consistent ordering for tests
249        return [{k: kwargs[k]} for k in sorted(kwargs.keys())]
250
251    @property
252    def full_func(self):
253        return "{!s}.{!s}".format(self.module, self.func)
254
255    def __str__(self):
256        return "{!s} = {!s}:{!s}".format(self.id_, self.full_func, self.attrs)
257
258    def __call__(self):
259        return {self.full_func: self.attrs}
260
261    def __enter__(self):
262        Registry.push_requisite(self.requisite)
263
264    def __exit__(self, type, value, traceback):
265        Registry.pop_requisite()
266
267
268class SaltObject:
269    """
270    Object based interface to the functions in __salt__
271
272    .. code-block:: python
273       :linenos:
274
275        Salt = SaltObject(__salt__)
276        Salt.cmd.run(bar)
277    """
278
279    def __init__(self, salt):
280        self._salt = salt
281
282    def __getattr__(self, mod):
283        class __wrapper__:
284            def __getattr__(wself, func):  # pylint: disable=E0213
285                try:
286                    return self._salt["{}.{}".format(mod, func)]
287                except KeyError:
288                    raise AttributeError
289
290        return __wrapper__()
291
292
293class MapMeta(type):
294    """
295    This is the metaclass for our Map class, used for building data maps based
296    off of grain data.
297    """
298
299    @classmethod
300    def __prepare__(metacls, name, bases):
301        return OrderedDict()
302
303    def __new__(cls, name, bases, attrs):
304        c = type.__new__(cls, name, bases, attrs)
305        c.__ordered_attrs__ = attrs.keys()
306        return c
307
308    def __init__(cls, name, bases, nmspc):
309        cls.__set_attributes__()  # pylint: disable=no-value-for-parameter
310        super().__init__(name, bases, nmspc)
311
312    def __set_attributes__(cls):
313        match_info = []
314        grain_targets = set()
315
316        # find all of our filters
317        for item in cls.__ordered_attrs__:
318            if item[0] == "_":
319                continue
320
321            filt = cls.__dict__[item]
322
323            # only process classes
324            if not inspect.isclass(filt):
325                continue
326
327            # which grain are we filtering on
328            grain = getattr(filt, "__grain__", "os_family")
329            grain_targets.add(grain)
330
331            # does the object pointed to have a __match__ attribute?
332            # if so use it, otherwise use the name of the object
333            # this is so that you can match complex values, which the python
334            # class name syntax does not allow
335            match = getattr(filt, "__match__", item)
336
337            match_attrs = {}
338            for name in filt.__dict__:
339                if name[0] != "_":
340                    match_attrs[name] = filt.__dict__[name]
341
342            match_info.append((grain, match, match_attrs))
343
344        # Reorder based on priority
345        try:
346            if not hasattr(cls.priority, "__iter__"):
347                log.error("pyobjects: priority must be an iterable")
348            else:
349                new_match_info = []
350                for grain in cls.priority:
351                    # Using list() here because we will be modifying
352                    # match_info during iteration
353                    for index, item in list(enumerate(match_info)):
354                        try:
355                            if item[0] == grain:
356                                # Add item to new list
357                                new_match_info.append(item)
358                                # Clear item from old list
359                                match_info[index] = None
360                        except TypeError:
361                            # Already moved this item to new list
362                            pass
363                # Add in any remaining items not defined in priority
364                new_match_info.extend([x for x in match_info if x is not None])
365                # Save reordered list as the match_info list
366                match_info = new_match_info
367        except AttributeError:
368            pass
369
370        # Check for matches and update the attrs dict accordingly
371        attrs = {}
372        if match_info:
373            grain_vals = Map.__salt__["grains.item"](*grain_targets)
374            for grain, match, match_attrs in match_info:
375                if grain not in grain_vals:
376                    continue
377                if grain_vals[grain] == match:
378                    attrs.update(match_attrs)
379
380        if hasattr(cls, "merge"):
381            pillar = Map.__salt__["pillar.get"](cls.merge)
382            if pillar:
383                attrs.update(pillar)
384
385        for name in attrs:
386            setattr(cls, name, attrs[name])
387
388
389def need_salt(*a, **k):
390    log.error("Map needs __salt__ set before it can be used!")
391    return {}
392
393
394class Map(metaclass=MapMeta):  # pylint: disable=W0232
395    __salt__ = {"grains.filter_by": need_salt, "pillar.get": need_salt}
396