1""" 2Defines base classes `Op`, `PureOp`, and `CLinkerOp`. 3 4The `Op` class is the base interface for all operations 5compatible with `gof`'s :doc:`graph` routines. 6 7""" 8from __future__ import absolute_import, print_function, division 9 10import inspect 11import logging 12import numpy as np 13import os 14import re 15import sys 16import warnings 17 18import theano 19from theano import config 20 21import theano.gof.cc 22from six import itervalues, PY3 23from theano.gof import graph 24from theano.gof import utils 25from theano.gof.cmodule import GCC_compiler 26from theano.gof.fg import FunctionGraph 27 28__authors__ = "theano-dev" 29__copyright__ = "(c) 2010, Universite de Montreal" 30__license__ = "3-clause BSD License" 31__contact__ = "theano-dev <theano-dev@googlegroups.com>" 32 33__docformat__ = "restructuredtext en" 34 35_logger = logging.getLogger('theano.gof.op.Op') 36 37 38# Open file in "universal newline mode". 39# In Python 2, this is done by calling open(..., 'U'), but this is 40# deprected in Python 3 (where we would need to pass "newline=None", 41# which is the default). 42if PY3: 43 _open_u = open 44else: 45 def _open_u(file): 46 return open(file, 'U') 47 48 49class CLinkerObject(object): 50 """ 51 Standard elements of an Op or Type used with the CLinker. 52 53 """ 54 55 def c_headers(self): 56 """ 57 Optional: Return a list of header files required by code returned by 58 this class. 59 60 Examples 61 -------- 62 return ['<iostream>', '<math.h>', '/full/path/to/header.h'] 63 64 These strings will be prefixed with "#include " and inserted at the 65 beginning of the c source code. 66 67 Strings in this list that start neither with '<' nor '"' will be 68 enclosed in double-quotes. 69 70 Raises 71 ------ 72 MethodNotDefined 73 Subclass does not implement this method. 74 75 """ 76 raise utils.MethodNotDefined( 77 "c_headers", type(self), self.__class__.__name__) 78 79 def c_header_dirs(self): 80 """ 81 Optional: Return a list of header search paths required by code 82 returned by this class. 83 84 Examples 85 -------- 86 return ['/usr/local/include', '/opt/weirdpath/src/include'] 87 88 Provides search paths for headers, in addition to those in any relevant 89 environment variables. 90 91 Hint: for unix compilers, these are the things that get '-I' prefixed 92 in the compiler cmdline. 93 94 Raises 95 ------ 96 MethodNotDefined 97 Subclass does not implement this method. 98 99 """ 100 raise utils.MethodNotDefined( 101 "c_header_dirs", 102 type(self), 103 self.__class__.__name__) 104 105 def c_libraries(self): 106 """ 107 Optional: Return a list of libraries required by code returned by 108 this class. 109 110 Examples 111 -------- 112 return ['gsl', 'gslcblas', 'm', 'fftw3', 'g2c']. 113 114 The compiler will search the directories specified by the environment 115 variable LD_LIBRARY_PATH in addition to any returned by `c_lib_dirs`. 116 117 Hint: for unix compilers, these are the things that get '-l' prefixed 118 in the compiler cmdline. 119 120 Raises 121 ------ 122 MethodNotDefined 123 Subclass does not implement this method. 124 125 """ 126 raise utils.MethodNotDefined( 127 "c_libraries", type(self), self.__class__.__name__) 128 129 def c_lib_dirs(self): 130 """ 131 Optional: Return a list of library search paths required by code 132 returned by this class. 133 134 Examples 135 -------- 136 return ['/usr/local/lib', '/opt/weirdpath/build/libs']. 137 138 Provides search paths for libraries, in addition to those in any 139 relevant environment variables (e.g. LD_LIBRARY_PATH). 140 141 Hint: for unix compilers, these are the things that get '-L' prefixed 142 in the compiler cmdline. 143 144 Raises 145 ------ 146 MethodNotDefined 147 Subclass does not implement this method. 148 149 """ 150 raise utils.MethodNotDefined( 151 "c_lib_dirs", type(self), self.__class__.__name__) 152 153 def c_support_code(self): 154 """ 155 Optional: Return utility code (a string, or a list of strings) for use by a `Variable` or `Op` to be 156 included at global scope prior to the rest of the code for this class. 157 158 QUESTION: How many times will this support code be emitted for a graph 159 with many instances of the same type? 160 161 Raises 162 ------ 163 MethodNotDefined 164 Subclass does not implement this method. 165 166 """ 167 raise utils.MethodNotDefined( 168 "c_support_code", 169 type(self), 170 self.__class__.__name__) 171 172 def c_code_cache_version(self): 173 """ 174 Return a tuple of integers indicating the version of this Op. 175 176 An empty tuple indicates an 'unversioned' Op that will not be cached 177 between processes. 178 179 The cache mechanism may erase cached modules that have been superceded 180 by newer versions. See `ModuleCache` for details. 181 182 See Also 183 -------- 184 c_code_cache_version_apply() 185 186 """ 187 return () 188 189 def c_compile_args(self): 190 """ 191 Optional: Return a list of compile args recommended to compile the 192 code returned by other methods in this class. 193 194 Example 195 ------- 196 return ['-ffast-math'] 197 198 Compiler arguments related to headers, libraries and search paths should 199 be provided via the functions `c_headers`, `c_libraries`, 200 `c_header_dirs`, and `c_lib_dirs`. 201 202 Raises 203 ------ 204 MethodNotDefined 205 Subclass does not implement this method. 206 207 """ 208 raise utils.MethodNotDefined( 209 "c_compile_args", 210 type(self), 211 self.__class__.__name__) 212 213 def c_no_compile_args(self): 214 """ 215 Optional: return a list of incompatible gcc compiler arguments. 216 217 We will remove those arguments from the command line of gcc. So if 218 another Op adds a compile arg in the graph that is incompatible 219 with this Op, the incompatible arg will not be used. 220 Useful for instance to remove -ffast-math. 221 222 EXAMPLE 223 224 WRITEME 225 226 Raises 227 ------ 228 MethodNotDefined 229 The subclass does not override this method. 230 231 """ 232 raise utils.MethodNotDefined( 233 "c_no_compile_args", 234 type(self), 235 self.__class__.__name__) 236 237 def c_init_code(self): 238 """ 239 Optional: return a list of code snippets to be inserted in module 240 initialization. 241 242 Raises 243 ------ 244 MethodNotDefined 245 The subclass does not override this method. 246 247 """ 248 raise utils.MethodNotDefined("c_init_code", type(self), 249 self.__class__.__name__) 250 251 252class CLinkerOp(CLinkerObject): 253 """ 254 Interface definition for `Op` subclasses compiled by `CLinker`. 255 256 A subclass should implement WRITEME. 257 258 WRITEME: structure of automatically generated C code. 259 Put this in doc/code_structure.txt 260 261 """ 262 263 def c_code(self, node, name, inputs, outputs, sub): 264 """ 265 Required: return the C implementation of an Op. 266 267 Returns C code that does the computation associated to this `Op`, 268 given names for the inputs and outputs. 269 270 Parameters 271 ---------- 272 node : Apply instance 273 The node for which we are compiling the current c_code. 274 The same Op may be used in more than one node. 275 name : str 276 A name that is automatically assigned and guaranteed to be 277 unique. 278 inputs : list of strings 279 There is a string for each input of the function, and the 280 string is the name of a C variable pointing to that input. 281 The type of the variable depends on the declared type of 282 the input. There is a corresponding python variable that 283 can be accessed by prepending "py_" to the name in the 284 list. 285 outputs : list of strings 286 Each string is the name of a C variable where the Op should 287 store its output. The type depends on the declared type of 288 the output. There is a corresponding python variable that 289 can be accessed by prepending "py_" to the name in the 290 list. In some cases the outputs will be preallocated and 291 the value of the variable may be pre-filled. The value for 292 an unallocated output is type-dependent. 293 sub : dict of strings 294 Extra symbols defined in `CLinker` sub symbols (such as 'fail'). 295 WRITEME 296 297 Raises 298 ------ 299 MethodNotDefined 300 The subclass does not override this method. 301 302 """ 303 raise utils.MethodNotDefined('%s.c_code' % self.__class__.__name__) 304 305 def c_code_cache_version_apply(self, node): 306 """ 307 Return a tuple of integers indicating the version of this Op. 308 309 An empty tuple indicates an 'unversioned' Op that will not be 310 cached between processes. 311 312 The cache mechanism may erase cached modules that have been 313 superceded by newer versions. See `ModuleCache` for details. 314 315 See Also 316 -------- 317 c_code_cache_version() 318 319 Notes 320 ----- 321 This function overrides `c_code_cache_version` unless it explicitly 322 calls `c_code_cache_version`. The default implementation simply 323 calls `c_code_cache_version` and ignores the `node` argument. 324 325 """ 326 return self.c_code_cache_version() 327 328 def c_code_cleanup(self, node, name, inputs, outputs, sub): 329 """ 330 Optional: return C code to run after c_code, whether it failed or not. 331 332 This is a convenient place to clean up things allocated by c_code(). 333 334 Parameters 335 ---------- 336 node : Apply instance 337 WRITEME 338 name : str 339 A name that is automatically assigned and guaranteed to be 340 unique. 341 inputs : list of strings 342 There is a string for each input of the function, and the 343 string is the name of a C variable pointing to that input. 344 The type of the variable depends on the declared type of 345 the input. There is a corresponding python variable that 346 can be accessed by prepending "py_" to the name in the 347 list. 348 outputs : list of strings 349 Each string is the name of a C variable correspoinding to 350 one of the outputs of the Op. The type depends on the 351 declared type of the output. There is a corresponding 352 python variable that can be accessed by prepending "py_" to 353 the name in the list. 354 sub : dict of strings 355 extra symbols defined in `CLinker` sub symbols (such as 'fail'). 356 WRITEME 357 358 Raises 359 ------ 360 MethodNotDefined 361 The subclass does not override this method. 362 363 """ 364 raise utils.MethodNotDefined('%s.c_code_cleanup' % 365 self.__class__.__name__) 366 367 def c_support_code_apply(self, node, name): 368 """ 369 Optional: return utility code for use by an `Op` that will be 370 inserted at global scope, that can be specialized for the 371 support of a particular `Apply` node. 372 373 Parameters 374 ---------- 375 node: an Apply instance in the graph being compiled 376 name: str 377 A string or number that serves to uniquely identify this node. 378 Symbol names defined by this support code should include the name, 379 so that they can be called from the c_code, and so that they do not 380 cause name collisions. 381 382 Notes 383 ----- 384 This function is called in addition to c_support_code and will 385 supplement whatever is returned from there. 386 387 Raises 388 ------ 389 MethodNotDefined 390 Subclass does not implement this method. 391 392 """ 393 raise utils.MethodNotDefined("c_support_code_apply", 394 type(self), self.__class__.__name__) 395 396 def c_init_code_apply(self, node, name): 397 """ 398 Optional: return a code string specific to the apply 399 to be inserted in the module initialization code. 400 401 Parameters 402 ---------- 403 node : an Apply instance in the graph being compiled 404 name : str 405 A string or number that serves to uniquely identify this node. 406 Symbol names defined by this support code should include the name, 407 so that they can be called from the c_code, and so that they do not 408 cause name collisions. 409 410 Notes 411 ----- 412 This function is called in addition to c_init_code and will supplement 413 whatever is returned from there. 414 415 Raises 416 ------ 417 MethodNotDefined 418 The subclass does not override this method. 419 420 """ 421 raise utils.MethodNotDefined("c_init_code_apply", type(self), 422 self.__class__.__name__) 423 424 def c_init_code_struct(self, node, name, sub): 425 """ 426 Optional: return a code string specific to the apply 427 to be inserted in the struct initialization code. 428 429 Parameters 430 ---------- 431 node : an Apply instance in the graph being compiled 432 name : str 433 A unique name to distinguish variables from those of other nodes. 434 sub 435 A dictionary of values to substitute in the code. 436 Most notably it contains a 'fail' entry that you should place in 437 your code after setting a python exception to indicate an error. 438 439 Raises 440 ------ 441 MethodNotDefined 442 The subclass does not override this method. 443 444 """ 445 raise utils.MethodNotDefined("c_init_code_struct", type(self), 446 self.__class__.__name__) 447 448 def c_support_code_struct(self, node, name): 449 """ 450 Optional: return utility code for use by an `Op` that will be 451 inserted at struct scope, that can be specialized for the 452 support of a particular `Apply` node. 453 454 Parameters 455 ---------- 456 node : an Apply instance in the graph being compiled 457 name : str 458 A unique name to distinguish you variables from those of other 459 nodes. 460 461 Raises 462 ------ 463 MethodNotDefined 464 Subclass does not implement this method. 465 466 """ 467 raise utils.MethodNotDefined("c_support_code_struct", 468 type(self), self.__class__.__name__) 469 470 def c_cleanup_code_struct(self, node, name): 471 """ 472 Optional: return a code string specific to the apply to be 473 inserted in the struct cleanup code. 474 475 Parameters 476 ---------- 477 node : an Apply instance in the graph being compiled 478 name : str 479 A unique name to distinguish variables from those of other nodes. 480 481 Raises 482 ------ 483 MethodNotDefined 484 The subclass does not override this method. 485 486 """ 487 raise utils.MethodNotDefined("c_cleanup_code_struct", type(self), 488 self.__class__.__name__) 489 490 491class PureOp(object): 492 """ 493 An :term:`Op` is a type of operation. 494 495 `Op` is an abstract class that documents the interface for theano's data 496 transformations. It has many subclasses, such as 497 `sparse dot <http://pylearn.org/epydoc/theano.sparse.Dot-class.html>`__, 498 and `Shape <http://pylearn.org/epydoc/theano.tensor.Shape-class.html>`__. 499 500 These subclasses are meant to be instantiated. 501 An instance has several responsabilities: 502 503 - making `Apply` instances, which mean "apply this type of operation to some 504 particular inputs" (via `make_node`), 505 506 - performing the calculation of outputs from given inputs 507 (via the `perform`), 508 509 - [optionally] building gradient-calculating graphs (via `grad`). 510 511 To see how `Op`, `Type`, `Variable`, and `Apply` fit together see the page 512 on :doc:`graph`. 513 514 For more specifications on how these methods should behave: see the 515 `Op Contract` in the sphinx docs (advanced tutorial on Op-making). 516 517 """ 518 519 default_output = None 520 """ 521 Configuration variable for `__call__`. 522 523 A subclass should not change this class variable, but instead over-ride it with a subclass 524 variable or an instance variable. 525 526 """ 527 528 ############# 529 # make_node # 530 ############# 531 532 def make_node(self, *inputs): 533 """ 534 Required: return an Apply instance representing the 535 application of this Op to the provided inputs. 536 537 """ 538 raise utils.MethodNotDefined( 539 "make_node", type(self), self.__class__.__name__) 540 541 @classmethod 542 def _get_test_value(cls, v): 543 """ 544 Extract test value from variable v. 545 Raises AttributeError if there is none. 546 547 For a Constant, the test value is v.value. 548 For a Shared variable, it is the internal value. 549 For another Variable, it is the content of v.tag.test_value. 550 551 """ 552 # avoid circular import 553 from theano.compile.sharedvalue import SharedVariable 554 555 if isinstance(v, graph.Constant): 556 return v.value 557 elif isinstance(v, SharedVariable): 558 return v.get_value(borrow=True, return_internal_type=True) 559 elif isinstance(v, graph.Variable) and hasattr(v.tag, 'test_value'): 560 # ensure that the test value is correct 561 try: 562 ret = v.type.filter(v.tag.test_value) 563 except Exception as e: 564 # Better error message. 565 detailed_err_msg = ( 566 "For compute_test_value, one input test value does not" 567 " have the requested type.\n") 568 detailed_err_msg += utils.get_variable_trace_string(v) 569 570 detailed_err_msg += ( 571 "\nThe error when converting the test value to that" 572 " variable type:") 573 # We need to only have 1 args and it should be of type 574 # string. Otherwise, it print the tuple and so the 575 # new line do not get printed. 576 args = (detailed_err_msg,) + tuple(str(arg) for arg in e.args) 577 e.args = ("\n".join(args),) 578 raise 579 return ret 580 detailed_err_msg = utils.get_variable_trace_string(v) 581 raise AttributeError('%s has no test value %s' % (v, detailed_err_msg)) 582 583 def __call__(self, *inputs, **kwargs): 584 """ 585 Optional: return some or all output[s] of `make_node`. 586 587 It is called by code such as: 588 589 .. python:: 590 591 x = tensor.matrix() 592 593 # tensor.exp is an Op instance, calls 594 # Op.__call__(self=<instance of exp>, inputs=(x,)) 595 y = tensor.exp(x) 596 597 This class implements a convenience function (for graph-building) which 598 uses `default_output`, but subclasses are free to override this function 599 and ignore `default_output`. 600 601 Parameters 602 ---------- 603 inputs 604 The Op's inputs, forwarded to the call to `make_node()`. 605 kwargs 606 Additional keyword arguments to be forwarded to 607 `make_node()` *except* for optional argument `return_list` (which 608 defaults to False). If `return_list` is True, then the returned 609 value is always a list. Otherwise it is either a single Variable 610 when the output of `make_node()` contains a single element, or this 611 output (unchanged) when it contains multiple elements. 612 613 """ 614 return_list = kwargs.pop('return_list', False) 615 node = self.make_node(*inputs, **kwargs) 616 617 if config.compute_test_value != 'off': 618 run_perform = True 619 620 # build test input-values 621 storage_map = {} 622 compute_map = {} 623 for i, ins in enumerate(node.inputs): 624 try: 625 storage_map[ins] = [self._get_test_value(ins)] 626 compute_map[ins] = [True] 627 except AttributeError: 628 # no test-value was specified, act accordingly 629 if config.compute_test_value == 'warn': 630 warnings.warn( 631 'Warning, Cannot compute test value: input %i (%s) of Op %s missing default value' % 632 (i, ins, node), stacklevel=2) 633 run_perform = False 634 elif config.compute_test_value == 'raise': 635 detailed_err_msg = utils.get_variable_trace_string(ins) 636 637 raise ValueError( 638 'Cannot compute test value: input %i (%s) of Op %s missing default value. %s' % 639 (i, ins, node, detailed_err_msg)) 640 elif config.compute_test_value == 'ignore': 641 # silently skip test 642 run_perform = False 643 elif config.compute_test_value == 'pdb': 644 import pdb 645 pdb.post_mortem(sys.exc_info()[2]) 646 else: 647 raise ValueError( 648 '%s is invalid for option config.compute_Test_value' % 649 config.compute_test_value) 650 651 # if all inputs have test-values, run the actual op 652 if run_perform: 653 # Original values should not be destroyed: 654 # copy the values of the inputs in destroy_map 655 destroyed_inputs_idx = set() 656 if getattr(node.op, 'destroy_map', None): 657 for i_pos_list in itervalues(node.op.destroy_map): 658 destroyed_inputs_idx.update(i_pos_list) 659 for inp_idx in destroyed_inputs_idx: 660 inp = node.inputs[inp_idx] 661 storage_map[inp] = [storage_map[inp][0].copy()] 662 663 # Prepare storage_map and compute_map for the outputs 664 for o in node.outputs: 665 storage_map[o] = [None] 666 compute_map[o] = [False] 667 668 # compute output value once with test inputs to validate graph 669 thunk = node.op.make_thunk(node, storage_map, compute_map, 670 no_recycling=[]) 671 thunk.inputs = [storage_map[v] for v in node.inputs] 672 thunk.outputs = [storage_map[v] for v in node.outputs] 673 674 required = thunk() 675 assert not required # We provided all inputs 676 677 for output in node.outputs: 678 # Check that the output has been computed 679 assert compute_map[output][ 680 0], (output, storage_map[output][0]) 681 682 # add 'test_value' to output tag, so that downstream ops can use these 683 # numerical values as inputs to their perform method. 684 output.tag.test_value = storage_map[output][0] 685 686 if self.default_output is not None: 687 rval = node.outputs[self.default_output] 688 if return_list: 689 rval = [rval] 690 return rval 691 else: 692 if return_list: 693 return list(node.outputs) 694 elif len(node.outputs) == 1: 695 return node.outputs[0] 696 else: 697 return node.outputs 698 699 def __ne__(self, other): 700 return not (self == other) 701 702 # Convenience so that subclass implementers don't have to import utils 703 # just to self.add_tag_trace 704 add_tag_trace = staticmethod(utils.add_tag_trace) 705 706 ######################### 707 # Python implementation # 708 ######################### 709 710 def L_op(self, inputs, outputs, output_grads): 711 return self.grad(inputs, output_grads) 712 713 def R_op(self, inputs, eval_points): 714 """ 715 This method is primarily used by tensor.Rop 716 717 Suppose the op outputs 718 719 [ f_1(inputs), ..., f_n(inputs) ] 720 721 Parameters 722 ---------- 723 inputs : a Variable or list of Variables 724 eval_points 725 A Variable or list of Variables with the same length as inputs. 726 Each element of eval_points specifies the value of the corresponding 727 input at the point where the R op is to be evaluated. 728 729 Returns 730 ------- 731 list of n elements 732 rval[i] should be Rop(f=f_i(inputs), 733 wrt=inputs, 734 eval_points=eval_points) 735 736 """ 737 raise NotImplementedError( 738 "%s of class %s does not " 739 "implement R_op. If this is a theano op, write to the " 740 "theano-dev mailing list for assistance. If it is your " 741 "own op, implement the R_op method." % 742 (self, self.__class__.__name__)) 743 744 def perform(self, node, inputs, output_storage, params=None): 745 """ 746 Required: Calculate the function on the inputs and put the variables in 747 the output storage. Return None. 748 749 Parameters 750 ---------- 751 node : Apply instance 752 Contains the symbolic inputs and outputs. 753 inputs : list 754 Sequence of inputs (immutable). 755 output_storage : list 756 List of mutable 1-element lists (do not change the length of 757 these lists) 758 759 Notes 760 ----- 761 The `output_storage` list might contain data. If an element of 762 output_storage is not None, it has to be of the right type, 763 for instance, for a TensorVariable, it has to be a Numpy ndarray, 764 with the right number of dimensions, and the correct dtype. 765 Its shape and stride pattern, can be arbitrary. It not is 766 guaranteed that it was produced by a previous call to impl. It 767 could be allocated by another Op impl is free to reuse it as it 768 sees fit, or to discard it and allocate new memory. 769 770 Raises 771 ------ 772 MethodNotDefined 773 The subclass does not override this method. 774 775 """ 776 raise utils.MethodNotDefined( 777 "perform", type(self), self.__class__.__name__, 778 "Did you used Theano flags mode=FAST_COMPILE?" 779 " You can use optimizer=fast_compile instead.") 780 781 def do_constant_folding(self, node): 782 """ 783 This allows each op to determine if it wants to be constant 784 folded when all its inputs are constant. This allows it to 785 choose where it puts its memory/speed trade-off. Also, it 786 could make things faster as constants can't be used for inplace 787 operations (see *IncSubtensor). 788 789 """ 790 return True 791 792 793class Op(utils.object2, PureOp, CLinkerOp): 794 """ 795 Convenience class to bundle `PureOp` and `CLinkerOp`. 796 797 """ 798 799 # We add a default get_params() implementation which will try to detect params from the op 800 # if params_type is set to a ParamsType. If not, we raise a MethodNotDefined exception. 801 def get_params(self, node): 802 if hasattr(self, 'params_type') and isinstance(self.params_type, theano.gof.ParamsType): 803 wrapper = self.params_type 804 if not all(hasattr(self, field) for field in wrapper.fields): 805 # Let's print missing attributes for debugging. 806 not_found = tuple(field for field in wrapper.fields if not hasattr(self, field)) 807 raise AttributeError('%s: missing attributes %s for ParamsType.' % (type(self).__name__, not_found)) 808 # ParamsType.get_params() will apply filtering to attributes. 809 return self.params_type.get_params(self) 810 raise theano.gof.utils.MethodNotDefined('get_params') 811 812 def prepare_node(self, node, storage_map, compute_map, impl): 813 """ 814 Make any special modifications that the Op needs before doing 815 make_thunk(). 816 817 This can modify the node inplace and should return nothing. 818 819 It can be called multiple time with different impl. It is the 820 op responsibility to don't re-prepare the node when it isn't 821 good to do so. 822 823 """ 824 pass 825 826 def make_c_thunk(self, node, storage_map, compute_map, no_recycling): 827 """Like make_thunk, but will only try to make a C thunk. 828 829 """ 830 node_input_storage = [storage_map[r] for r in node.inputs] 831 node_output_storage = [storage_map[r] for r in node.outputs] 832 833 e = FunctionGraph(node.inputs, node.outputs) 834 e_no_recycling = [new_o 835 for (new_o, old_o) in zip(e.outputs, node.outputs) 836 if old_o in no_recycling] 837 cl = theano.gof.cc.CLinker().accept(e, 838 no_recycling=e_no_recycling) 839 # float16 gets special treatment since running 840 # unprepared C code will get bad results. 841 if not getattr(self, '_f16_ok', False): 842 def is_f16(t): 843 return getattr(t, 'dtype', '') == 'float16' 844 845 if (any(is_f16(i.type) for i in node.inputs) or 846 any(is_f16(o.type) for o in node.outputs)): 847 # get_dynamic_module is a subset of make_thunk that is reused. 848 # This just try to build the c code 849 # It will raise an error for ops 850 # that don't implement c code. In those cases, we 851 # don't want to print a warning. 852 cl.get_dynamic_module() 853 print("Disabling C code for %s due to unsupported " 854 "float16" % (self,)) 855 raise NotImplementedError("float16") 856 _logger.debug('Trying CLinker.make_thunk') 857 outputs = cl.make_thunk(input_storage=node_input_storage, 858 output_storage=node_output_storage) 859 thunk, node_input_filters, node_output_filters = outputs 860 861 def rval(): 862 thunk() 863 for o in node.outputs: 864 compute_map[o][0] = True 865 866 rval.thunk = thunk 867 rval.cthunk = thunk.cthunk 868 rval.inputs = node_input_storage 869 rval.outputs = node_output_storage 870 rval.lazy = False 871 return rval 872 873 def make_py_thunk(self, node, storage_map, compute_map, no_recycling, 874 debug=False): 875 """ 876 Like make_thunk() but only makes python thunks. 877 878 """ 879 node_input_storage = [storage_map[r] for r in node.inputs] 880 node_output_storage = [storage_map[r] for r in node.outputs] 881 882 if debug: 883 p = node.op.debug_perform 884 else: 885 p = node.op.perform 886 887 params = node.run_params() 888 889 if params is graph.NoParams: 890 # default arguments are stored in the closure of `rval` 891 def rval(p=p, i=node_input_storage, o=node_output_storage, n=node): 892 r = p(n, [x[0] for x in i], o) 893 for o in node.outputs: 894 compute_map[o][0] = True 895 return r 896 else: 897 params_val = node.params_type.filter(params) 898 899 def rval(p=p, i=node_input_storage, o=node_output_storage, n=node, 900 params=params_val): 901 r = p(n, [x[0] for x in i], o, params) 902 for o in node.outputs: 903 compute_map[o][0] = True 904 return r 905 906 rval.inputs = node_input_storage 907 rval.outputs = node_output_storage 908 rval.perform = p 909 rval.lazy = False 910 return rval 911 912 def make_thunk(self, node, storage_map, compute_map, no_recycling, 913 impl=None): 914 """ 915 This function must return a thunk, that is a zero-arguments 916 function that encapsulates the computation to be performed 917 by this op on the arguments of the node. 918 919 Parameters 920 ---------- 921 node 922 Something previously returned by self.make_node. 923 storage_map 924 dict variable -> one-element-list where a computed 925 value for this variable may be found. 926 compute_map 927 dict variable -> one-element-list where a boolean 928 value will be found. The boolean indicates whether the 929 variable's storage_map container contains a valid value (True) 930 or if it has not been computed yet (False). 931 no_recycling 932 List of variables for which it is forbidden to reuse memory 933 allocated by a previous call. 934 impl 935 Currently, None, 'c' or 'py'. If 'c' or 'py' we will only try 936 that version of the code. 937 938 Notes 939 ----- 940 If the thunk consults the storage_map on every call, it is safe 941 for it to ignore the no_recycling argument, because elements of the 942 no_recycling list will have a value of None in the storage map. If 943 the thunk can potentially cache return values (like CLinker does), 944 then it must not do so for variables in the no_recycling list. 945 946 self.prepare_node(node, ...) is always called. If we try 'c' and it 947 fail and we try again 'py', prepare_node will be called twice. 948 """ 949 950 if (impl is None and theano.config.cxx) or impl == 'c': 951 self.prepare_node(node, storage_map=storage_map, 952 compute_map=compute_map, impl='c') 953 try: 954 return self.make_c_thunk(node, storage_map, compute_map, 955 no_recycling) 956 except (NotImplementedError, utils.MethodNotDefined): 957 # We requested the c code, so don't catch the error. 958 if impl == 'c': 959 raise 960 _logger.debug('Falling back on perform') 961 962 # condition: either there was no c_code, or it failed or 963 # python code was requested. 964 self.prepare_node(node, storage_map=storage_map, 965 compute_map=compute_map, impl='py') 966 return self.make_py_thunk(node, storage_map, compute_map, no_recycling) 967 968 def make_node(self, *inputs): 969 """ 970 Create a "apply" nodes for the inputs in that order. 971 """ 972 if not hasattr(self, 'itypes'): 973 raise NotImplementedError("You can either define itypes and otypes,\ 974 or implement make_node") 975 976 if not hasattr(self, 'otypes'): 977 raise NotImplementedError("You can either define itypes and otypes,\ 978 or implement make_node") 979 980 if len(inputs) != len(self.itypes): 981 raise ValueError("We expected %d inputs but got %d." % 982 (len(self.itypes), len(inputs))) 983 if not all(inp.type == it for inp, it in zip(inputs, self.itypes)): 984 raise TypeError( 985 "We expected inputs of types '%s' but got types '%s' " % 986 (str(self.itypes), str([inp.type for inp in inputs]))) 987 return theano.Apply(self, inputs, [o() for o in self.otypes]) 988 989 990def get_test_value(v): 991 """ 992 Extract test value from `v`. Raises AttributeError if there is none. 993 994 If input `v` is not already a variable, it is turned into one by calling 995 `as_tensor_variable(v)`, so that this function can be applied e.g. 996 on numpy arrays or Python lists and scalars, considering them as constants. 997 998 For a Constant, the test value is v.value. 999 For a Shared variable, it is the internal value. 1000 For another Variable, it is the content of v.tag.test_value. 1001 1002 """ 1003 if not isinstance(v, graph.Variable): 1004 v_var = theano.tensor.as_tensor_variable(v) 1005 else: 1006 v_var = v 1007 return PureOp._get_test_value(v_var) 1008 1009 1010def missing_test_message(msg): 1011 """ 1012 Displays msg, a message saying that some test_value is missing, 1013 in the appropriate form based on config.compute_test_value: 1014 1015 off: The interactive debugger is off, so we do nothing. 1016 ignore: The interactive debugger is set to ignore missing inputs, 1017 so do nothing. 1018 warn: Display msg as a warning. 1019 1020 Raises 1021 ------ 1022 AttributeError 1023 With msg as the exception text. 1024 1025 """ 1026 action = config.compute_test_value 1027 if action == 'raise': 1028 raise AttributeError(msg) 1029 elif action == 'warn': 1030 warnings.warn(msg, stacklevel=2) 1031 else: 1032 assert action in ['ignore', 'off'] 1033 1034 1035def debug_error_message(msg): 1036 """ 1037 Displays a message saying that an error was found in some 1038 test_values. Becomes a warning or a ValueError depending on 1039 config.compute_test_value. 1040 1041 """ 1042 action = config.compute_test_value 1043 1044 # this message should never be called when the debugger is off 1045 assert action != 'off' 1046 1047 if action in ['raise', 'ignore']: 1048 raise ValueError(msg) 1049 else: 1050 assert action == 'warn' 1051 warnings.warn(msg, stacklevel=2) 1052 1053 1054def debug_assert(condition, msg=None): 1055 """ 1056 Customized assert with options to ignore the assert 1057 with just a warning 1058 """ 1059 if msg is None: 1060 msg = 'debug_assert failed' 1061 if not condition: 1062 action = config.compute_test_value 1063 if action in ['raise', 'ignore']: 1064 raise AssertionError(msg) 1065 else: 1066 assert action == 'warn' 1067 warnings.warn(msg, stacklevel=2) 1068 1069 1070def get_debug_values(*args): 1071 """ 1072 Intended use: 1073 1074 for val_1, ..., val_n in get_debug_values(var_1, ..., var_n): 1075 if some condition on val_1, ..., val_n is not met: 1076 debug_error_message("condition was not met") 1077 1078 Given a list of variables, get_debug_values does one of three things: 1079 1080 1. If the interactive debugger is off, returns an empty list 1081 2. If the interactive debugger is on, and all variables have 1082 debug values, returns a list containing a single element. 1083 This single element is either: 1084 a) if there is only one variable, the element is its 1085 value 1086 b) otherwise, a tuple containing debug values of all 1087 the variables. 1088 3. If the interactive debugger is on, and some variable does 1089 not have a debug value, issue a missing_test_message about 1090 the variable, and, if still in control of execution, return 1091 an empty list. 1092 1093 """ 1094 1095 if config.compute_test_value == 'off': 1096 return [] 1097 1098 rval = [] 1099 1100 for i, arg in enumerate(args): 1101 try: 1102 rval.append(get_test_value(arg)) 1103 except AttributeError: 1104 if hasattr(arg, 'name') and arg.name is not None: 1105 missing_test_message("Argument " + str(i) + "('" + arg.name + 1106 "') has no test value") 1107 else: 1108 missing_test_message("Argument " + str(i) + 1109 " has no test value") 1110 return [] 1111 1112 if len(rval) == 1: 1113 return rval 1114 1115 return [tuple(rval)] 1116 1117 1118ops_with_inner_function = {} 1119""" 1120Registry of Ops that have an inner compiled Theano function. 1121 1122The keys are Op classes (not instances), and values are the name of the 1123attribute that contains the function. For instance, if the function is 1124self.fn, the value will be 'fn'. 1125 1126We need that to be able not to run debug checks a number of times that is 1127exponential in the nesting level of those ops. 1128For instance, Scan will be registered here. 1129 1130""" 1131 1132 1133class OpenMPOp(Op): 1134 """ 1135 All op using OpenMP code should inherit from this Op. 1136 1137 This op will check that the compiler support correctly OpenMP code. 1138 If not, it will print a warning and disable openmp for this Op. 1139 Then it will generate the not OpenMP code. 1140 1141 This is needed as EPD on Windows g++ version spec information tell 1142 it support OpenMP, but does not include the OpenMP files. 1143 1144 We also add the correct compiler flags in c_compile_args. 1145 1146 """ 1147 1148 gxx_support_openmp = None 1149 """ 1150 True/False after we tested this. 1151 1152 """ 1153 1154 def __init__(self, openmp=None): 1155 if openmp is None: 1156 openmp = theano.config.openmp 1157 self.openmp = openmp 1158 1159 def __setstate__(self, d): 1160 self.__dict__.update(d) 1161 # If we unpickle old op 1162 if not hasattr(self, "openmp"): 1163 self.openmp = False 1164 1165 def c_compile_args(self): 1166 """ 1167 Return the compilation arg "fopenmp" if openMP is supported 1168 """ 1169 self.update_self_openmp() 1170 if self.openmp: 1171 return ['-fopenmp'] 1172 return [] 1173 1174 def c_headers(self): 1175 """ 1176 Return the header file name "omp.h" if openMP is supported 1177 """ 1178 self.update_self_openmp() 1179 if self.openmp: 1180 return ["omp.h"] 1181 return [] 1182 1183 @staticmethod 1184 def test_gxx_support(): 1185 """ 1186 Check if openMP is supported 1187 """ 1188 code = """ 1189 #include <omp.h> 1190int main( int argc, const char* argv[] ) 1191{ 1192 int res[10]; 1193 1194 for(int i=0; i < 10; i++){ 1195 res[i] = i; 1196 } 1197} 1198 """ 1199 default_openmp = GCC_compiler.try_compile_tmp( 1200 src_code=code, 1201 tmp_prefix='test_omp_', 1202 flags=['-fopenmp'], 1203 try_run=False) 1204 return default_openmp 1205 1206 def update_self_openmp(self): 1207 """ 1208 Make sure self.openmp is not True if there is no support in gxx. 1209 1210 """ 1211 if self.openmp: 1212 if OpenMPOp.gxx_support_openmp is None: 1213 OpenMPOp.gxx_support_openmp = OpenMPOp.test_gxx_support() 1214 if not OpenMPOp.gxx_support_openmp: 1215 # We want to warn only once. 1216 warnings.warn( 1217 "Your g++ compiler fails to compile OpenMP code. We" 1218 " know this happen with some version of the EPD mingw" 1219 " compiler and LLVM compiler on Mac OS X." 1220 " We disable openmp everywhere in Theano." 1221 " To remove this warning set the theano flags `openmp`" 1222 " to False.", 1223 stacklevel=3) 1224 if OpenMPOp.gxx_support_openmp is False: 1225 self.openmp = False 1226 theano.config.openmp = False 1227 1228 def prepare_node(self, node, storage_map, compute_map, impl): 1229 if impl == 'c': 1230 self.update_self_openmp() 1231 1232 1233def simple_meth(tag): 1234 def f(self): 1235 if tag in self.code_sections: 1236 return self.code_sections[tag] 1237 else: 1238 raise utils.MethodNotDefined( 1239 'c_' + tag, type(self), type(self).__name__) 1240 f.__name__ = 'c_' + tag 1241 return f 1242 1243 1244def apply_meth(tag): 1245 def f(self, node, name): 1246 if tag in self.code_sections: 1247 code = self.code_sections[tag] 1248 1249 define_macros, undef_macros = self.get_c_macros(node, name) 1250 return '\n'.join(['', define_macros, code, 1251 undef_macros]) 1252 else: 1253 raise utils.MethodNotDefined( 1254 'c_' + tag, type(self), type(self).__name__) 1255 f.__name__ = 'c_' + tag 1256 return f 1257 1258 1259class COp(Op): 1260 """ 1261 Class to allow an op to have an external C implementation. 1262 1263 An op can use this class by inheriting from it and calling its 1264 __init__() method, providing it with a path to an external file containing 1265 the C implementation and the name of the function, in that file, to call 1266 to perform the computations for the op. 1267 1268 """ 1269 1270 section_re = re.compile(r'^#section ([a-zA-Z0-9_]+)$', re.MULTILINE) 1271 backward_re = re.compile( 1272 r'^THEANO_(APPLY|SUPPORT)_CODE_SECTION$', 1273 re.MULTILINE) 1274 # This is the set of allowed markers 1275 SECTIONS = set([ 1276 'init_code', 'init_code_apply', 'init_code_struct', 1277 'support_code', 'support_code_apply', 'support_code_struct', 1278 'cleanup_code_struct', 1279 'code', 'code_cleanup']) 1280 1281 @classmethod 1282 def get_path(cls, f): 1283 """ 1284 Convert a path relative to the location of the class file into 1285 an aboslute path. Paths that are already absolute are passed 1286 through unchanged. 1287 1288 """ 1289 if not os.path.isabs(f): 1290 class_file = inspect.getfile(cls) 1291 class_dir = os.path.dirname(class_file) 1292 f = os.path.realpath(os.path.join(class_dir, f)) 1293 return f 1294 1295 def __init__(self, func_files, func_name=None): 1296 """ 1297 Sections are loaded from files in order with sections in later 1298 files overriding sections in previous files. 1299 1300 """ 1301 if not isinstance(func_files, list): 1302 func_files = [func_files] 1303 1304 self.func_name = func_name 1305 # Keep the original name. If we reload old pickle, we want to 1306 # find the new path and new version of the file in Theano. 1307 self.func_files = func_files 1308 self.load_c_code(func_files) 1309 1310 if len(self.code_sections) == 0: 1311 raise ValueError("No sections where defined in C files") 1312 1313 if self.func_name is not None: 1314 if 'op_code' in self.code_sections: 1315 # maybe a warning instead (and clearing the key) 1316 raise ValueError('Cannot have an "op_code" section and ' 1317 'specify the func_name') 1318 if 'op_code_cleanup' in self.code_sections: 1319 # maybe a warning instead (and clearing the key) 1320 raise ValueError('Cannot have an "op_code_cleanup" section ' 1321 'and specify the func_name') 1322 1323 def load_c_code(self, func_files): 1324 """ 1325 Loads the c code to perform the Op 1326 """ 1327 func_files = [self.get_path(f) for f in func_files] 1328 self.func_codes = [] 1329 for func_file in func_files: 1330 # U (universal) will convert all new lines format to \n. 1331 with _open_u(func_file) as f: 1332 self.func_codes.append(f.read()) 1333 1334 # If both the old section markers and the new section markers are 1335 # present, raise an error because we don't know which ones to follow. 1336 old_markers_present = False 1337 new_markers_present = False 1338 for code in self.func_codes: 1339 if self.backward_re.search(code): 1340 old_markers_present = True 1341 if self.section_re.search(code): 1342 new_markers_present = True 1343 1344 if old_markers_present and new_markers_present: 1345 raise ValueError('Both the new and the old syntax for ' 1346 'identifying code sections are present in the ' 1347 'provided C code. These two syntaxes should not ' 1348 'be used at the same time.') 1349 1350 self.code_sections = dict() 1351 for i, code in enumerate(self.func_codes): 1352 if self.backward_re.search(code): 1353 # This is backward compat code that will go away in a while 1354 1355 # Separate the code into the proper sections 1356 split = self.backward_re.split(code) 1357 n = 1 1358 while n < len(split): 1359 if split[n] == 'APPLY': 1360 self.code_sections['support_code_apply'] = split[n + 1] 1361 elif split[n] == 'SUPPORT': 1362 self.code_sections['support_code'] = split[n + 1] 1363 n += 2 1364 continue 1365 1366 elif self.section_re.search(code): 1367 1368 # Check for code outside of the supported sections 1369 split = self.section_re.split(code) 1370 if split[0].strip() != '': 1371 raise ValueError('Stray code before first #section ' 1372 'statement (in file %s): %s' % 1373 (func_files[i], split[0])) 1374 1375 # Separate the code into the proper sections 1376 n = 1 1377 while n < len(split): 1378 if split[n] not in self.SECTIONS: 1379 raise ValueError( 1380 "Unknown section type (in file %s): %s" % 1381 (func_files[i], split[n])) 1382 if split[n] not in self.code_sections: 1383 self.code_sections[split[n]] = "" 1384 self.code_sections[split[n]] += split[n + 1] 1385 n += 2 1386 1387 else: 1388 raise ValueError("No valid section marker was found in file " 1389 "%s" % func_files[i]) 1390 1391 def __get_op_params(self): 1392 """ 1393 Returns a list of (name, value) pairs that will be turned into 1394 macros for use within the op code. 1395 1396 The names must be strings that are not a C keyword and the 1397 values must be strings of literal C representations. 1398 1399 If op uses a :class:`theano.gof.params_type.ParamsType` as ``params_type``, 1400 it returns: 1401 - a default macro ``PARAMS_TYPE`` which defines the class name of the 1402 corresponding C struct. 1403 - a macro ``DTYPE_PARAM_key`` for every ``key`` in the ParamsType for which associated 1404 type implements the method :func:`theano.gof.type.CLinkerType.c_element_type`. 1405 ``DTYPE_PARAM_key`` defines the primitive C type name of an item in a variable 1406 associated to ``key``. 1407 1408 """ 1409 if hasattr(self, 'params_type') and isinstance(self.params_type, theano.gof.ParamsType): 1410 wrapper = self.params_type 1411 params = [('PARAMS_TYPE', wrapper.name)] 1412 for i in range(wrapper.length): 1413 try: 1414 # NB (reminder): These macros are currently used only in ParamsType example test 1415 # (`theano/gof/tests/test_quadratic_function.c`), to demonstrate how we can 1416 # access params dtypes when dtypes may change (e.g. if based on theano.config.floatX). 1417 # But in practice, params types generally have fixed types per op. 1418 params.append(('DTYPE_PARAM_' + wrapper.fields[i], wrapper.types[i].c_element_type())) 1419 except utils.MethodNotDefined: 1420 pass 1421 return params 1422 return [] 1423 1424 def c_code_cache_version(self): 1425 version = (hash(tuple(self.func_codes)), ) 1426 if hasattr(self, 'params_type'): 1427 version += (self.params_type.c_code_cache_version(), ) 1428 return version 1429 1430 def c_init_code(self): 1431 """ 1432 Get the code section for init_code 1433 """ 1434 if 'init_code' in self.code_sections: 1435 return [self.code_sections['init_code']] 1436 else: 1437 raise utils.MethodNotDefined( 1438 'c_init_code', type(self), type(self).__name__) 1439 1440 c_init_code_apply = apply_meth('init_code_apply') 1441 c_support_code = simple_meth('support_code') 1442 c_support_code_apply = apply_meth('support_code_apply') 1443 c_support_code_struct = apply_meth('support_code_struct') 1444 c_cleanup_code_struct = apply_meth('cleanup_code_struct') 1445 1446 def format_c_function_args(self, inp, out): 1447 # Generate an string containing the arguments sent to the external C 1448 # function. The argstring will be of format : 1449 # "input0, input1, input2, &output0, &output1" 1450 inp = list(inp) 1451 numi = getattr(self, '_cop_num_inputs', len(inp)) 1452 while len(inp) < numi: 1453 inp.append('NULL') 1454 out = ["&%s" % o for o in out] 1455 numo = getattr(self, '_cop_num_outputs', len(out)) 1456 while len(out) < numo: 1457 out.append('NULL') 1458 return ", ".join(inp + out) 1459 1460 def get_c_macros(self, node, name, check_input=None): 1461 define_template = "#define %s %s" 1462 undef_template = "#undef %s" 1463 define_macros = [] 1464 undef_macros = [] 1465 1466 if check_input is None: 1467 check_input = getattr(self, 'check_input', True) 1468 1469 if check_input: 1470 # Extract the various properties of the input and output variables 1471 variables = node.inputs + node.outputs 1472 variable_names = (["INPUT_%i" % i for i in range(len(node.inputs))] + 1473 ["OUTPUT_%i" % i for i in range(len(node.outputs))]) 1474 1475 # Generate dtype macros 1476 for i, v in enumerate(variables): 1477 if not hasattr(v, 'dtype'): 1478 continue 1479 vname = variable_names[i] 1480 1481 macro_name = "DTYPE_" + vname 1482 macro_value = "npy_" + v.dtype 1483 1484 define_macros.append( 1485 define_template % 1486 (macro_name, macro_value)) 1487 undef_macros.append(undef_template % macro_name) 1488 1489 d = np.dtype(v.dtype) 1490 1491 macro_name = "TYPENUM_" + vname 1492 macro_value = d.num 1493 1494 define_macros.append( 1495 define_template % 1496 (macro_name, macro_value)) 1497 undef_macros.append(undef_template % macro_name) 1498 1499 macro_name = "ITEMSIZE_" + vname 1500 macro_value = d.itemsize 1501 1502 define_macros.append( 1503 define_template % 1504 (macro_name, macro_value)) 1505 undef_macros.append(undef_template % macro_name) 1506 1507 # Generate a macro to mark code as being apply-specific 1508 define_macros.append(define_template % ("APPLY_SPECIFIC(str)", 1509 "str##_%s" % name)) 1510 undef_macros.append(undef_template % "APPLY_SPECIFIC") 1511 1512 for n, v in self.__get_op_params(): 1513 define_macros.append(define_template % (n, v)) 1514 undef_macros.append(undef_template % (n,)) 1515 1516 return '\n'.join(define_macros), '\n'.join(undef_macros) 1517 1518 def _lquote_macro(self, txt): 1519 res = [] 1520 spl = txt.split('\n') 1521 for l in spl[:-1]: 1522 res.append(l + ' \\') 1523 res.append(spl[-1]) 1524 return '\n'.join(res) 1525 1526 def get_sub_macros(self, sub): 1527 define_macros = [] 1528 undef_macros = [] 1529 define_macros.append("#define FAIL %s" % ( 1530 self._lquote_macro(sub['fail']),)) 1531 undef_macros.append("#undef FAIL") 1532 if 'params' in sub: 1533 define_macros.append("#define PARAMS %s" % (sub['params'],)) 1534 undef_macros.append("#undef PARAMS") 1535 1536 return '\n'.join(define_macros), '\n'.join(undef_macros) 1537 1538 def get_io_macros(self, inputs, outputs): 1539 define_macros = [] 1540 undef_macros = [] 1541 1542 for i, inp in enumerate(inputs): 1543 define_macros.append("#define INPUT_%d %s" % (i, inp)) 1544 undef_macros.append("#undef INPUT_%d" % (i,)) 1545 1546 for i, out in enumerate(outputs): 1547 define_macros.append("#define OUTPUT_%d %s" % (i, inp)) 1548 undef_macros.append("#undef OUTPUT_%d" % (i,)) 1549 1550 def c_init_code_struct(self, node, name, sub): 1551 """ 1552 Stitches all the macros and "init_code" together 1553 1554 """ 1555 if 'init_code_struct' in self.code_sections: 1556 op_code = self.code_sections['init_code_struct'] 1557 1558 def_macros, undef_macros = self.get_c_macros(node, name) 1559 def_sub, undef_sub = self.get_sub_macros(sub) 1560 1561 return '\n'.join(['', def_macros, def_sub, 1562 op_code, 1563 undef_sub, undef_macros]) 1564 else: 1565 raise utils.MethodNotDefined( 1566 'c_init_code_struct', type(self), type(self).__name__) 1567 1568 def c_code(self, node, name, inp, out, sub): 1569 if self.func_name is not None: 1570 assert 'code' not in self.code_sections 1571 1572 define_macros, undef_macros = self.get_c_macros(node, name, 1573 check_input=False) 1574 1575 params = "" 1576 if 'params' in sub: 1577 params = ", %s" % (sub['params'],) 1578 1579 # Generate the C code 1580 return """ 1581 %(define_macros)s 1582 { 1583 if (%(func_name)s(%(func_args)s%(params)s) != 0) { 1584 %(fail)s 1585 } 1586 } 1587 %(undef_macros)s 1588 """ % dict(func_name=self.func_name, 1589 fail=sub['fail'], params=params, 1590 func_args=self.format_c_function_args(inp, out), 1591 define_macros=define_macros, 1592 undef_macros=undef_macros) 1593 else: 1594 if 'code' in self.code_sections: 1595 op_code = self.code_sections['code'] 1596 1597 def_macros, undef_macros = self.get_c_macros(node, name) 1598 def_sub, undef_sub = self.get_sub_macros(sub) 1599 # FIXME: get_io_macros() doesn't return anything. Unpacking will raise a TypeError. 1600 def_io, undef_io = self.get_io_macros(inp, out) 1601 1602 return '\n'.join([def_macros, def_sub, def_io, 1603 op_code, 1604 undef_io, undef_sub, undef_macros]) 1605 else: 1606 raise utils.MethodNotDefined( 1607 'c_code', type(self), type(self).__name__) 1608 1609 def c_code_cleanup(self, node, name, inputs, outputs, sub): 1610 """ 1611 Stitches all the macros and "code_cleanup" together 1612 """ 1613 if 'code_cleanup' in self.code_sections: 1614 op_code = self.code_sections['code_cleanup'] 1615 1616 def_macros, undef_macros = self.get_c_macros(node, name) 1617 def_sub, undef_sub = self.get_sub_macros(sub) 1618 # FIXME: get_io_macros() doesn't return anything. Unpacking will raise a TypeError. 1619 def_io, undef_io = self.get_io_macros(inputs, outputs) 1620 1621 return '\n'.join([def_macros, def_sub, def_io, 1622 op_code, 1623 undef_io, undef_sub, undef_macros]) 1624 else: 1625 raise utils.MethodNotDefined( 1626 'c_code_cleanup', type(self), type(self).__name__) 1627