1""" 2:maintainer: Jack Kuan <kjkuan@gmail.com> 3:maturity: new 4:platform: all 5 6A Python DSL for generating Salt's highstate data structure. 7 8This module is intended to be used with the `pydsl` renderer, 9but can also be used on its own. Here's what you can do with 10Salt PyDSL:: 11 12 # Example translated from the online salt tutorial 13 14 apache = state('apache') 15 apache.pkg.installed() 16 apache.service.running() \\ 17 .watch(pkg='apache', 18 file='/etc/httpd/conf/httpd.conf', 19 user='apache') 20 21 if __grains__['os'] == 'RedHat': 22 apache.pkg.installed(name='httpd') 23 apache.service.running(name='httpd') 24 25 apache.group.present(gid=87).require(apache.pkg) 26 apache.user.present(uid=87, gid=87, 27 home='/var/www/html', 28 shell='/bin/nologin') \\ 29 .require(apache.group) 30 31 state('/etc/httpd/conf/httpd.conf').file.managed( 32 source='salt://apache/httpd.conf', 33 user='root', 34 group='root', 35 mode=644) 36 37 38Example with ``include`` and ``extend``, translated from 39the online salt tutorial:: 40 41 include('http', 'ssh') 42 extend( 43 state('apache').file( 44 name='/etc/httpd/conf/httpd.conf', 45 source='salt://http/httpd2.conf' 46 ), 47 state('ssh-server').service.watch(file='/etc/ssh/banner') 48 ) 49 state('/etc/ssh/banner').file.managed(source='salt://ssh/banner') 50 51 52Example of a ``cmd`` state calling a python function:: 53 54 def hello(s): 55 s = "hello world, %s" % s 56 return dict(result=True, changes=dict(changed=True, output=s)) 57 58 state('hello').cmd.call(hello, 'pydsl!') 59 60""" 61 62# Implementation note: 63# - There's a bit of terminology mix-up here: 64# - what I called a state or state declaration here is actually 65# an ID declaration. 66# - what I called a module or a state module actually corresponds 67# to a state declaration. 68# - and a state function is a function declaration. 69 70 71# TODOs: 72# 73# - support exclude declarations 74# 75# - allow this: 76# state('X').cmd.run.cwd = '/' 77# assert state('X').cmd.run.cwd == '/' 78# 79# - make it possible to remove: 80# - state declarations 81# - state module declarations 82# - state func and args 83# 84 85 86from uuid import uuid4 as _uuid 87 88from salt.state import HighState 89from salt.utils.odict import OrderedDict 90 91REQUISITES = set( 92 "listen require watch prereq use listen_in require_in watch_in prereq_in use_in" 93 " onchanges onfail".split() 94) 95 96 97class PyDslError(Exception): 98 pass 99 100 101class Options(dict): 102 def __getattr__(self, name): 103 return self.get(name) 104 105 106SLS_MATCHES = None 107 108 109class Sls: 110 def __init__(self, sls, saltenv, rendered_sls): 111 self.name = sls 112 self.saltenv = saltenv 113 self.includes = [] 114 self.included_highstate = HighState.get_active().building_highstate 115 self.extends = [] 116 self.decls = [] 117 self.options = Options() 118 self.funcs = [] # track the ordering of state func declarations 119 self.rendered_sls = rendered_sls # a set of names of rendered sls modules 120 121 if not HighState.get_active(): 122 raise PyDslError("PyDSL only works with a running high state!") 123 124 @classmethod 125 def get_all_decls(cls): 126 return HighState.get_active()._pydsl_all_decls 127 128 @classmethod 129 def get_render_stack(cls): 130 return HighState.get_active()._pydsl_render_stack 131 132 def set(self, **options): 133 self.options.update(options) 134 135 def include(self, *sls_names, **kws): 136 if "env" in kws: 137 # "env" is not supported; Use "saltenv". 138 kws.pop("env") 139 140 saltenv = kws.get("saltenv", self.saltenv) 141 142 if kws.get("delayed", False): 143 for incl in sls_names: 144 self.includes.append((saltenv, incl)) 145 return 146 147 HIGHSTATE = HighState.get_active() 148 149 global SLS_MATCHES 150 if SLS_MATCHES is None: 151 SLS_MATCHES = HIGHSTATE.top_matches(HIGHSTATE.get_top()) 152 153 highstate = self.included_highstate 154 slsmods = [] # a list of pydsl sls modules rendered. 155 for sls in sls_names: 156 r_env = "{}:{}".format(saltenv, sls) 157 if r_env not in self.rendered_sls: 158 self.rendered_sls.add( 159 sls 160 ) # needed in case the starting sls uses the pydsl renderer. 161 histates, errors = HIGHSTATE.render_state( 162 sls, saltenv, self.rendered_sls, SLS_MATCHES 163 ) 164 HIGHSTATE.merge_included_states(highstate, histates, errors) 165 if errors: 166 raise PyDslError("\n".join(errors)) 167 HIGHSTATE.clean_duplicate_extends(highstate) 168 169 state_id = "_slsmod_{}".format(sls) 170 if state_id not in highstate: 171 slsmods.append(None) 172 else: 173 for arg in highstate[state_id]["stateconf"]: 174 if isinstance(arg, dict) and next(iter(arg)) == "slsmod": 175 slsmods.append(arg["slsmod"]) 176 break 177 178 if not slsmods: 179 return None 180 return slsmods[0] if len(slsmods) == 1 else slsmods 181 182 def extend(self, *state_funcs): 183 if self.options.ordered or self.last_func(): 184 raise PyDslError("Cannot extend() after the ordered option was turned on!") 185 for f in state_funcs: 186 state_id = f.mod._state_id 187 self.extends.append(self.get_all_decls().pop(state_id)) 188 i = len(self.decls) 189 for decl in reversed(self.decls): 190 i -= 1 191 if decl._id == state_id: 192 del self.decls[i] 193 break 194 195 def state(self, id=None): 196 if not id: 197 id = ".{}".format(_uuid()) 198 # adds a leading dot to make use of stateconf's namespace feature. 199 try: 200 return self.get_all_decls()[id] 201 except KeyError: 202 self.get_all_decls()[id] = s = StateDeclaration(id) 203 self.decls.append(s) 204 return s 205 206 def last_func(self): 207 return self.funcs[-1] if self.funcs else None 208 209 def track_func(self, statefunc): 210 self.funcs.append(statefunc) 211 212 def to_highstate(self, slsmod): 213 # generate a state that uses the stateconf.set state, which 214 # is a no-op state, to hold a reference to the sls module 215 # containing the DSL statements. This is to prevent the module 216 # from being GC'ed, so that objects defined in it will be 217 # available while salt is executing the states. 218 slsmod_id = "_slsmod_" + self.name 219 self.state(slsmod_id).stateconf.set(slsmod=slsmod) 220 del self.get_all_decls()[slsmod_id] 221 222 highstate = OrderedDict() 223 if self.includes: 224 highstate["include"] = [{t[0]: t[1]} for t in self.includes] 225 if self.extends: 226 highstate["extend"] = extend = OrderedDict() 227 for ext in self.extends: 228 extend[ext._id] = ext._repr(context="extend") 229 for decl in self.decls: 230 highstate[decl._id] = decl._repr() 231 232 if self.included_highstate: 233 errors = [] 234 HighState.get_active().merge_included_states( 235 highstate, self.included_highstate, errors 236 ) 237 if errors: 238 raise PyDslError("\n".join(errors)) 239 return highstate 240 241 def load_highstate(self, highstate): 242 for sid, decl in highstate.items(): 243 s = self.state(sid) 244 for modname, args in decl.items(): 245 if "." in modname: 246 modname, funcname = modname.rsplit(".", 1) 247 else: 248 funcname = next(x for x in args if isinstance(x, str)) 249 args.remove(funcname) 250 mod = getattr(s, modname) 251 named_args = {} 252 for x in args: 253 if isinstance(x, dict): 254 k, v = next(iter(x.items())) 255 named_args[k] = v 256 mod(funcname, **named_args) 257 258 259class StateDeclaration: 260 def __init__(self, id): 261 self._id = id 262 self._mods = [] 263 264 def __getattr__(self, name): 265 for m in self._mods: 266 if m._name == name: 267 return m 268 m = StateModule(name, self._id) 269 self._mods.append(m) 270 return m 271 272 __getitem__ = __getattr__ 273 274 def __str__(self): 275 return self._id 276 277 def __iter__(self): 278 return iter(self._mods) 279 280 def _repr(self, context=None): 281 return OrderedDict(m._repr(context) for m in self) 282 283 def __call__(self, check=True): 284 sls = Sls.get_render_stack()[-1] 285 if self._id in sls.get_all_decls(): 286 last_func = sls.last_func() 287 if last_func and self._mods[-1]._func is not last_func: 288 raise PyDslError( 289 "Cannot run state({}: {}) that is required by a runtime " 290 "state({}: {}), at compile time.".format( 291 self._mods[-1]._name, 292 self._id, 293 last_func.mod, 294 last_func.mod._state_id, 295 ) 296 ) 297 sls.get_all_decls().pop(self._id) 298 sls.decls.remove(self) 299 self._mods[0]._func._remove_auto_require() 300 for m in self._mods: 301 try: 302 sls.funcs.remove(m._func) 303 except ValueError: 304 pass 305 306 result = HighState.get_active().state.functions["state.high"]( 307 {self._id: self._repr()} 308 ) 309 310 if not isinstance(result, dict): 311 # A list is an error 312 raise PyDslError( 313 "An error occurred while running highstate: {}".format( 314 "; ".join(result) 315 ) 316 ) 317 318 result = sorted(result.items(), key=lambda t: t[1]["__run_num__"]) 319 if check: 320 for k, v in result: 321 if not v["result"]: 322 import pprint 323 324 raise PyDslError( 325 "Failed executing low state at compile time:\n{}".format( 326 pprint.pformat({k: v}) 327 ) 328 ) 329 return result 330 331 332class StateModule: 333 def __init__(self, name, parent_decl): 334 self._state_id = parent_decl 335 self._name = name 336 self._func = None 337 338 def __getattr__(self, name): 339 if self._func: 340 if name == self._func.name: 341 return self._func 342 else: 343 if name not in REQUISITES: 344 if self._func.name: 345 raise PyDslError( 346 "Multiple state functions({}) not allowed in a " 347 "state module({})!".format(name, self._name) 348 ) 349 self._func.name = name 350 return self._func 351 return getattr(self._func, name) 352 353 if name in REQUISITES: 354 self._func = f = StateFunction(None, self) 355 return getattr(f, name) 356 else: 357 self._func = f = StateFunction(name, self) 358 return f 359 360 def __call__(self, _fname, *args, **kws): 361 return getattr(self, _fname).configure(args, kws) 362 363 def __str__(self): 364 return self._name 365 366 def _repr(self, context=None): 367 return (self._name, self._func._repr(context)) 368 369 370def _generate_requsite_method(t): 371 def req(self, *args, **kws): 372 for mod in args: 373 self.reference(t, mod, None) 374 for mod_ref in kws.items(): 375 self.reference(t, *mod_ref) 376 return self 377 378 return req 379 380 381class StateFunction: 382 def __init__(self, name, parent_mod): 383 self.mod = parent_mod 384 self.name = name 385 self.args = [] 386 387 # track the position of the auto-added require for easy 388 # removal if run at compile time. 389 self.require_index = None 390 391 sls = Sls.get_render_stack()[-1] 392 if sls.options.ordered: 393 last_f = sls.last_func() 394 if last_f: 395 self.require(last_f.mod) 396 self.require_index = len(self.args) - 1 397 sls.track_func(self) 398 399 def _remove_auto_require(self): 400 if self.require_index is not None: 401 del self.args[self.require_index] 402 self.require_index = None 403 404 def __call__(self, *args, **kws): 405 self.configure(args, kws) 406 return self 407 408 def _repr(self, context=None): 409 if not self.name and context != "extend": 410 raise PyDslError( 411 "No state function specified for module: {}".format(self.mod._name) 412 ) 413 if not self.name and context == "extend": 414 return self.args 415 return [self.name] + self.args 416 417 def configure(self, args, kws): 418 args = list(args) 419 if args: 420 first = args[0] 421 if ( 422 self.mod._name == "cmd" 423 and self.name in ("call", "wait_call") 424 and callable(first) 425 ): 426 427 args[0] = first.__name__ 428 kws = dict(func=first, args=args[1:], kws=kws) 429 del args[1:] 430 431 args[0] = dict(name=args[0]) 432 433 for k, v in kws.items(): 434 args.append({k: v}) 435 436 self.args.extend(args) 437 return self 438 439 def reference(self, req_type, mod, ref): 440 if isinstance(mod, StateModule): 441 ref = mod._state_id 442 elif not (mod and ref): 443 raise PyDslError( 444 "Invalid a requisite reference declaration! {}: {}".format(mod, ref) 445 ) 446 self.args.append({req_type: [{str(mod): str(ref)}]}) 447 448 ns = locals() 449 for req_type in REQUISITES: 450 ns[req_type] = _generate_requsite_method(req_type) 451 del ns 452 del req_type 453