1"""
2:maintainer: Jack Kuan <kjkuan@gmail.com>
3:maturity: new
4:platform: all
5
6A Python DSL for generating Salt's highstate data structure.
7
8This module is intended to be used with the `pydsl` renderer,
9but can also be used on its own. Here's what you can do with
10Salt PyDSL::
11
12    # Example translated from the online salt tutorial
13
14    apache = state('apache')
15    apache.pkg.installed()
16    apache.service.running() \\
17                  .watch(pkg='apache',
18                         file='/etc/httpd/conf/httpd.conf',
19                         user='apache')
20
21    if __grains__['os'] == 'RedHat':
22        apache.pkg.installed(name='httpd')
23        apache.service.running(name='httpd')
24
25    apache.group.present(gid=87).require(apache.pkg)
26    apache.user.present(uid=87, gid=87,
27                        home='/var/www/html',
28                        shell='/bin/nologin') \\
29               .require(apache.group)
30
31    state('/etc/httpd/conf/httpd.conf').file.managed(
32        source='salt://apache/httpd.conf',
33        user='root',
34        group='root',
35        mode=644)
36
37
38Example with ``include`` and ``extend``, translated from
39the online salt tutorial::
40
41    include('http', 'ssh')
42    extend(
43        state('apache').file(
44            name='/etc/httpd/conf/httpd.conf',
45            source='salt://http/httpd2.conf'
46        ),
47        state('ssh-server').service.watch(file='/etc/ssh/banner')
48    )
49    state('/etc/ssh/banner').file.managed(source='salt://ssh/banner')
50
51
52Example of a ``cmd`` state calling a python function::
53
54    def hello(s):
55        s = "hello world, %s" % s
56        return dict(result=True, changes=dict(changed=True, output=s))
57
58    state('hello').cmd.call(hello, 'pydsl!')
59
60"""
61
62# Implementation note:
63#  - There's a bit of terminology mix-up here:
64#    - what I called a state or state declaration here is actually
65#      an ID declaration.
66#    - what I called a module or a state module actually corresponds
67#      to a state declaration.
68#    - and a state function is a function declaration.
69
70
71# TODOs:
72#
73#  - support exclude declarations
74#
75#  - allow this:
76#      state('X').cmd.run.cwd = '/'
77#      assert state('X').cmd.run.cwd == '/'
78#
79#  - make it possible to remove:
80#    - state declarations
81#    - state module declarations
82#    - state func and args
83#
84
85
86from uuid import uuid4 as _uuid
87
88from salt.state import HighState
89from salt.utils.odict import OrderedDict
90
91REQUISITES = set(
92    "listen require watch prereq use listen_in require_in watch_in prereq_in use_in"
93    " onchanges onfail".split()
94)
95
96
97class PyDslError(Exception):
98    pass
99
100
101class Options(dict):
102    def __getattr__(self, name):
103        return self.get(name)
104
105
106SLS_MATCHES = None
107
108
109class Sls:
110    def __init__(self, sls, saltenv, rendered_sls):
111        self.name = sls
112        self.saltenv = saltenv
113        self.includes = []
114        self.included_highstate = HighState.get_active().building_highstate
115        self.extends = []
116        self.decls = []
117        self.options = Options()
118        self.funcs = []  # track the ordering of state func declarations
119        self.rendered_sls = rendered_sls  # a set of names of rendered sls modules
120
121        if not HighState.get_active():
122            raise PyDslError("PyDSL only works with a running high state!")
123
124    @classmethod
125    def get_all_decls(cls):
126        return HighState.get_active()._pydsl_all_decls
127
128    @classmethod
129    def get_render_stack(cls):
130        return HighState.get_active()._pydsl_render_stack
131
132    def set(self, **options):
133        self.options.update(options)
134
135    def include(self, *sls_names, **kws):
136        if "env" in kws:
137            # "env" is not supported; Use "saltenv".
138            kws.pop("env")
139
140        saltenv = kws.get("saltenv", self.saltenv)
141
142        if kws.get("delayed", False):
143            for incl in sls_names:
144                self.includes.append((saltenv, incl))
145            return
146
147        HIGHSTATE = HighState.get_active()
148
149        global SLS_MATCHES
150        if SLS_MATCHES is None:
151            SLS_MATCHES = HIGHSTATE.top_matches(HIGHSTATE.get_top())
152
153        highstate = self.included_highstate
154        slsmods = []  # a list of pydsl sls modules rendered.
155        for sls in sls_names:
156            r_env = "{}:{}".format(saltenv, sls)
157            if r_env not in self.rendered_sls:
158                self.rendered_sls.add(
159                    sls
160                )  # needed in case the starting sls uses the pydsl renderer.
161                histates, errors = HIGHSTATE.render_state(
162                    sls, saltenv, self.rendered_sls, SLS_MATCHES
163                )
164                HIGHSTATE.merge_included_states(highstate, histates, errors)
165                if errors:
166                    raise PyDslError("\n".join(errors))
167                HIGHSTATE.clean_duplicate_extends(highstate)
168
169            state_id = "_slsmod_{}".format(sls)
170            if state_id not in highstate:
171                slsmods.append(None)
172            else:
173                for arg in highstate[state_id]["stateconf"]:
174                    if isinstance(arg, dict) and next(iter(arg)) == "slsmod":
175                        slsmods.append(arg["slsmod"])
176                        break
177
178        if not slsmods:
179            return None
180        return slsmods[0] if len(slsmods) == 1 else slsmods
181
182    def extend(self, *state_funcs):
183        if self.options.ordered or self.last_func():
184            raise PyDslError("Cannot extend() after the ordered option was turned on!")
185        for f in state_funcs:
186            state_id = f.mod._state_id
187            self.extends.append(self.get_all_decls().pop(state_id))
188            i = len(self.decls)
189            for decl in reversed(self.decls):
190                i -= 1
191                if decl._id == state_id:
192                    del self.decls[i]
193                    break
194
195    def state(self, id=None):
196        if not id:
197            id = ".{}".format(_uuid())
198            # adds a leading dot to make use of stateconf's namespace feature.
199        try:
200            return self.get_all_decls()[id]
201        except KeyError:
202            self.get_all_decls()[id] = s = StateDeclaration(id)
203            self.decls.append(s)
204            return s
205
206    def last_func(self):
207        return self.funcs[-1] if self.funcs else None
208
209    def track_func(self, statefunc):
210        self.funcs.append(statefunc)
211
212    def to_highstate(self, slsmod):
213        # generate a state that uses the stateconf.set state, which
214        # is a no-op state, to hold a reference to the sls module
215        # containing the DSL statements. This is to prevent the module
216        # from being GC'ed, so that objects defined in it will be
217        # available while salt is executing the states.
218        slsmod_id = "_slsmod_" + self.name
219        self.state(slsmod_id).stateconf.set(slsmod=slsmod)
220        del self.get_all_decls()[slsmod_id]
221
222        highstate = OrderedDict()
223        if self.includes:
224            highstate["include"] = [{t[0]: t[1]} for t in self.includes]
225        if self.extends:
226            highstate["extend"] = extend = OrderedDict()
227            for ext in self.extends:
228                extend[ext._id] = ext._repr(context="extend")
229        for decl in self.decls:
230            highstate[decl._id] = decl._repr()
231
232        if self.included_highstate:
233            errors = []
234            HighState.get_active().merge_included_states(
235                highstate, self.included_highstate, errors
236            )
237            if errors:
238                raise PyDslError("\n".join(errors))
239        return highstate
240
241    def load_highstate(self, highstate):
242        for sid, decl in highstate.items():
243            s = self.state(sid)
244            for modname, args in decl.items():
245                if "." in modname:
246                    modname, funcname = modname.rsplit(".", 1)
247                else:
248                    funcname = next(x for x in args if isinstance(x, str))
249                    args.remove(funcname)
250                mod = getattr(s, modname)
251                named_args = {}
252                for x in args:
253                    if isinstance(x, dict):
254                        k, v = next(iter(x.items()))
255                        named_args[k] = v
256                mod(funcname, **named_args)
257
258
259class StateDeclaration:
260    def __init__(self, id):
261        self._id = id
262        self._mods = []
263
264    def __getattr__(self, name):
265        for m in self._mods:
266            if m._name == name:
267                return m
268        m = StateModule(name, self._id)
269        self._mods.append(m)
270        return m
271
272    __getitem__ = __getattr__
273
274    def __str__(self):
275        return self._id
276
277    def __iter__(self):
278        return iter(self._mods)
279
280    def _repr(self, context=None):
281        return OrderedDict(m._repr(context) for m in self)
282
283    def __call__(self, check=True):
284        sls = Sls.get_render_stack()[-1]
285        if self._id in sls.get_all_decls():
286            last_func = sls.last_func()
287            if last_func and self._mods[-1]._func is not last_func:
288                raise PyDslError(
289                    "Cannot run state({}: {}) that is required by a runtime "
290                    "state({}: {}), at compile time.".format(
291                        self._mods[-1]._name,
292                        self._id,
293                        last_func.mod,
294                        last_func.mod._state_id,
295                    )
296                )
297            sls.get_all_decls().pop(self._id)
298            sls.decls.remove(self)
299            self._mods[0]._func._remove_auto_require()
300            for m in self._mods:
301                try:
302                    sls.funcs.remove(m._func)
303                except ValueError:
304                    pass
305
306        result = HighState.get_active().state.functions["state.high"](
307            {self._id: self._repr()}
308        )
309
310        if not isinstance(result, dict):
311            # A list is an error
312            raise PyDslError(
313                "An error occurred while running highstate: {}".format(
314                    "; ".join(result)
315                )
316            )
317
318        result = sorted(result.items(), key=lambda t: t[1]["__run_num__"])
319        if check:
320            for k, v in result:
321                if not v["result"]:
322                    import pprint
323
324                    raise PyDslError(
325                        "Failed executing low state at compile time:\n{}".format(
326                            pprint.pformat({k: v})
327                        )
328                    )
329        return result
330
331
332class StateModule:
333    def __init__(self, name, parent_decl):
334        self._state_id = parent_decl
335        self._name = name
336        self._func = None
337
338    def __getattr__(self, name):
339        if self._func:
340            if name == self._func.name:
341                return self._func
342            else:
343                if name not in REQUISITES:
344                    if self._func.name:
345                        raise PyDslError(
346                            "Multiple state functions({}) not allowed in a "
347                            "state module({})!".format(name, self._name)
348                        )
349                    self._func.name = name
350                    return self._func
351                return getattr(self._func, name)
352
353        if name in REQUISITES:
354            self._func = f = StateFunction(None, self)
355            return getattr(f, name)
356        else:
357            self._func = f = StateFunction(name, self)
358            return f
359
360    def __call__(self, _fname, *args, **kws):
361        return getattr(self, _fname).configure(args, kws)
362
363    def __str__(self):
364        return self._name
365
366    def _repr(self, context=None):
367        return (self._name, self._func._repr(context))
368
369
370def _generate_requsite_method(t):
371    def req(self, *args, **kws):
372        for mod in args:
373            self.reference(t, mod, None)
374        for mod_ref in kws.items():
375            self.reference(t, *mod_ref)
376        return self
377
378    return req
379
380
381class StateFunction:
382    def __init__(self, name, parent_mod):
383        self.mod = parent_mod
384        self.name = name
385        self.args = []
386
387        # track the position of the auto-added require for easy
388        # removal if run at compile time.
389        self.require_index = None
390
391        sls = Sls.get_render_stack()[-1]
392        if sls.options.ordered:
393            last_f = sls.last_func()
394            if last_f:
395                self.require(last_f.mod)
396                self.require_index = len(self.args) - 1
397            sls.track_func(self)
398
399    def _remove_auto_require(self):
400        if self.require_index is not None:
401            del self.args[self.require_index]
402            self.require_index = None
403
404    def __call__(self, *args, **kws):
405        self.configure(args, kws)
406        return self
407
408    def _repr(self, context=None):
409        if not self.name and context != "extend":
410            raise PyDslError(
411                "No state function specified for module: {}".format(self.mod._name)
412            )
413        if not self.name and context == "extend":
414            return self.args
415        return [self.name] + self.args
416
417    def configure(self, args, kws):
418        args = list(args)
419        if args:
420            first = args[0]
421            if (
422                self.mod._name == "cmd"
423                and self.name in ("call", "wait_call")
424                and callable(first)
425            ):
426
427                args[0] = first.__name__
428                kws = dict(func=first, args=args[1:], kws=kws)
429                del args[1:]
430
431            args[0] = dict(name=args[0])
432
433        for k, v in kws.items():
434            args.append({k: v})
435
436        self.args.extend(args)
437        return self
438
439    def reference(self, req_type, mod, ref):
440        if isinstance(mod, StateModule):
441            ref = mod._state_id
442        elif not (mod and ref):
443            raise PyDslError(
444                "Invalid a requisite reference declaration! {}: {}".format(mod, ref)
445            )
446        self.args.append({req_type: [{str(mod): str(ref)}]})
447
448    ns = locals()
449    for req_type in REQUISITES:
450        ns[req_type] = _generate_requsite_method(req_type)
451    del ns
452    del req_type
453