1""" 2.. warning:: 3 4This directory is for the internal of Theano. 5 6You are strongly advised not to use it, except if you know 7what you are doing! 8 9If you want to use a scalar variable in a Theano graph, 10you probably want to use theano.tensor.[c,z,f,d,b,w,i,l,]scalar! 11""" 12from __future__ import absolute_import, print_function, division 13 14from itertools import chain 15import math 16import warnings 17from copy import copy 18from textwrap import dedent 19 20import numpy as np 21import six 22from six.moves import xrange 23 24import theano 25from theano.compat import imap, izip, Callable 26from theano import gof, printing 27from theano.gof import (Op, utils, Variable, Constant, Type, Apply, 28 FunctionGraph) 29from functools import partial 30from theano import config 31 32from theano.gradient import DisconnectedType 33from theano.gradient import grad_undefined 34 35from theano.printing import pprint 36 37builtin_bool = bool 38builtin_complex = complex 39builtin_int = int 40builtin_float = float 41 42 43class ComplexError(NotImplementedError): 44 """ 45 Raised if complex numbers are used in an unsupported operation. 46 47 """ 48 pass 49 50 51class IntegerDivisionError(Exception): 52 """ 53 Raised if someone tries to divide integers with '/' instead of '//'. 54 55 """ 56 pass 57 58 59def upcast(dtype, *dtypes): 60 # This tries to keep data in floatX or lower precision, unless we 61 # explicitely request a higher precision datatype. 62 keep_float32 = [(config.cast_policy == 'numpy+floatX' and 63 config.floatX == 'float32')] 64 keep_float16 = [(config.cast_policy == 'numpy+floatX' and 65 config.floatX == 'float16')] 66 67 def make_array(dt): 68 if dt == 'float64': 69 # There is an explicit float64 dtype: we cannot keep float32. 70 keep_float32[0] = False 71 keep_float16[0] = False 72 if dt == 'float32': 73 keep_float16[0] = False 74 return np.zeros((), dtype=dt) 75 z = make_array(dtype) 76 for dt in dtypes: 77 z = z + make_array(dt=dt) 78 rval = str(z.dtype) 79 if rval == 'float64': 80 if keep_float16[0]: 81 return 'float16' 82 if keep_float32[0]: 83 return 'float32' 84 elif rval == 'float32': 85 if keep_float16[0]: 86 return 'float16' 87 return rval 88 89 90def as_common_dtype(*vars): 91 """ 92 For for theano.scalar.Scalar and TensorVariable. 93 """ 94 dtype = upcast(*[v.dtype for v in vars]) 95 return (v.astype(dtype) for v in vars) 96 97 98def get_scalar_type(dtype): 99 """ 100 Return a Scalar(dtype) object. 101 102 This caches objects to save allocation and run time. 103 104 """ 105 if dtype not in get_scalar_type.cache: 106 get_scalar_type.cache[dtype] = Scalar(dtype=dtype) 107 return get_scalar_type.cache[dtype] 108get_scalar_type.cache = {} 109 110 111def as_scalar(x, name=None): 112 from ..tensor import TensorType, scalar_from_tensor 113 if isinstance(x, gof.Apply): 114 if len(x.outputs) != 1: 115 raise ValueError("It is ambiguous which output of a multi-output" 116 " Op has to be fetched.", x) 117 else: 118 x = x.outputs[0] 119 if isinstance(x, Variable): 120 if isinstance(x.type, Scalar): 121 return x 122 elif isinstance(x.type, TensorType) and x.ndim == 0: 123 return scalar_from_tensor(x) 124 else: 125 raise TypeError("Variable type field must be a Scalar.", x, x.type) 126 try: 127 return constant(x) 128 except TypeError: 129 raise TypeError("Cannot convert %s to Scalar" % x, type(x)) 130 131 132class NumpyAutocaster(object): 133 """ 134 This class is used to cast python ints and floats to numpy arrays. 135 136 The behavior when called on scalar `x` depends on `config.cast_policy`: 137 - 'numpy' will simply use the same type as found by `numpy.asarray(x)`. 138 - 'numpy+floatX' will do the same, except it will use float32 instead 139 of float64 if `x` is a Python float and `config.floatX` is set to 140 'float32' (note that if `x` is a numpy scalar whose data type is 141 float64, it is not modified since we assume the user is purposely 142 using float64). 143 - 'custom' lets one define a tuple of data types such that: 144 - if `x` is already a numpy scalar and its data type is in this 145 tuple, then it is returned unchanged; 146 - otherwise, the first data type in this tuple that can represent 147 `x` without loss of precision will be used, unless `x` is a float 148 and 'float32' is in the tuple (in which case `x` is cast as a 149 float32); 150 - if no data type can represent `x` without loss of precision, then 151 the last data type in the tuple will be used. 152 153 154 Parameters 155 ---------- 156 dtypes: tuple of strings 157 The ordered list of preferred data types (only used when 158 `config.cast_policy` is set to 'custom', see the `NumpyAutocaster` 159 help for details). 160 161 """ 162 163 def __init__(self, dtypes): 164 self.dtypes = tuple(dtypes) 165 166 def __call__(self, x): 167 # Make sure we only deal with scalars. 168 assert (isinstance(x, six.integer_types) or 169 isinstance(x, builtin_float) or 170 (isinstance(x, np.ndarray) and x.ndim == 0)) 171 172 if config.cast_policy == 'numpy': 173 return np.asarray(x) 174 elif config.cast_policy == 'numpy+floatX': 175 rval = np.asarray(x) 176 if ((not hasattr(x, 'dtype') and 177 rval.dtype in ('float64', 'float32') and 178 rval.dtype != config.floatX)): 179 rval = theano._asarray(rval, dtype=config.floatX) 180 return rval 181 182 # The following is the original code, corresponding to the 'custom' 183 # option for `config.cast_policy`. 184 assert config.cast_policy == 'custom' 185 186 try: 187 # Pass through numpy scalars, since they are already typed on 188 # purpose typically. 189 if str(x.dtype) in self.dtypes: 190 # No need to cast `x` into a new dtype. Note that we still 191 # need to convert it into an array, because it may not be 192 # one already (e.g. if x == numpy.float64(1.1)). 193 return np.asarray(x) 194 except AttributeError: 195 # Means `x` has no 'dtype' attribute. 196 pass 197 198 # unsafe downcast of float64 variables when config.floatX == 'float32' 199 # recall: float is numpy.float 200 if ((isinstance(x, float) and 201 config.floatX in self.dtypes and 202 config.floatX != 'float64')): 203 return theano._asarray(x, dtype=config.floatX) 204 205 # Don't autocast to float16 unless config.floatX is float16 206 try_dtypes = [d for d in self.dtypes 207 if config.floatX == 'float16' or d != 'float16'] 208 209 for dtype in try_dtypes: 210 x_ = theano._asarray(x, dtype=dtype) 211 if np.all(x == x_): 212 break 213 # returns either an exact x_==x, or the last cast x_ 214 return x_ 215 216autocast_int = NumpyAutocaster(('int8', 'int16', 'int32', 'int64')) 217# autocast_float dtypes might be manipulated in tensor.* 218autocast_float = NumpyAutocaster(('float16', 'float32', 'float64')) 219 220 221class autocast_float_as(object): 222 """ 223 Temporarily adjust autocasting behavior. 224 225 This class makes it possible to temporarily and locally adjust autocasting 226 behavior when `config.cast_policy` is set to 'custom'. 227 If `config.cast_policy` is not 'custom', an exception is raised. 228 This class might be convenient in some code, but it definitely 229 helps to test the autocasting mechanism. 230 231 Examples 232 -------- 233 >>> with autocast_float_as('float32'): 234 ... assert (fvector() + 1.1).dtype == 'float32' # temporary downcasting 235 >>> assert (fvector() + 1.1).dtype == 'float64' # back to default behaviour 236 237 """ 238 def __init__(self, *dtypes): 239 self.dtypes = dtypes 240 assert config.cast_policy == 'custom' 241 242 def __enter__(self): 243 assert config.cast_policy == 'custom' 244 self.old_dtypes = autocast_float.dtypes 245 autocast_float.dtypes = self.dtypes 246 247 def __exit__(self, *args): 248 assert config.cast_policy == 'custom' 249 autocast_float.dtypes = self.old_dtypes 250 251 252def convert(x, dtype=None): 253 """ 254 Convert the input to a properly typed numpy value according to the 255 current casting policy. Work with scalars and tensors. 256 257 """ 258 if dtype is not None: 259 # in this case, the semantics are that the caller is forcing the dtype 260 x_ = theano._asarray(x, dtype=dtype) 261 else: 262 # In this case, this function should infer the dtype according to the 263 # autocasting rules. See autocasting above. 264 x_ = None 265 if isinstance(x, six.integer_types): 266 try: 267 x_ = autocast_int(x) 268 except OverflowError: 269 # This is to imitate numpy behavior which tries to fit 270 # bigger numbers into a uint64. 271 x_ = theano._asarray(x, dtype='uint64') 272 elif isinstance(x, builtin_float): 273 x_ = autocast_float(x) 274 elif isinstance(x, np.ndarray): 275 x_ = x 276 else: 277 # Here x is probably a list or a tuple. If it contains a 278 # long, we will behave like the current NumPy version: it 279 # will work if the long fits in int64 or uint64. 280 x_ = np.asarray(x) 281 if x_.size == 0 and not hasattr(x, 'dtype'): 282 x_ = np.asarray(x, dtype=config.floatX) 283 assert type(x_) in [np.ndarray, np.memmap] 284 return x_ 285 286 287def constant(x, name=None, dtype=None): 288 x = convert(x, dtype=dtype) 289 assert x.ndim == 0 290 return ScalarConstant(get_scalar_type(str(x.dtype)), x, name=name) 291 292 293class Scalar(Type): 294 295 """ 296 Internal class, should not be used by clients. 297 298 Primarily used by tensor.elemwise and tensor.reduce. 299 Analogous to TensorType, but for zero-dimensional objects. 300 Maps directly to C primitives. 301 302 TODO: refactor to be named ScalarType for consistency with TensorType. 303 304 """ 305 __props__ = ('dtype',) 306 ndim = 0 307 308 def __init__(self, dtype): 309 if dtype == 'floatX': 310 dtype = config.floatX 311 self.dtype = dtype 312 self.dtype_specs() # error checking 313 314 @staticmethod 315 def may_share_memory(a, b): 316 # This class represent basic c type, represented in python 317 # with numpy.scalar. They are read only. So from python, they 318 # can never share memory. 319 return False 320 321 def filter(self, data, strict=False, allow_downcast=None): 322 py_type = self.dtype_specs()[0] 323 if strict and not isinstance(data, py_type): 324 raise TypeError("%s expected a %s, got %s of type %s" % ( 325 self, py_type, data, type(data)), data) 326 try: 327 converted_data = py_type(data) 328 if (allow_downcast or 329 (allow_downcast is None and 330 type(data) is float and 331 self.dtype == theano.config.floatX) or 332 data == converted_data): 333 return py_type(data) 334 else: 335 raise TypeError('Value cannot accurately be converted to dtype' 336 ' (%s) and allow_downcast is not True' % 337 self.dtype) 338 except Exception as e: 339 raise TypeError("Could not convert %s (value=%s) to %s" % ( 340 type(data), data, self.dtype), e) 341 342 def values_eq_approx(self, a, b, tolerance=1e-4): 343 # The addition have risk of overflow especially with [u]int8 344 if self.dtype == 'bool': 345 return a == b 346 diff = a - b 347 if diff == 0: 348 return True 349 return abs(diff) <= (abs(a) * tolerance) + (abs(b) * tolerance) 350 351 def c_element_type(self): 352 return self.dtype_specs()[1] 353 354 def c_headers(self, c_compiler): 355 l = ['<math.h>'] 356 # These includes are needed by Scalar and TensorType, 357 # we declare them here and they will be re-used by TensorType 358 l.append('<numpy/arrayobject.h>') 359 l.append('<numpy/arrayscalars.h>') 360 if config.lib.amdlibm and c_compiler.supports_amdlibm: 361 l += ['<amdlibm.h>'] 362 return l 363 364 def c_libraries(self, c_compiler): 365 l = [] 366 if config.lib.amdlibm and c_compiler.supports_amdlibm: 367 l += ['amdlibm'] 368 return l 369 370 def c_compile_args(self, c_compiler): 371 if config.lib.amdlibm and c_compiler.supports_amdlibm: 372 return ['-DREPLACE_WITH_AMDLIBM'] 373 else: 374 return [] 375 376 def dtype_specs(self): 377 try: 378 # To help debug dtype/typenum problem, here is code to get 379 # the list of numpy typenum. This list change between 32 380 # and 64 bit platform and probably also also between 381 # Windows and Linux. 382 # NOTE: equivalent type on a platform can have different typenum. 383 # This is the source of all dtype/typenum problem found up to 384 # now, as Theano always expect the exact typenum that 385 # correspond to our supported dtype. 386 """ 387 for dtype in ['bool', 'int8', 'uint8', 'short', 'ushort', 'intc', 388 'uintc', 389 'longlong', 'ulonglong', 'single', 'double', 390 'longdouble', 'csingle', 'cdouble', 'clongdouble', 391 'float32', 'float64', 'int8', 'int16', 'int32', 392 'int64', 'uint8', 'uint16', 'uint32', 'uint64', 393 'complex64', 'complex128', 'float', 'double', 394 'int', 'uint']: 395 print(dtype, np.zeros(1, dtype=dtype).dtype.num) 396 """ 397 return { # dtype: (py_type, c_type, cls_name) 398 'float16': (np.float16, 'npy_float16', 'Float16'), 399 'float32': (np.float32, 'npy_float32', 'Float32'), 400 'float64': (np.float64, 'npy_float64', 'Float64'), 401 'complex128': (np.complex128, 'theano_complex128', 402 'Complex128'), 403 'complex64': (np.complex64, 'theano_complex64', 'Complex64'), 404 'bool': (np.bool_, 'npy_bool', 'Bool'), 405 'uint8': (np.uint8, 'npy_uint8', 'UInt8'), 406 'int8': (np.int8, 'npy_int8', 'Int8'), 407 'uint16': (np.uint16, 'npy_uint16', 'UInt16'), 408 'int16': (np.int16, 'npy_int16', 'Int16'), 409 'uint32': (np.uint32, 'npy_uint32', 'UInt32'), 410 'int32': (np.int32, 'npy_int32', 'Int32'), 411 'uint64': (np.uint64, 'npy_uint64', 'UInt64'), 412 'int64': (np.int64, 'npy_int64', 'Int64') 413 }[self.dtype] 414 except KeyError: 415 raise TypeError("Unsupported dtype for %s: %s" % ( 416 self.__class__.__name__, self.dtype)) 417 418 def upcast(self, *others): 419 return upcast(*[x.dtype for x in [self] + list(others)]) 420 421 def make_variable(self, name=None): 422 return ScalarVariable(self, name=name) 423 424 def __str__(self): 425 return str(self.dtype) 426 427 def __repr__(self): 428 return "Scalar(%s)" % self.dtype 429 430 def c_literal(self, data): 431 if 'complex' in self.dtype: 432 raise NotImplementedError("No literal for complex values.") 433 if self.dtype == 'bool': 434 return '1' if data else '0' 435 return str(data) 436 437 def c_declare(self, name, sub, check_input=True): 438 if(check_input): 439 pre = """ 440 typedef %(dtype)s dtype_%(name)s; 441 """ % dict(name=name, dtype=self.dtype_specs()[1]) 442 else: 443 pre = "" 444 return pre + """ 445 %(dtype)s %(name)s; 446 """ % dict(name=name, dtype=self.dtype_specs()[1]) 447 448 def c_init(self, name, sub): 449 return """ 450 %(name)s = 0; 451 """ % locals() 452 453 def c_extract(self, name, sub, check_input=True): 454 if self.dtype == 'float16': 455 # This doesn't work at the numpy level 456 raise NotImplementedError('float16') 457 specs = self.dtype_specs() 458 if(check_input): 459 pre = """ 460 if (!PyObject_TypeCheck(py_%(name)s, &%(pyarr_type)s)) 461 { 462 PyErr_Format(PyExc_ValueError, 463 "Scalar check failed (%(dtype)s)"); 464 %(fail)s 465 } 466 """ % dict(sub, 467 name=name, 468 dtype=specs[1], 469 pyarr_type='Py%sArrType_Type' % specs[2]) 470 else: 471 pre = "" 472 return pre + """ 473 PyArray_ScalarAsCtype(py_%(name)s, &%(name)s); 474 """ % dict(sub, name=name) 475 476 def c_sync(self, name, sub): 477 specs = self.dtype_specs() 478 return """ 479 Py_XDECREF(py_%(name)s); 480 py_%(name)s = PyArrayScalar_New(%(cls)s); 481 if (!py_%(name)s) 482 { 483 Py_XINCREF(Py_None); 484 py_%(name)s = Py_None; 485 PyErr_Format(PyExc_MemoryError, 486 "Instantiation of new Python scalar failed (%(dtype)s)"); 487 %(fail)s 488 } 489 PyArrayScalar_ASSIGN(py_%(name)s, %(cls)s, %(name)s); 490 """ % dict(sub, 491 name=name, 492 dtype=specs[1], 493 cls=specs[2]) 494 495 def c_cleanup(self, name, sub): 496 return "" 497 498 def c_support_code(self): 499 500 if self.dtype.startswith('complex'): 501 cplx_types = ['theano_complex64', 'theano_complex128'] 502 real_types = ['npy_int8', 'npy_int16', 'npy_int32', 'npy_int64', 503 'npy_float32', 'npy_float64'] 504 # If the 'int' C type is not exactly the same as an existing 505 # 'npy_intX', some C code may not compile, e.g. when assigning 506 # the value 0 (cast to 'int' in C) to a theano_complex64. 507 if (np.dtype('intc').num not in 508 [np.dtype(d[4:]).num for d in real_types]): 509 # In that case we add the 'int' type to the real types. 510 real_types.append('int') 511 512 template = """ 513 struct theano_complex%(nbits)s : public npy_complex%(nbits)s 514 { 515 typedef theano_complex%(nbits)s complex_type; 516 typedef npy_float%(half_nbits)s scalar_type; 517 518 complex_type operator +(const complex_type &y) const { 519 complex_type ret; 520 ret.real = this->real + y.real; 521 ret.imag = this->imag + y.imag; 522 return ret; 523 } 524 525 complex_type operator -() const { 526 complex_type ret; 527 ret.real = -this->real; 528 ret.imag = -this->imag; 529 return ret; 530 } 531 bool operator ==(const complex_type &y) const { 532 return (this->real == y.real) && (this->imag == y.imag); 533 } 534 bool operator ==(const scalar_type &y) const { 535 return (this->real == y) && (this->imag == 0); 536 } 537 complex_type operator -(const complex_type &y) const { 538 complex_type ret; 539 ret.real = this->real - y.real; 540 ret.imag = this->imag - y.imag; 541 return ret; 542 } 543 complex_type operator *(const complex_type &y) const { 544 complex_type ret; 545 ret.real = this->real * y.real - this->imag * y.imag; 546 ret.imag = this->real * y.imag + this->imag * y.real; 547 return ret; 548 } 549 complex_type operator /(const complex_type &y) const { 550 complex_type ret; 551 scalar_type y_norm_square = y.real * y.real + y.imag * y.imag; 552 ret.real = (this->real * y.real + this->imag * y.imag) / y_norm_square; 553 ret.imag = (this->imag * y.real - this->real * y.imag) / y_norm_square; 554 return ret; 555 } 556 template <typename T> 557 complex_type& operator =(const T& y); 558 559 theano_complex%(nbits)s() {} 560 561 template <typename T> 562 theano_complex%(nbits)s(const T& y) { *this = y; } 563 564 template <typename TR, typename TI> 565 theano_complex%(nbits)s(const TR& r, const TI& i) { this->real=r; this->imag=i; } 566 }; 567 """ 568 569 def operator_eq_real(mytype, othertype): 570 return ''' 571 template <> %(mytype)s & %(mytype)s::operator=<%(othertype)s>(const %(othertype)s & y) 572 { this->real=y; this->imag=0; return *this; } 573 ''' % dict(mytype=mytype, othertype=othertype) 574 575 def operator_eq_cplx(mytype, othertype): 576 return ''' 577 template <> %(mytype)s & %(mytype)s::operator=<%(othertype)s>(const %(othertype)s & y) 578 { this->real=y.real; this->imag=y.imag; return *this; } 579 ''' % dict(mytype=mytype, othertype=othertype) 580 581 operator_eq = (''.join(operator_eq_real(ctype, rtype) 582 for ctype in cplx_types 583 for rtype in real_types) + 584 ''.join(operator_eq_cplx(ctype1, ctype2) 585 for ctype1 in cplx_types 586 for ctype2 in cplx_types)) 587 588 # We are not using C++ generic templating here, because this would 589 # generate two different functions for adding a complex64 and a 590 # complex128, one returning a complex64, the other a complex128, 591 # and the compiler complains it is ambiguous. 592 # Instead, we generate code for known and safe types only. 593 594 def operator_plus_real(mytype, othertype): 595 return ''' 596 const %(mytype)s operator+(const %(mytype)s &x, const %(othertype)s &y) 597 { return %(mytype)s(x.real+y, x.imag); } 598 599 const %(mytype)s operator+(const %(othertype)s &y, const %(mytype)s &x) 600 { return %(mytype)s(x.real+y, x.imag); } 601 ''' % dict(mytype=mytype, othertype=othertype) 602 603 operator_plus = ''.join(operator_plus_real(ctype, rtype) 604 for ctype in cplx_types 605 for rtype in real_types) 606 607 def operator_minus_real(mytype, othertype): 608 return ''' 609 const %(mytype)s operator-(const %(mytype)s &x, const %(othertype)s &y) 610 { return %(mytype)s(x.real-y, x.imag); } 611 612 const %(mytype)s operator-(const %(othertype)s &y, const %(mytype)s &x) 613 { return %(mytype)s(y-x.real, -x.imag); } 614 ''' % dict(mytype=mytype, othertype=othertype) 615 616 operator_minus = ''.join(operator_minus_real(ctype, rtype) 617 for ctype in cplx_types 618 for rtype in real_types) 619 620 def operator_mul_real(mytype, othertype): 621 return ''' 622 const %(mytype)s operator*(const %(mytype)s &x, const %(othertype)s &y) 623 { return %(mytype)s(x.real*y, x.imag*y); } 624 625 const %(mytype)s operator*(const %(othertype)s &y, const %(mytype)s &x) 626 { return %(mytype)s(x.real*y, x.imag*y); } 627 ''' % dict(mytype=mytype, othertype=othertype) 628 629 operator_mul = ''.join(operator_mul_real(ctype, rtype) 630 for ctype in cplx_types 631 for rtype in real_types) 632 633 return (template % dict(nbits=64, half_nbits=32) + 634 template % dict(nbits=128, half_nbits=64) + 635 operator_eq + 636 operator_plus + 637 operator_minus + 638 operator_mul) 639 640 else: 641 return "" 642 643 def c_init_code(self): 644 return ["import_array();"] 645 646 def c_code_cache_version(self): 647 return (13, np.__version__) 648 649 def get_shape_info(self, obj): 650 return obj.itemsize 651 652 def get_size(self, shape_info): 653 return shape_info 654 655# Register C code for ViewOp on Scalars. 656theano.compile.register_view_op_c_code( 657 Scalar, 658 """ 659 %(oname)s = %(iname)s; 660 """, 661 1) 662 663 664bool = get_scalar_type('bool') 665int8 = get_scalar_type('int8') 666int16 = get_scalar_type('int16') 667int32 = get_scalar_type('int32') 668int64 = get_scalar_type('int64') 669uint8 = get_scalar_type('uint8') 670uint16 = get_scalar_type('uint16') 671uint32 = get_scalar_type('uint32') 672uint64 = get_scalar_type('uint64') 673float16 = get_scalar_type('float16') 674float32 = get_scalar_type('float32') 675float64 = get_scalar_type('float64') 676complex64 = get_scalar_type('complex64') 677complex128 = get_scalar_type('complex128') 678 679int_types = int8, int16, int32, int64 680uint_types = uint8, uint16, uint32, uint64 681float_types = float16, float32, float64 682complex_types = complex64, complex128 683 684integer_types = int_types + uint_types 685discrete_types = (bool,) + integer_types 686continuous_types = float_types + complex_types 687all_types = discrete_types + continuous_types 688 689 690class _scalar_py_operators: 691 # So that we can simplify checking code when we have a mixture of Scalar 692 # variables and Tensor variables 693 ndim = 0 694 695 dtype = property(lambda self: self.type.dtype) 696 """The dtype of this scalar.""" 697 698 # UNARY 699 def __abs__(self): 700 return abs_(self) 701 702 def __neg__(self): 703 return neg(self) 704 705 # CASTS 706 # def __int__(self): return AsInt(self).out 707 # def __float__(self): return AsDouble(self).out 708 # def __complex__(self): return AsComplex(self).out 709 710 # BITWISE 711 def __invert__(self): 712 return invert(self) 713 714 def __and__(self, other): 715 return and_(self, other) 716 717 def __or__(self, other): 718 return or_(self, other) 719 720 def __xor__(self, other): 721 return xor(self, other) 722 723 def __rand__(self, other): 724 return and_(other, self) 725 726 def __ror__(self, other): 727 return or_(other, self) 728 729 def __rxor__(self, other): 730 return xor(other, self) 731 732 # COMPARISONS 733 def __lt__(self, other): 734 return lt(self, other) 735 736 def __le__(self, other): 737 return le(self, other) 738 739 def __gt__(self, other): 740 return gt(self, other) 741 742 def __ge__(self, other): 743 return ge(self, other) 744 745 # ARITHMETIC - NORMAL 746 def __add__(self, other): 747 return add(self, other) 748 749 def __sub__(self, other): 750 return sub(self, other) 751 752 def __mul__(self, other): 753 return mul(self, other) 754 755 def __truediv__(self, other): 756 return div_proxy(self, other) 757 758 def __div__(self, other): 759 return div_proxy(self, other) 760 761 def __floordiv__(self, other): 762 return int_div(self, other) 763 764 def __mod__(self, other): 765 return mod_check(self, other) 766 767 def __pow__(self, other): 768 return pow(self, other) 769 770 # ARITHMETIC - RIGHT-OPERAND 771 def __radd__(self, other): 772 return add(other, self) 773 774 def __rsub__(self, other): 775 return sub(other, self) 776 777 def __rmul__(self, other): 778 return mul(other, self) 779 780 def __rdiv__(self, other): 781 return div_proxy(other, self) 782 783 def __rmod__(self, other): 784 return mod(other, self) 785 786 def __rpow__(self, other): 787 return pow(other, self) 788 789 def zeros_like(self, dtype=None): 790 # The second is needed for Elemwise ops to work right 791 if dtype is None: 792 dtype = str(self.type.dtype) 793 return second(self, ScalarConstant(get_scalar_type(dtype), 0)) 794 795 def ones_like(self, dtype=None): 796 # The second is needed for Elemwise ops to work right 797 if dtype is None: 798 dtype = str(self.type.dtype) 799 return second(self, ScalarConstant(get_scalar_type(dtype), 1)) 800 801 def astype(self, dtype): 802 return cast(self, dtype) 803 804 805class ScalarVariable(_scalar_py_operators, Variable): 806 pass 807 808 809class ScalarConstant(_scalar_py_operators, Constant): 810 pass 811 812# Register ScalarConstant as the type of Constant corresponding to Scalar 813Scalar.Constant = ScalarConstant 814 815 816# Easy constructors 817 818def _multi(*fns): 819 def f2(f, names): 820 if len(names) == 1: 821 return f(names) 822 else: 823 return [f(name) for name in names] 824 if len(fns) == 1: 825 return partial(f2, fns[0]) 826 else: 827 return [partial(f2, f) for f in fns] 828 829ints = _multi(int64) 830floats = _multi(float64) 831complexs = _multi(complex128) 832complexs64 = _multi(complex64) 833complexs128 = _multi(complex128) 834 835 836def upcast_out(*types): 837 dtype = Scalar.upcast(*types) 838 return get_scalar_type(dtype), 839 840 841def upcast_out_nobool(*types): 842 type = upcast_out(*types) 843 if type[0] == bool: 844 raise TypeError("bool output not supported") 845 return type 846 847 848def upcast_out_min8(*types): 849 type = upcast_out(*types) 850 if type[0] == bool: 851 return int8, 852 return type 853 854 855def upgrade_to_float(*types): 856 """ 857 Upgrade any int types to float32 or float64 to avoid losing precision. 858 859 """ 860 conv = {bool: float32, 861 int8: float32, 862 int16: float32, 863 int32: float64, 864 int64: float64, 865 uint8: float32, 866 uint16: float32, 867 uint32: float64, 868 uint64: float64} 869 return get_scalar_type(Scalar.upcast(*[conv.get(type, type) 870 for type in types])), 871 872 873def upgrade_to_float64(*types): 874 """ 875 Upgrade any int and float32 to float64 to do as SciPy. 876 877 """ 878 return get_scalar_type('float64'), 879 880 881def same_out(type): 882 return type, 883 884 885def same_out_nobool(type): 886 if type == bool: 887 raise TypeError("bool input not supported") 888 return type, 889 890 891def same_out_min8(type): 892 if type == bool: 893 return int8, 894 return type, 895 896 897def upcast_out_no_complex(*types): 898 if any([type in complex_types for type in types]): 899 raise TypeError('complex type are not supported') 900 return get_scalar_type(dtype=Scalar.upcast(*types)), 901 902 903def same_out_float_only(type): 904 if type not in float_types: 905 raise TypeError('only float type are supported') 906 return type, 907 908 909class transfer_type(gof.utils.object2): 910 __props__ = ('transfer',) 911 912 def __init__(self, *transfer): 913 assert all(type(x) in [int, str] or x is None for x in transfer) 914 self.transfer = transfer 915 916 def __str__(self): 917 return 'transfer_type{%s}' % self.transfer 918 919 def __call__(self, *types): 920 upcast = upcast_out(*types) 921 retval = [] 922 for i in self.transfer: 923 if i is None: 924 retval += [upcast] 925 elif isinstance(i, str): 926 retval += [i] 927 else: 928 retval += [types[i]] 929 return retval 930 # return [upcast if i is None else types[i] for i in self.transfer] 931 932 933class specific_out(gof.utils.object2): 934 __props__ = ('spec',) 935 936 def __init__(self, *spec): 937 self.spec = spec 938 939 def __call__(self, *types): 940 return self.spec 941 942 943def int_out(*types): 944 return int64, 945 946 947def float_out(*types): 948 return float64, 949 950 951def upgrade_to_float_no_complex(*types): 952 """ 953 Don't accept complex, otherwise call upgrade_to_float(). 954 955 """ 956 for type in types: 957 if type in complex_types: 958 raise TypeError('complex argument not supported') 959 return upgrade_to_float(*types) 960 961 962def same_out_nocomplex(type): 963 if type in complex_types: 964 raise TypeError('complex argument not supported') 965 return type, 966 967 968def int_out_nocomplex(*types): 969 for type in types: 970 if type in complex_types: 971 raise TypeError('complex argument not supported') 972 return int64, 973 974 975def float_out_nocomplex(*types): 976 for type in types: 977 if type in complex_types: 978 raise TypeError('complex argument not supported') 979 return float64, 980 981 982class unary_out_lookup(gof.utils.object2): 983 """ 984 Get a output_types_preference object by passing a dictionary: 985 986 unary_out_lookup({int8:int32, float32:complex128}) 987 988 The result is an op that maps in8 to int32 and float32 to 989 complex128 and other input types lead to a TypeError. 990 991 """ 992 def __init__(self, type_table): 993 self.tbl = type_table 994 995 def __call__(self, *types): 996 if len(types) == 1: 997 types = types[0] 998 try: 999 rval = self.tbl[types] 1000 except Exception: 1001 raise TypeError(types) 1002 if isinstance(types, (list, tuple)): 1003 return rval 1004 else: 1005 return [rval] 1006 1007 def __eq__(self, other): 1008 return type(self) == type(other) and self.tbl == other.tbl 1009 1010 def __hash__(self): 1011 return hash(type(self)) # ignore hash of table 1012 1013 1014def real_out(type): 1015 if type == complex64: 1016 return float32, 1017 if type == complex128: 1018 return float64, 1019 return type, 1020 1021 1022class ScalarOp(Op): 1023 1024 nin = -1 1025 nout = 1 1026 1027 def __init__(self, output_types_preference=None, name=None): 1028 self.name = name 1029 if output_types_preference is not None: 1030 if not isinstance(output_types_preference, Callable): 1031 raise TypeError( 1032 "Expected a callable for the 'output_types_preference' argument to %s. (got: %s)" % 1033 (self.__class__, output_types_preference)) 1034 self.output_types_preference = output_types_preference 1035 1036 def make_node(self, *inputs): 1037 if self.nin >= 0: 1038 if len(inputs) != self.nin: 1039 raise TypeError("Wrong number of inputs for %s.make_node (got %i(%s), expected %i)" % 1040 (self, len(inputs), str(inputs), self.nin)) 1041 inputs = [as_scalar(input) for input in inputs] 1042 outputs = [t() for t in self.output_types([input.type 1043 for input in inputs])] 1044 if len(outputs) != self.nout: 1045 raise TypeError("Not the right number of outputs produced for %s(%s). Expected %s, got %s." 1046 % (self, ", ".join(str(input) for input in inputs), self.nout, len(outputs))) 1047 return Apply(self, inputs, outputs) 1048 1049 def output_types(self, types): 1050 if hasattr(self, 'output_types_preference'): 1051 variables = self.output_types_preference(*types) 1052 if not isinstance(variables, (list, tuple)) or any(not isinstance(x, Type) for x in variables): 1053 raise TypeError( 1054 "output_types_preference should return a list or a tuple of types", self.output_types_preference, variables) 1055 if len(variables) != self.nout: 1056 raise TypeError("Not the right number of outputs types produced for %s(%s) by %s. Expected %s, got %s." 1057 % (self, ", ".join(str(type) for type in variables), 1058 self.output_types_preference, self.nout, len(variables))) 1059 return variables 1060 else: 1061 raise NotImplementedError( 1062 "Cannot calculate the output types for %s" % self) 1063 1064 def perform(self, node, inputs, output_storage): 1065 if self.nout == 1: 1066 output_storage[0][0] = self.impl(*inputs) 1067 else: 1068 variables = utils.from_return_values(self.impl(*inputs)) 1069 assert len(variables) == len(output_storage) 1070 for storage, variable in zip(output_storage, variables): 1071 storage[0] = variable 1072 1073 def impl(self, *inputs): 1074 raise utils.MethodNotDefined("impl", type(self), 1075 self.__class__.__name__) 1076 1077 def grad(self, inputs, output_gradients): 1078 raise utils.MethodNotDefined("grad", type(self), 1079 self.__class__.__name__) 1080 1081 def L_op(self, inputs, outputs, output_gradients): 1082 return self.grad(inputs, output_gradients) 1083 1084 def __eq__(self, other): 1085 test = (type(self) == type(other) and 1086 getattr(self, 'output_types_preference', None) == 1087 getattr(other, 'output_types_preference', None)) 1088 return test 1089 1090 def __hash__(self): 1091 return hash((type(self), 1092 getattr(self, 'output_types_preference', 0))) 1093 1094 def __str__(self): 1095 if hasattr(self, 'name') and self.name: 1096 return self.name 1097 else: 1098 param = [(k, v) for k, v in self.__dict__.items() 1099 if k not in ["name", "_op_use_c_code", "bool", 1100 "output_types_preference"]] 1101 if param: 1102 return "%s{%s}" % (self.__class__.__name__, 1103 ", ".join("%s=%s" % (k, v) 1104 for k, v in param)) 1105 else: 1106 return self.__class__.__name__ 1107 1108 def c_code_cache_version(self): 1109 return (4,) 1110 1111 def c_code_contiguous(self, node, name, inp, out, sub): 1112 """ 1113 This function is called by Elemwise when all inputs and outputs are 1114 c_contiguous. This allows to use the SIMD version of this op. 1115 1116 The inputs are the same as c_code except that: 1117 1118 - inp and out must be the names of the variables associated to the 1119 ndarrays in the C code 1120 - node must be the elemwise node (this is needed to know 1121 the inputs/outputs types) 1122 1123 """ 1124 raise theano.gof.utils.MethodNotDefined() 1125 1126 def supports_c_code(self, inputs, outputs): 1127 """Returns True if the current op has functioning C code for 1128 the given Elemwise inputs, outputs. 1129 1130 """ 1131 try: 1132 tmp_s_input = [] 1133 # To keep the same aliasing between inputs 1134 mapping = dict() 1135 for ii in inputs: 1136 if ii in mapping: 1137 tmp_s_input.append(mapping[ii]) 1138 else: 1139 tmp = get_scalar_type(ii.dtype).make_variable() 1140 tmp_s_input.append(tmp) 1141 mapping[ii] = tmp_s_input[-1] 1142 1143 with theano.change_flags(compute_test_value='ignore'): 1144 s_op = self(*tmp_s_input, return_list=True) 1145 1146 # if the scalar_op don't have a c implementation, 1147 # we skip its fusion to allow the fusion of the 1148 # other ops. 1149 self.c_code(s_op[0].owner, 1150 "test_presence_of_c_code", 1151 ["x" for x in inputs], 1152 ["z" for z in outputs], 1153 {"fail": "%(fail)s"}) 1154 except (theano.gof.utils.MethodNotDefined, NotImplementedError): 1155 return False 1156 return True 1157 1158 1159class UnaryScalarOp(ScalarOp): 1160 nin = 1 1161 amd_float32 = None 1162 amd_float64 = None 1163 1164 def c_code_contiguous(self, node, name, inputs, outputs, sub): 1165 (x,) = inputs 1166 (z,) = outputs 1167 if (not theano.config.lib.amdlibm or 1168 # We compare the dtype AND the broadcast flag 1169 # as this function do not broadcast 1170 node.inputs[0].type != node.outputs[0].type): 1171 raise theano.gof.utils.MethodNotDefined() 1172 1173 dtype = node.inputs[0].type.dtype_specs()[1] 1174 fct_call = self.c_code_contiguous_raw(dtype, 'n', 'x', 'z') 1175 return """ 1176{ 1177 npy_intp n = PyArray_SIZE(%(z)s); 1178 %(dtype)s * x = (%(dtype)s*) PyArray_DATA(%(x)s); 1179 %(dtype)s * z = (%(dtype)s*) PyArray_DATA(%(z)s); 1180 %(fct_call)s; 1181} 1182 """ % locals() 1183 1184 def c_code_contiguous_raw(self, dtype, n, i, o): 1185 if not config.lib.amdlibm: 1186 raise theano.gof.utils.MethodNotDefined() 1187 if dtype.startswith('npy_'): 1188 dtype = dtype[4:] 1189 if dtype == 'float32' and self.amd_float32 is not None: 1190 dtype = 'float' 1191 fct = self.amd_float32 1192 elif dtype == 'float64' and self.amd_float64 is not None: 1193 dtype = 'double' 1194 fct = self.amd_float64 1195 else: 1196 raise theano.gof.utils.MethodNotDefined() 1197 return "%(fct)s(%(n)s, %(i)s, %(o)s)" % locals() 1198 1199 1200class BinaryScalarOp(ScalarOp): 1201 # One may define in subclasses the following fields: 1202 # - `identity`: for an associative operation, identity corresponds to 1203 # the neutral element. For instance, it will be 0 for addition, 1 for 1204 # multiplication, True for "and", False for "or". 1205 # - `commutative`: whether op(a, b) == op(b, a) 1206 # - `associative`: whether op(op(a, b), c) == op(a, op(b, c)) 1207 nin = 2 1208 1209 1210############### 1211# Comparisons 1212############### 1213 1214class LogicalComparison(BinaryScalarOp): 1215 def __init__(self, *args, **kwargs): 1216 BinaryScalarOp.__init__(self, *args, **kwargs) 1217 # This is for compat with old pickles. 1218 self.bool = True 1219 1220 def __eq__(self, other): 1221 return (BinaryScalarOp.__eq__(self, other) and 1222 getattr(self, 'bool', False) == getattr(other, 'bool', False)) 1223 1224 def __hash__(self): 1225 # bool should always be True 1226 return BinaryScalarOp.__hash__(self) 1227 1228 def output_types(self, *input_dtypes): 1229 return [bool] if getattr(self, 'bool', False) else [int8] 1230 1231 def L_op(self, inputs, outputs, output_gradients): 1232 x, y = inputs 1233 assert outputs[0].type == bool 1234 return [x.zeros_like().astype(theano.config.floatX), 1235 y.zeros_like().astype(theano.config.floatX)] 1236 1237 def c_code_cache_version(self): 1238 super_version = super(LogicalComparison, self).c_code_cache_version() 1239 return super_version + (0,) 1240 1241 1242class FixedLogicalComparison(UnaryScalarOp): 1243 """ 1244 Comparison to a fixed value. 1245 1246 """ 1247 def __init__(self, *args, **kwargs): 1248 UnaryScalarOp.__init__(self, *args, **kwargs) 1249 # This is for compat with old pickles 1250 self.bool = True 1251 1252 def __eq__(self, other): 1253 return (UnaryScalarOp.__eq__(self, other) and 1254 getattr(self, 'bool', False) == getattr(other, 'bool', False)) 1255 1256 def __hash__(self): 1257 # bool should always be True 1258 return UnaryScalarOp.__hash__(self) 1259 1260 def output_types(self, *input_dtypes): 1261 return [bool] if getattr(self, 'bool', False) else [int8] 1262 1263 def L_op(self, inputs, outputs, output_gradients): 1264 x, = inputs 1265 assert outputs[0].type == bool 1266 return [x.zeros_like().astype(theano.config.floatX)] 1267 1268 def c_code_cache_version(self): 1269 super_version = super(FixedLogicalComparison, self).c_code_cache_version() 1270 return super_version + (0,) 1271 1272 1273class LT(LogicalComparison): 1274 identity = False 1275 commutative = False 1276 associative = False 1277 nfunc_spec = ('less', 2, 1) 1278 1279 def impl(self, x, y): 1280 # built-in < don't support complex 1281 return np.less(x, y) 1282 1283 def c_code(self, node, name, inputs, outputs, sub): 1284 (x, y) = inputs 1285 (z,) = outputs 1286 if node.inputs[0].type in complex_types: 1287 raise NotImplementedError() 1288 return "%(z)s = (%(x)s < %(y)s);" % locals() 1289lt = LT() 1290 1291 1292class GT(LogicalComparison): 1293 identity = False 1294 commutative = False 1295 associative = False 1296 nfunc_spec = ('greater', 2, 1) 1297 1298 def impl(self, x, y): 1299 # built-in > don't support complex 1300 return np.greater(x, y) 1301 1302 def c_code(self, node, name, inputs, outputs, sub): 1303 (x, y) = inputs 1304 (z,) = outputs 1305 if node.inputs[0].type in complex_types: 1306 raise NotImplementedError() 1307 return "%(z)s = (%(x)s > %(y)s);" % locals() 1308gt = GT() 1309 1310 1311class LE(LogicalComparison): 1312 identity = False 1313 commutative = False 1314 associative = False 1315 nfunc_spec = ('less_equal', 2, 1) 1316 1317 def impl(self, x, y): 1318 # built-in <= don't support complex 1319 return np.less_equal(x, y) 1320 1321 def c_code(self, node, name, inputs, outputs, sub): 1322 (x, y) = inputs 1323 (z,) = outputs 1324 if node.inputs[0].type in complex_types: 1325 raise NotImplementedError() 1326 return "%(z)s = (%(x)s <= %(y)s);" % locals() 1327le = LE() 1328 1329 1330class GE(LogicalComparison): 1331 identity = False 1332 commutative = False 1333 associative = False 1334 nfunc_spec = ('greater_equal', 2, 1) 1335 1336 def impl(self, x, y): 1337 # built-in >= don't support complex 1338 return np.greater_equal(x, y) 1339 1340 def c_code(self, node, name, inputs, outputs, sub): 1341 (x, y) = inputs 1342 (z,) = outputs 1343 if node.inputs[0].type in complex_types: 1344 raise NotImplementedError() 1345 return "%(z)s = (%(x)s >= %(y)s);" % locals() 1346ge = GE() 1347 1348 1349class EQ(LogicalComparison): 1350 identity = False 1351 commutative = True 1352 associative = False 1353 nfunc_spec = ('equal', 2, 1) 1354 1355 def impl(self, x, y): 1356 return x == y 1357 1358 def c_code(self, node, name, inputs, outputs, sub): 1359 (x, y) = inputs 1360 (z,) = outputs 1361 return "%(z)s = (%(x)s == %(y)s);" % locals() 1362eq = EQ() 1363 1364 1365class NEQ(LogicalComparison): 1366 identity = False 1367 commutative = True 1368 associative = False 1369 nfunc_spec = ('not_equal', 2, 1) 1370 1371 def impl(self, x, y): 1372 return x != y 1373 1374 def c_code(self, node, name, inputs, outputs, sub): 1375 (x, y) = inputs 1376 (z,) = outputs 1377 if node.inputs[0].type in complex_types: 1378 raise NotImplementedError() 1379 return "%(z)s = (%(x)s != %(y)s);" % locals() 1380neq = NEQ() 1381 1382 1383class IsNan(FixedLogicalComparison): 1384 nfunc_spec = ('isnan', 1, 1) 1385 1386 def impl(self, x): 1387 return np.isnan(x) 1388 1389 def c_code(self, node, name, inputs, outputs, sub): 1390 (x,) = inputs 1391 (z,) = outputs 1392 if node.inputs[0].type in complex_types: 1393 raise NotImplementedError() 1394 # Discrete type can never be nan 1395 if node.inputs[0].type in discrete_types: 1396 return "%(z)s = false;" % locals() 1397 1398 # Windows tries to be different and sometimes return -1, but we want 1399 # to be consistent with numpy (which returns True), hence the "abs". 1400 return "%(z)s = abs(isnan(%(x)s));" % locals() 1401 1402 def c_code_cache_version(self): 1403 scalarop_version = super(IsNan, self).c_code_cache_version() 1404 return tuple(scalarop_version) + (3,) 1405isnan = IsNan() 1406 1407 1408class IsInf(FixedLogicalComparison): 1409 nfunc_spec = ('isinf', 1, 1) 1410 1411 def impl(self, x): 1412 return np.isinf(x) 1413 1414 def c_code(self, node, name, inputs, outputs, sub): 1415 (x,) = inputs 1416 (z,) = outputs 1417 if node.inputs[0].type in complex_types: 1418 raise NotImplementedError() 1419 # Discrete type can never be inf 1420 if node.inputs[0].type in discrete_types: 1421 return "%(z)s = false;" % locals() 1422 1423 # Note that the C isinf returns -1 for -Inf and +1 for +Inf, while 1424 # numpy simply returns True: we mimic numpy's behavior here, thus 1425 # the absolute value. 1426 return "%(z)s = abs(isinf(%(x)s));" % locals() 1427 1428 def c_code_cache_version(self): 1429 scalarop_version = super(IsInf, self).c_code_cache_version() 1430 return tuple(scalarop_version) + (3,) 1431isinf = IsInf() 1432 1433 1434class InRange(LogicalComparison): 1435 nin = 3 1436 1437 def __init__(self, openlow, openhi): 1438 self.openlow = openlow 1439 self.openhi = openhi 1440 1441 def impl(self, x, low, hi): 1442 if self.openlow and x <= low: 1443 return False 1444 elif not self.openlow and x < low: 1445 return False 1446 if self.openhi and x >= hi: 1447 return False 1448 elif not self.openhi and x > hi: 1449 return False 1450 return True 1451 1452 def c_code(self, node, name, inputs, outputs, sub): 1453 (x, low, hi) = inputs 1454 (z,) = outputs 1455 1456 cmp1 = '>' if self.openlow else '>=' 1457 cmp2 = '<' if self.openhi else '<=' 1458 1459 return ("%(z)s = %(x)s %(cmp1)s %(low)s &&" 1460 " %(x)s %(cmp2)s %(hi)s;" % locals()) 1461 1462 def get_grad(self, elem): 1463 if elem.type in complex_types: 1464 msg = ("No gradient implemented for complex numbers in " 1465 "class scalar.basic.InRange") 1466 raise NotImplementedError(msg) 1467 elif elem.type in discrete_types: 1468 return elem.zeros_like().astype(theano.config.floatX) 1469 else: 1470 return elem.zeros_like() 1471 1472 def L_op(self, inputs, outputs, gout): 1473 (x, low, hi) = inputs 1474 (gz,) = gout 1475 grads = [] 1476 for elem in [x, low, hi]: 1477 grads.append(self.get_grad(elem)) 1478 return grads 1479 1480inopenrange = InRange(True, True) 1481inclosedrange = InRange(False, False) 1482 1483 1484class Switch(ScalarOp): 1485 nin = 3 1486 nfunc_spec = ('where', 3, 1) 1487 1488 def impl(self, cond, ift, iff): 1489 return ift if cond else iff 1490 1491 def c_code(self, node, name, inputs, outputs, sub): 1492 (cond, ift, iff) = inputs 1493 (z,) = outputs 1494 return "%(z)s = %(cond)s ? %(ift)s : %(iff)s;" % locals() 1495 1496 def L_op(self, inputs, outputs, gout): 1497 (cond, ift, iff) = inputs 1498 (gz,) = gout 1499 first_part = switch(cond, gz, 0.) 1500 second_part = switch(cond, 0., gz) 1501 1502 if (outputs[0].type.dtype in discrete_types): 1503 first_part = 0. 1504 second_part = 0. 1505 1506 # cond does affect the elements of the output so it is connected. 1507 # For the sake of making the gradient convenient we assume that 1508 # condition + epsilon always triggers the same branch as condition 1509 condition_grad = cond.zeros_like().astype(theano.config.floatX) 1510 1511 return (condition_grad, first_part, second_part) 1512 1513 def output_types(self, types): 1514 (cond_t, ift_t, iff_t) = types 1515 return upcast_out(ift_t, iff_t) 1516switch = Switch() 1517 1518#################### 1519# BIT-WISE OPERATORS 1520#################### 1521 1522 1523class UnaryBitOp(UnaryScalarOp): 1524 def output_types(self, *input_types): 1525 for i in input_types[0]: 1526 if i not in discrete_types: 1527 raise TypeError('input to a BitOp must have type (u)int8, ' 1528 '(u)int16, (u)int32 or (u)int64 or bool not %s' % i) 1529 return upcast_out(*input_types[0]) 1530 1531 def grad(self, inputs, output_gradients): 1532 return [inputs[0].zeros_like().astype(theano.config.floatX)] 1533 1534 1535class BinaryBitOp(BinaryScalarOp): 1536 def output_types(self, *input_types): 1537 t0, t1 = input_types[0] 1538 if t0 == bool and t1 == bool: 1539 return [bool] 1540 for i in input_types[0]: 1541 if i not in integer_types: 1542 raise TypeError('input to a BitOp must have type (u)int8, ' 1543 '(u)int16, (u)int32 or (u)int64 or ' 1544 'be all bools not %s' % i) 1545 return upcast_out(*input_types[0]) 1546 1547 def grad(self, inputs, output_gradients): 1548 a, b = inputs 1549 return [a.zeros_like().astype(theano.config.floatX), 1550 b.zeros_like().astype(theano.config.floatX)] 1551 1552 1553class OR(BinaryBitOp): 1554 identity = 0 1555 commutative = True 1556 associative = True 1557 nfunc_spec = ('bitwise_or', 2, 1) 1558 1559 def impl(self, x, y): 1560 return x | y 1561 1562 def c_code(self, node, name, inputs, outputs, sub): 1563 (x, y) = inputs 1564 (z,) = outputs 1565 return "%(z)s = (%(x)s | %(y)s);" % locals() 1566or_ = OR() 1567 1568 1569class XOR(BinaryBitOp): 1570 identity = 0 1571 commutative = True 1572 associative = True 1573 nfunc_spec = ('bitwise_xor', 2, 1) 1574 1575 def impl(self, x, y): 1576 return x ^ y 1577 1578 def c_code(self, node, name, inputs, outputs, sub): 1579 (x, y) = inputs 1580 (z,) = outputs 1581 return "%(z)s = (%(x)s ^ %(y)s);" % locals() 1582xor = XOR() 1583 1584 1585class AND(BinaryBitOp): 1586 identity = -1 1587 commutative = True 1588 associative = True 1589 nfunc_spec = ('bitwise_and', 2, 1) 1590 1591 def impl(self, x, y): 1592 return x & y 1593 1594 def c_code(self, node, name, inputs, outputs, sub): 1595 (x, y) = inputs 1596 (z,) = outputs 1597 return "%(z)s = (%(x)s & %(y)s);" % locals() 1598 1599 def c_code_cache_version(self): 1600 super_version = super(AND, self).c_code_cache_version() 1601 return super_version + (3,) 1602and_ = AND() 1603 1604 1605class Invert(UnaryBitOp): 1606 nfunc_spec = ('invert', 1, 1) 1607 1608 def impl(self, x): 1609 return ~x 1610 1611 def c_code(self, node, name, inputs, outputs, sub): 1612 (x,) = inputs 1613 (z,) = outputs 1614 if node.outputs[0].type == bool: 1615 return "%(z)s = (!%(x)s);" % locals() 1616 return "%(z)s = (~%(x)s);" % locals() 1617invert = Invert() 1618 1619 1620############## 1621# Arithmetic 1622############## 1623class Maximum(BinaryScalarOp): 1624 commutative = True 1625 associative = True 1626 nfunc_spec = ('maximum', 2, 1) 1627 1628 def impl(self, *inputs): 1629 # The built-in max function don't support complex type 1630 return np.maximum(*inputs) 1631 1632 def c_code(self, node, name, inputs, outputs, sub): 1633 (x, y) = inputs 1634 (z,) = outputs 1635 if any([i.type in complex_types for i in node.inputs]): 1636 raise NotImplementedError() 1637 # Test for both y>x and x>=y to detect NaN 1638 return ('%(z)s = ((%(y)s)>(%(x)s)? (%(y)s): ' 1639 '((%(x)s)>=(%(y)s)? (%(x)s): nan("")));' % locals()) 1640 1641 def L_op(self, inputs, outputs, gout): 1642 (x, y) = inputs 1643 (gz,) = gout 1644 if gz.type in complex_types: 1645 # max is currently defined for complex_types, 1646 # but the gradient for complex is not. 1647 raise NotImplementedError() 1648 1649 if outputs[0].type in discrete_types: 1650 return [x.zeros_like().astype(theano.config.floatX), 1651 y.zeros_like().astype(theano.config.floatX)] 1652 # This form handle the case when both value are the same. 1653 # In that case, gx will be gz, gy will be 0. 1654 e = eq(outputs[0], x) 1655 gx = e * gz 1656 gy = (constant(1, dtype=gz.dtype) - e) * gz 1657 return (gx, gy) 1658 1659maximum = Maximum(upcast_out, name='maximum') 1660 1661 1662class Minimum(BinaryScalarOp): 1663 commutative = True 1664 associative = True 1665 nfunc_spec = ('minimum', 2, 1) 1666 1667 def impl(self, *inputs): 1668 # The built-in min function don't support complex type 1669 return np.minimum(*inputs) 1670 1671 def c_code(self, node, name, inputs, outputs, sub): 1672 (x, y) = inputs 1673 (z,) = outputs 1674 if any([i.type in complex_types for i in node.inputs]): 1675 raise NotImplementedError() 1676 return ('%(z)s = ((%(y)s)<(%(x)s)? (%(y)s): ' 1677 '((%(x)s)<=(%(y)s)? (%(x)s): nan("")));' % locals()) 1678 1679 def L_op(self, inputs, outputs, gout): 1680 (x, y) = inputs 1681 (gz,) = gout 1682 if gz.type in complex_types: 1683 # min is currently defined for complex_types, 1684 # but the gradient for complex is not. 1685 raise NotImplementedError() 1686 1687 if outputs[0].type in discrete_types: 1688 return [x.zeros_like().astype(theano.config.floatX), 1689 y.zeros_like().astype(theano.config.floatX)] 1690 # This form handle the case when both value are the same. 1691 # In that case, gx will be gz, gy will be 0. 1692 e = eq(outputs[0], x) 1693 gx = e * gz 1694 gy = (constant(1, dtype=gz.dtype) - e) * gz 1695 return (gx, gy) 1696minimum = Minimum(upcast_out, name='minimum') 1697 1698 1699class Add(ScalarOp): 1700 identity = 0 1701 commutative = True 1702 associative = True 1703 nfunc_spec = ('add', 2, 1) 1704 1705 def impl(self, *inputs): 1706 return sum(inputs) 1707 1708 def c_code(self, node, name, inputs, outputs, sub): 1709 (z,) = outputs 1710 op = " + " 1711 if node.outputs[0].type == bool: 1712 op = " || " 1713 if not inputs: 1714 return z + " = 0;" 1715 else: 1716 return z + " = " + op.join(inputs) + ";" 1717 1718 def L_op(self, inputs, outputs, gout): 1719 (gz,) = gout 1720 if gz.type in complex_types: 1721 raise NotImplementedError() 1722 if (outputs[0].type in discrete_types): 1723 assert gz is not None 1724 retval = [] 1725 for ii, inp in enumerate(inputs): 1726 if hasattr(inp, 'zeros_like'): 1727 retval.append( 1728 inp.zeros_like().astype(theano.config.floatX)) 1729 else: 1730 retval.append(grad_undefined(self, ii, inp)) 1731 else: 1732 retval = [] 1733 for i in inputs: 1734 retval += [gz] 1735 return retval 1736 1737 1738add = Add(upcast_out, name='add') 1739 1740 1741class Mul(ScalarOp): 1742 identity = 1 1743 commutative = True 1744 associative = True 1745 nfunc_spec = ('multiply', 2, 1) 1746 1747 def impl(self, *inputs): 1748 return np.product(inputs) 1749 1750 def c_code(self, node, name, inputs, outputs, sub): 1751 (z,) = outputs 1752 op = " * " 1753 if node.outputs[0].type == bool: 1754 op = " && " 1755 if not inputs: 1756 return z + " = 1;" 1757 else: 1758 return z + " = " + op.join(inputs) + ";" 1759 1760 def grad(self, inputs, gout): 1761 (gz,) = gout 1762 retval = [] 1763 1764 # The following 3 lines verify that gz is complex when the 1765 # output is complex. The rest of this function make this supposition. 1766 output_type = self.output_types([i.type for i in inputs])[0] 1767 if output_type in complex_types: 1768 if gz.type not in complex_types: 1769 raise TypeError( 1770 'Mul with output_type ' + str(output_type) + 1771 ' expected gz type to be complex, got gz with type ' + 1772 str(gz.type)) 1773 1774 if output_type in discrete_types: 1775 return [ipt.zeros_like().astype(theano.config.floatX) 1776 for ipt in inputs] 1777 1778 for input in inputs: 1779 if gz.type in complex_types: 1780 # zr+zi = (xr + xi)(yr + yi) 1781 # zr+zi = (xr*yr - xi*yi) + (xr yi + xi yr ) 1782 otherprod = mul(*(utils.difference(inputs, [input]))) 1783 yr = real(otherprod) 1784 yi = imag(otherprod) 1785 if input.type in complex_types: 1786 retval += [complex(yr * real(gz) + yi * imag(gz), 1787 yr * imag(gz) - yi * real(gz))] 1788 else: 1789 retval += [yr * real(gz) + yi * imag(gz)] 1790 else: 1791 retval += [mul(*([gz] + utils.difference(inputs, 1792 [input])))] 1793 return retval 1794 1795 1796mul = Mul(upcast_out, name='mul') 1797 1798 1799class Sub(BinaryScalarOp): 1800 nfunc_spec = ('subtract', 2, 1) 1801 1802 def impl(self, x, y): 1803 return x - y 1804 1805 def c_code(self, node, name, inputs, outputs, sub): 1806 (x, y) = inputs 1807 (z,) = outputs 1808 return "%(z)s = %(x)s - %(y)s;" % locals() 1809 1810 def L_op(self, inputs, outputs, gout): 1811 (x, y) = inputs 1812 (gz,) = gout 1813 if gz.type in complex_types: 1814 raise NotImplementedError() 1815 if outputs[0].type in discrete_types: 1816 return [x.zeros_like().astype(theano.config.floatX), 1817 y.zeros_like().astype(theano.config.floatX)] 1818 1819 first_part = gz 1820 second_part = -gz 1821 1822 return first_part, second_part 1823sub = Sub(upcast_out_nobool, name='sub') 1824 1825 1826def int_or_true_div(x_discrete, y_discrete): 1827 """ 1828 Return 'int' or 'true' depending on the type of division used for x / y. 1829 1830 Parameters 1831 ---------- 1832 x_discrete : bool 1833 True if `x` is discrete ([unsigned] integer). 1834 y_discrete : bool 1835 True if `y` is discrete ([unsigned] integer). 1836 1837 Returns 1838 ------- 1839 str 1840 'int' if `x / y` should be an integer division, or `true` if it 1841 should be a true division. 1842 1843 Raises 1844 ------ 1845 IntegerDivisionError 1846 If both `x_discrete` and `y_discrete` are True and `config.int_division` 1847 is set to 'raise'. 1848 1849 Notes 1850 ----- 1851 This function is used by both scalar/basic.py and tensor/basic.py. 1852 1853 """ 1854 if (x_discrete and y_discrete): 1855 if config.int_division == 'raise': 1856 raise IntegerDivisionError( 1857 "With `config.int_division` set to 'raise', dividing two " 1858 "integer types with '/' is forbidden to avoid confusion " 1859 "between integer and floating point divisions. Please " 1860 "use // for integer division, or if you want a float result " 1861 "either cast one of the arguments to a float or directly call " 1862 "`x.__truediv__(y)`.") 1863 elif config.int_division == 'int': 1864 warnings.warn( 1865 "Division of two integer types with x / y is deprecated, " 1866 "please use x // y for an integer division.", 1867 DeprecationWarning, 1868 stacklevel=4) 1869 return int_div 1870 elif config.int_division == 'floatX': 1871 return true_div 1872 else: 1873 raise NotImplementedError(config.int_division) 1874 else: 1875 return true_div 1876 1877 1878def div_proxy(x, y): 1879 """ 1880 Proxy for either true_div or int_div, depending on types of x, y. 1881 1882 """ 1883 f = int_or_true_div(as_scalar(x).type in discrete_types, 1884 as_scalar(y).type in discrete_types) 1885 return f(x, y) 1886 1887 1888class TrueDiv(BinaryScalarOp): 1889 nfunc_spec = ('true_divide', 2, 1) 1890 1891 def output_types(self, types): 1892 if all(t in discrete_types for t in types): 1893 return [get_scalar_type(config.floatX)] 1894 else: 1895 return super(TrueDiv, self).output_types(types) 1896 1897 def impl(self, x, y): 1898 x = np.asarray(x) 1899 y = np.asarray(y) 1900 if all(a.dtype in discrete_types for a in (x, y)): 1901 return np.sctypeDict[config.floatX](float(x) / y) 1902 else: 1903 return x / y 1904 1905 def c_code(self, node, name, inputs, outputs, sub): 1906 # we generate good c code only when both are complex! 1907 (x, y) = inputs 1908 (z,) = outputs 1909 if sum([node.inputs[0].type in complex_types, 1910 node.inputs[1].type in complex_types]) == 1: 1911 raise NotImplementedError('type not supported', type) 1912 if (node.inputs[0].type in discrete_types and 1913 node.inputs[1].type in discrete_types): 1914 return "%(z)s = ((double)%(x)s) / %(y)s;" % locals() 1915 return "%(z)s = %(x)s / %(y)s;" % locals() 1916 1917 def grad(self, inputs, gout): 1918 1919 (x, y) = inputs 1920 (gz,) = gout 1921 if x.type in complex_types: 1922 raise NotImplementedError() 1923 1924 # If the output of this op is discrete, then it 1925 # it is locally flat everywhere, so the gradient 1926 # through it is 0. 1927 # This is different from it not being connected 1928 # to the output; x/y is still a function of x 1929 # and y; it's just a step function. 1930 if all(a.dtype in discrete_types for a in (x, y)): 1931 return [x.zeros_like(), y.zeros_like()] 1932 1933 first_part = gz / y 1934 1935 if y.type in complex_types: 1936 raise NotImplementedError() 1937 1938 second_part = -(gz * x) / (y * y) 1939 1940 return first_part, second_part 1941 1942true_div = TrueDiv(upcast_out, name='true_div') 1943 1944 1945class IntDiv(BinaryScalarOp): 1946 nfunc_spec = ('floor_divide', 2, 1) 1947 complex_error = ComplexError( 1948 "Theano does not support integer division (//) on " 1949 "complex numbers, since numpy deprecated it.") 1950 1951 def impl(self, x, y): 1952 return x // y 1953 1954 def c_support_code(self): 1955 # We use a macro as python use % as a special string character, 1956 # and the output of c_code may be run through another level 1957 # of string formatting. 1958 return "#define THEANO_MACRO_MOD(x,y) (x % y)" 1959 1960 def c_code(self, node, name, inputs, outputs, sub): 1961 (x, y) = inputs 1962 (z,) = outputs 1963 fail = sub['fail'] 1964 1965 t = node.inputs[0].type.upcast(*[i.type for i in node.inputs[1:]]) 1966 if t in imap(str, discrete_types): 1967 x_div_y_pp = '(%(x)s / %(y)s)' % locals() 1968 x_div_y_mp = '((-%(x)s) / %(y)s)' % locals() 1969 x_mod_y_mp = 'THEANO_MACRO_MOD((-%(x)s), %(y)s)' % locals() 1970 x_div_y_pm = '(%(x)s / (-%(y)s))' % locals() 1971 x_mod_y_pm = 'THEANO_MACRO_MOD(%(x)s, (-%(y)s))' % locals() 1972 x_div_y_mm = '((-%(x)s) / (-%(y)s))' % locals() 1973 # If we are in a gpuarray kernel, %(fail)s exits the kernel, 1974 # and we do not have any error report, and we cannot set 1975 # Python error messages either, so for now we just call the 1976 # cuda function, which return a binary pattern of all 1s. 1977 div_zero = dedent(''' 1978 #ifdef KERNEL 1979 %(z)s = %(x_div_y_pp)s; 1980 #else 1981 PyErr_SetString(PyExc_ZeroDivisionError, "integer division by zero"); 1982 %(fail)s 1983 #endif 1984 ''') % locals() 1985 elif t in imap(str, float_types): 1986 # We need to call different functions of math.h 1987 # depending on the type 1988 if t == 'float32': 1989 floor = 'floorf' 1990 fmod = 'fmodf' 1991 elif t == 'float64': 1992 floor = 'floor' 1993 fmod = 'fmod' 1994 else: 1995 raise NotImplementedError('type not supported', t) 1996 1997 x_div_y_pp = '%(floor)s(%(x)s / %(y)s)' % locals() 1998 x_div_y_mp = '%(floor)s((-%(x)s) / %(y)s)' % locals() 1999 x_mod_y_mp = '%(fmod)s((-%(x)s), %(y)s)' % locals() 2000 x_div_y_pm = '%(floor)s(%(x)s / (-%(y)s))' % locals() 2001 x_mod_y_pm = '%(fmod)s(%(x)s, (-%(y)s))' % locals() 2002 x_div_y_mm = '%(floor)s((-%(x)s) / (-%(y)s))' % locals() 2003 div_zero = '%(z)s = %(x_div_y_pp)s;' % locals() 2004 elif t in complex_types: 2005 raise self.complex_error 2006 else: 2007 raise NotImplementedError('type not supported', t) 2008 2009 return dedent(""" 2010 if (%(y)s == 0) { 2011 %(div_zero)s; 2012 } else if (%(y)s < 0) { 2013 if (%(x)s < 0) { 2014 %(z)s = %(x_div_y_mm)s; 2015 } else { 2016 %(z)s = - %(x_div_y_pm)s - ((%(x_mod_y_pm)s == 0) ? 0 : 1); 2017 } 2018 } else { 2019 if (%(x)s < 0) { 2020 %(z)s = - %(x_div_y_mp)s - ((%(x_mod_y_mp)s == 0) ? 0 : 1); 2021 } else { 2022 %(z)s = %(x_div_y_pp)s; 2023 } 2024 } 2025 """) % locals() 2026 2027 def c_code_cache_version(self): 2028 return (6,) 2029 2030 def grad(self, inputs, g_output): 2031 return [inp.zeros_like(dtype=theano.config.floatX) 2032 for inp in inputs] 2033int_div = IntDiv(upcast_out, name='int_div') 2034 2035 2036floor_div = int_div 2037 2038 2039def mod_check(x, y): 2040 if (as_scalar(x).type in complex_types or 2041 as_scalar(y).type in complex_types): 2042 # Currently forbidden. 2043 raise Mod.complex_error 2044 else: 2045 return mod(x, y) 2046 2047 2048class Mod(BinaryScalarOp): 2049 nfunc_spec = ('mod', 2, 1) 2050 complex_error = ComplexError( 2051 "Theano does not support the mod operator (%) on " 2052 "complex numbers, since numpy deprecated it.") 2053 2054 def impl(self, x, y): 2055 if isinstance(x, np.complex) or isinstance(y, np.complex): 2056 raise self.complex_error 2057 return x % y 2058 2059 def c_code_cache_version(self): 2060 return (9,) 2061 2062 def c_support_code(self): 2063 # We use a macro as python use % as a special string character, 2064 # and the output of c_code may be run through another level 2065 # of string formatting. 2066 return "#define THEANO_MACRO_MOD(x, y) (x % y)" 2067 2068 def c_code(self, node, name, inputs, outputs, sub): 2069 """ 2070 We want the result to have the same sign as Python, not the other 2071 implementation of mod. 2072 2073 """ 2074 (x, y) = inputs 2075 (z,) = outputs 2076 fail = sub['fail'] 2077 t = node.inputs[0].type.upcast(*[i.type for i in node.inputs[1:]]) 2078 if (str(t) in imap(str, discrete_types) or 2079 t in ['uint8', 'int8', 'uint16', 'int16'] or 2080 t in ['uint32', 'int32', 'uint64', 'int64'] or 2081 t in discrete_types): 2082 # The above or's should not be needed anymore. However, for now we 2083 # keep them out of safety, and verify they are useless with an 2084 # assert. 2085 assert str(t) in imap(str, discrete_types) 2086 x_mod_y = "THEANO_MACRO_MOD(%(x)s, %(y)s)" % locals() 2087 x_mod_ymm = "THEANO_MACRO_MOD(-%(x)s, -%(y)s)" % locals() 2088 x_mod_ypm = "THEANO_MACRO_MOD(%(x)s, -%(y)s)" % locals() 2089 x_mod_ymp = "THEANO_MACRO_MOD(-%(x)s, %(y)s)" % locals() 2090 # If we are in a gpuarray kernel, %(fail)s exits the kernel, 2091 # and we do not have any error report, and we cannot set 2092 # Python error messages either, so for now we just call the 2093 # cuda function, returning a binary pattern depending on dtype 2094 mod_zero = dedent(''' 2095 #ifdef KERNEL 2096 %(z)s = %(x_mod_y)s; 2097 #else 2098 PyErr_SetString(PyExc_ZeroDivisionError, "integer modulo by zero"); 2099 %(fail)s 2100 #endif 2101 ''') % locals() 2102 elif (str(t) in imap(str, float_types) or 2103 t in ['float32', 'float64'] or 2104 t in float_types): 2105 # The above or's should not be needed anymore. However, for now we 2106 # keep them out of safety, and verify they are useless with an 2107 # assert. 2108 assert str(t) in imap(str, float_types) 2109 x_mod_y = "fmod(%(x)s, %(y)s)" % locals() 2110 x_mod_ymm = "fmod(-%(x)s, -%(y)s)" % locals() 2111 x_mod_ypm = "fmod(%(x)s, -%(y)s)" % locals() 2112 x_mod_ymp = "fmod(-%(x)s, %(y)s)" % locals() 2113 mod_zero = "%(z)s = %(x_mod_y)s;" % locals() 2114 elif str(t) in imap(str, complex_types): 2115 raise self.complex_error 2116 else: 2117 raise NotImplementedError('type not supported', t) 2118 2119 return dedent(""" 2120 if (%(y)s == 0) { 2121 %(mod_zero)s; 2122 } else if (%(y)s < 0){ 2123 if (%(x)s < 0){ 2124 %(z)s = -(%(x_mod_ymm)s); 2125 } else { 2126 %(z)s = (%(x_mod_ypm)s) + (%(x_mod_ypm)s != 0 ? %(y)s : 0); 2127 } 2128 } else { 2129 if (%(x)s < 0){ 2130 %(z)s = - %(x_mod_ymp)s + (%(x_mod_ymp)s != 0 ? %(y)s : 0); 2131 } else { 2132 %(z)s = %(x_mod_y)s; 2133 } 2134 } 2135 """) % locals() 2136 2137 def L_op(self, inputs, outputs, gout): 2138 (x, y) = inputs 2139 (gz,) = gout 2140 if outputs[0].type.dtype in discrete_types: 2141 # The gradient does not flow in if the output is discrete 2142 return [x.zeros_like(dtype=theano.config.floatX), 2143 y.zeros_like(dtype=theano.config.floatX)] 2144 return [gz, 2145 -(x // y) * gz] 2146 2147mod = Mod(upcast_out, name='mod') 2148 2149 2150class Pow(BinaryScalarOp): 2151 nfunc_spec = ('power', 2, 1) 2152 2153 def impl(self, x, y): 2154 return x ** y 2155 2156 def c_code(self, node, name, inputs, outputs, sub): 2157 (x, y) = inputs 2158 (z,) = outputs 2159 if (node.inputs[0].type in complex_types or 2160 node.inputs[1].type in complex_types): 2161 raise NotImplementedError('type not supported', type) 2162 return "%(z)s = pow(%(x)s, %(y)s);" % locals() 2163 2164 def L_op(self, inputs, outputs, gout): 2165 (x, y) = inputs 2166 (gz,) = gout 2167 if gz.type in complex_types: 2168 raise NotImplementedError() 2169 2170 if outputs[0].type in discrete_types: 2171 return [x.zeros_like().astype(theano.config.floatX), 2172 y.zeros_like().astype(theano.config.floatX)] 2173 2174 first_part = gz * y * x ** (y - 1) 2175 2176 second_part = gz * log(x) * x ** y 2177 second_part = switch(eq(x, 0), 0, second_part) 2178 2179 return (first_part, second_part) 2180 2181 def c_code_contiguous(self, node, name, inputs, outputs, sub): 2182 (x, y) = inputs 2183 (z,) = outputs 2184 if not theano.config.lib.amdlibm: 2185 raise theano.gof.utils.MethodNotDefined() 2186 2187 # We compare the dtype AND the broadcast flag 2188 # as this function do not broadcast 2189 if (node.inputs[0].type == node.outputs[0].type and 2190 node.inputs[1].type == node.outputs[0].type and 2191 # amdlibm 3.0 do not have a float64 version of this SIMD function 2192 node.inputs[0].dtype == 'float32' and 2193 node.inputs[1].dtype == 'float32'): 2194 dtype = 'float' 2195 fct = "amd_vrsa_powf" 2196 return """ 2197 npy_intp n = PyArray_SIZE(%(z)s); 2198 %(dtype)s * x = (%(dtype)s*) PyArray_DATA(%(x)s); 2199 %(dtype)s * y = (%(dtype)s*) PyArray_DATA(%(y)s); 2200 %(dtype)s * z = (%(dtype)s*) PyArray_DATA(%(z)s); 2201 %(fct)s(n, x, y, z); 2202 """ % locals() 2203 # We compare the dtype and check we broadcast a scalar 2204 elif (node.inputs[0].type == node.outputs[0].type and 2205 node.inputs[1].dtype == node.outputs[0].dtype and 2206 all(node.inputs[1].broadcastable) and 2207 # amdlibm 3.0 do not have a float64 version of this SIMD function 2208 node.inputs[0].dtype == 'float32' and 2209 node.inputs[1].dtype == 'float32'): 2210 dtype = 'float' 2211 fct = "amd_vrsa_powxf" 2212 return """ 2213 npy_intp n = PyArray_SIZE(%(z)s); 2214 %(dtype)s * x = (%(dtype)s*) PyArray_DATA(%(x)s); 2215 %(dtype)s * y = (%(dtype)s*) PyArray_DATA(%(y)s); 2216 %(dtype)s * z = (%(dtype)s*) PyArray_DATA(%(z)s); 2217 %(fct)s(n, x, *y, z); 2218 """ % locals() 2219 2220 raise theano.gof.utils.MethodNotDefined() 2221 2222 2223pow = Pow(upcast_out_min8, name='pow') 2224 2225 2226class Clip(ScalarOp): 2227 nin = 3 2228 # The numpy.clip don't work correctly when the min is bigger then the max, 2229 # So we do not use nfunc_spec = ('clip', 3, 1) 2230 2231 def impl(self, x, min, max): 2232 if x < min: 2233 return min 2234 elif x > max: 2235 return max 2236 else: 2237 return x 2238 2239 def c_code(self, node, name, inputs, outputs, sub): 2240 (x, min, max) = inputs 2241 (z,) = outputs 2242 return "%(z)s = %(x)s < %(min)s ? %(min)s : %(x)s > %(max)s ? %(max)s : %(x)s;" % locals() 2243 2244 def L_op(self, inputs, outputs, gout): 2245 (x, mn, mx) = inputs 2246 (gz,) = gout 2247 assert gz.type not in complex_types 2248 gx = ((x >= mn) & (x <= mx)) * gz 2249 gmn = (x < mn) * gz 2250 gmx = (x > mx) * gz 2251 2252 def handle_int(v): 2253 if outputs[0].type in int_types: 2254 return v.zeros_like().astype(config.floatX) 2255 return v 2256 2257 return list(map(handle_int, [gx, gmn, gmx])) 2258 2259# Don't allow complex even if numpy do 2260# As there is no mathematical reason for this function on complex 2261clip = Clip(upcast_out_no_complex, name='clip') 2262 2263 2264class Second(BinaryScalarOp): 2265 def impl(self, x, y): 2266 return y 2267 2268 def c_code(self, node, name, inputs, outputs, sub): 2269 (x, y) = inputs 2270 (z,) = outputs 2271 return "%(z)s = %(y)s;" % locals() 2272 2273 def connection_pattern(self, node): 2274 2275 # x is never connected because its elements are never used 2276 # y is connected because its elements are copied over 2277 2278 return [[False], [True]] 2279 2280 def grad(self, inputs, gout): 2281 2282 (x, y) = inputs 2283 (gz,) = gout 2284 if y.type in continuous_types: 2285 # x is disconnected because the elements of x are not used 2286 return DisconnectedType()(), gz 2287 else: 2288 # when y is discrete, we assume the function can be extended 2289 # to deal with real-valued inputs by rounding them to the 2290 # nearest integer. f(x+eps) thus equals f(x) so the gradient 2291 # is zero, not disconnected or undefined 2292 return DisconnectedType()(), y.zeros_like() 2293 2294second = Second(transfer_type(1), name='second') 2295 2296 2297class Identity(UnaryScalarOp): 2298 def impl(self, input): 2299 return input 2300 2301 def c_code(self, node, name, inputs, outputs, sub): 2302 (x,) = inputs 2303 (z,) = outputs 2304 return "%(z)s = %(x)s;" % locals() 2305 2306 def grad(self, inputs, gout): 2307 (x,) = inputs 2308 (gz,) = gout 2309 if x.type in continuous_types: 2310 return gz, 2311 else: 2312 return x.zeros_like(dtype=theano.config.floatX), 2313identity = Identity(same_out, name='identity') 2314 2315 2316# CASTING OPERATIONS 2317class Cast(UnaryScalarOp): 2318 def __init__(self, o_type, name=None): 2319 if not isinstance(o_type, Scalar): 2320 raise TypeError(o_type) 2321 super(Cast, self).__init__(specific_out(o_type), name=name) 2322 self.o_type = o_type 2323 self.ctor = getattr(np, o_type.dtype) 2324 2325 def __str__(self): 2326 return '%s{%s}' % (self.__class__.__name__, self.o_type.dtype) 2327 2328 def clone_float32(self): 2329 if self.o_type == float16: 2330 return convert_to_float32 2331 return self 2332 2333 def make_new_inplace(self, output_types_preference=None, name=None): 2334 """ 2335 This op.__init__ fct don't have the same parameter as other scalar op. 2336 This breaks the insert_inplace_optimizer optimization. 2337 This function is a fix to patch this, by ignoring the 2338 output_types_preference passed by the optimization, and replacing it 2339 by the current output type. This should only be triggered when 2340 both input and output have the same dtype anyway. 2341 2342 """ 2343 return self.__class__(self.o_type, name) 2344 2345 def impl(self, input): 2346 return self.ctor(input) 2347 2348 def c_code(self, node, name, inputs, outputs, sub): 2349 (x,) = inputs 2350 (z,) = outputs 2351 if node.outputs[0].type == bool: 2352 return "%s = (%s) ? 1 : 0;" % (z, x) 2353 return "%s = (%s)%s;" % (z, node.outputs[0].type.dtype_specs()[1], x) 2354 2355 def grad(self, inputs, gout): 2356 (x,) = inputs 2357 (gz,) = gout 2358 if self.o_type in continuous_types: 2359 return [gz] 2360 else: 2361 return [x.zeros_like().astype(theano.config.floatX)] 2362 2363 def c_code_cache_version(self): 2364 s = super(Cast, self).c_code_cache_version() 2365 if s: 2366 return (4,) + s 2367 else: 2368 return s 2369 2370convert_to_bool = Cast(bool, name='convert_to_bool') 2371convert_to_int8 = Cast(int8, name='convert_to_int8') 2372convert_to_int16 = Cast(int16, name='convert_to_int16') 2373convert_to_int32 = Cast(int32, name='convert_to_int32') 2374convert_to_int64 = Cast(int64, name='convert_to_int64') 2375convert_to_uint8 = Cast(uint8, name='convert_to_uint8') 2376convert_to_uint16 = Cast(uint16, name='convert_to_uint16') 2377convert_to_uint32 = Cast(uint32, name='convert_to_uint32') 2378convert_to_uint64 = Cast(uint64, name='convert_to_uint64') 2379convert_to_float16 = Cast(float16, name='convert_to_float16') 2380convert_to_float32 = Cast(float32, name='convert_to_float32') 2381convert_to_float64 = Cast(float64, name='convert_to_float64') 2382convert_to_complex64 = Cast(complex64, name='convert_to_complex64') 2383convert_to_complex128 = Cast(complex128, name='convert_to_complex128') 2384 2385_cast_mapping = { 2386 'bool': convert_to_bool, 2387 'int8': convert_to_int8, 2388 'int16': convert_to_int16, 2389 'int32': convert_to_int32, 2390 'int64': convert_to_int64, 2391 'uint8': convert_to_uint8, 2392 'uint16': convert_to_uint16, 2393 'uint32': convert_to_uint32, 2394 'uint64': convert_to_uint64, 2395 'float16': convert_to_float16, 2396 'float32': convert_to_float32, 2397 'float64': convert_to_float64, 2398 'complex64': convert_to_complex64, 2399 'complex128': convert_to_complex128} 2400 2401 2402def cast(x, dtype): 2403 """ 2404 Symbolically cast `x` to a Scalar of given `dtype`. 2405 2406 """ 2407 if dtype == 'floatX': 2408 dtype = config.floatX 2409 2410 _x = as_scalar(x) 2411 if _x.type.dtype == dtype: 2412 return _x 2413 if _x.type.dtype.startswith('complex') and not dtype.startswith('complex'): 2414 raise TypeError('Casting from complex to real is ambiguous: consider' 2415 ' real(), imag(), angle() or abs()') 2416 return _cast_mapping[dtype](_x) 2417 2418 2419class Abs(UnaryScalarOp): 2420 nfunc_spec = ('abs', 1, 1) 2421 2422 def make_node(self, x): 2423 inputs = [as_scalar(input) for input in [x]] 2424 if inputs[0].type == complex64: 2425 outputs = [float32()] 2426 elif inputs[0].type == complex128: 2427 outputs = [float64()] 2428 else: 2429 outputs = [t() for t in self.output_types( 2430 [input.type for input in inputs])] 2431 return Apply(self, inputs, outputs) 2432 2433 def impl(self, x): 2434 return np.abs(x) 2435 2436 def L_op(self, inputs, outputs, gout): 2437 (x,) = inputs 2438 (gz,) = gout 2439 if (outputs[0].type in discrete_types): 2440 if x.type in discrete_types: 2441 return [x.zeros_like(dtype=theano.config.floatX)] 2442 else: 2443 return [x.zeros_like()] 2444 2445 if x.type in float_types: 2446 return gz * sgn(x), 2447 return gz * x / abs(x), # formula works for complex and real 2448 2449 def c_code(self, node, name, inputs, outputs, sub): 2450 (x,) = inputs 2451 (z,) = outputs 2452 type = node.inputs[0].type 2453 if type in int_types: 2454 return "%(z)s = abs(%(x)s);" % locals() 2455 if type in float_types: 2456 return "%(z)s = fabs(%(x)s);" % locals() 2457 if type in complex_types: 2458 return "%(z)s = sqrt(%(x)s.real*%(x)s.real + %(x)s.imag*%(x)s.imag);" % locals() 2459 if node.outputs[0].type == bool: 2460 return "%(z)s = (%(x)s) ? 1 : 0;" % locals() 2461 if type in uint_types: 2462 # uint are always already absolute value. 2463 return "%(z)s = %(x)s;" % locals() 2464 raise NotImplementedError('type not supported', type) 2465abs_ = Abs(same_out) 2466 2467 2468class Sgn(UnaryScalarOp): 2469 nfunc_spec = ('sign', 1, 1) 2470 2471 @staticmethod 2472 def output_types_preference(x): 2473 if x == bool: 2474 raise TypeError(x) 2475 return same_out_nocomplex(x) 2476 2477 def impl(self, x): 2478 # casting to output type is handled by filter 2479 return np.sign(x) 2480 2481 def grad(self, inputs, gout): 2482 (x,) = inputs 2483 (gz,) = gout 2484 rval = x.zeros_like() 2485 2486 if rval.type.dtype in discrete_types: 2487 rval = rval.astype(theano.config.floatX) 2488 2489 return [rval] 2490 2491 def c_code(self, node, name, inputs, outputs, sub): 2492 # casting is done by compiler 2493 # TODO: use copysign 2494 (x,) = inputs 2495 (z,) = outputs 2496 type = node.inputs[0].type 2497 if type in float_types: 2498 return '%(z)s = (%(x)s > 0) ? 1. : ((%(x)s < 0) ? -1. : (isnan(%(x)s) ? NAN : 0.));' % locals() 2499 if type in int_types: 2500 return "%(z)s = (%(x)s >= 0) ? (%(x)s == 0) ? 0 : 1 : -1;" % locals() 2501 raise ComplexError('complex has no sgn') 2502 2503 def c_code_cache_version(self): 2504 s = super(Sgn, self).c_code_cache_version() 2505 if s: 2506 return (4,) + s 2507 else: # if parent is unversioned, we are too 2508 return s 2509sgn = Sgn(name='sgn') 2510 2511 2512class Ceil(UnaryScalarOp): 2513 nfunc_spec = ('ceil', 1, 1) 2514 2515 def impl(self, x): 2516 return np.ceil(x) 2517 2518 def grad(self, inputs, gout): 2519 (x,) = inputs 2520 (gz,) = gout 2521 rval = x.zeros_like() 2522 2523 if rval.type.dtype in discrete_types: 2524 rval = rval.astype(theano.config.floatX) 2525 2526 return [rval] 2527 2528 def c_code(self, node, name, inputs, outputs, sub): 2529 (x,) = inputs 2530 (z,) = outputs 2531 cast = node.outputs[0].type.dtype_specs()[1] 2532 return "%(z)s = ceil((%(cast)s)%(x)s);" % locals() 2533ceil = Ceil(upgrade_to_float_no_complex, name='ceil') 2534 2535 2536class Floor(UnaryScalarOp): 2537 nfunc_spec = ('floor', 1, 1) 2538 2539 def impl(self, x): 2540 return np.floor(x) 2541 2542 def grad(self, inputs, gout): 2543 (x,) = inputs 2544 (gz,) = gout 2545 rval = x.zeros_like() 2546 2547 if rval.type.dtype in discrete_types: 2548 rval = rval.astype(theano.config.floatX) 2549 2550 return [rval] 2551 2552 def c_code(self, node, name, inputs, outputs, sub): 2553 (x,) = inputs 2554 (z,) = outputs 2555 cast = node.outputs[0].type.dtype_specs()[1] 2556 return "%(z)s = floor((%(cast)s)%(x)s);" % locals() 2557floor = Floor(upgrade_to_float_no_complex, name='floor') 2558 2559 2560class Trunc(UnaryScalarOp): 2561 nfunc_spec = ('trunc', 1, 1) 2562 2563 def impl(self, x): 2564 return np.trunc(x) 2565 2566 def grad(self, inputs, gout): 2567 (x,) = inputs 2568 (gz,) = gout 2569 return [x.zeros_like().astype(theano.config.floatX)] 2570 2571 def c_code(self, node, name, inputs, outputs, sub): 2572 (x,) = inputs 2573 (z,) = outputs 2574 return "%(z)s = %(x)s >= 0? floor(%(x)s): -floor(-%(x)s);" % locals() 2575trunc = Trunc(upgrade_to_float_no_complex, name='trunc') 2576 2577 2578class RoundHalfToEven(UnaryScalarOp): 2579 """ 2580 This function implement the same rounding than numpy: Round half to even. 2581 2582 c/c++ round fct IS DIFFERENT! 2583 See http://en.wikipedia.org/wiki/Rounding for more details. 2584 2585 """ 2586 nfunc_spec = ('around', 1, 1) 2587 2588 def impl(self, x): 2589 return np.round(x) 2590 2591 def grad(self, inputs, gout): 2592 (x,) = inputs 2593 (gz,) = gout 2594 rval = x.zeros_like() 2595 2596 if rval.type.dtype in discrete_types: 2597 rval = rval.astype(theano.config.floatX) 2598 2599 return [rval] 2600 2601 def c_code_cache_version(self): 2602 return (1,) 2603 2604 def c_code(self, node, name, inputs, outputs, sub): 2605 (x,) = inputs 2606 (z,) = outputs 2607 typ = node.outputs[0].type.dtype 2608 if typ not in ['float32', 'float64']: 2609 raise NotImplementedError("The output should be float32 or float64") 2610 if typ == 'float32': 2611 ctype = 'float' 2612 floor_function = 'floorf' 2613 else: 2614 ctype = 'double' 2615 floor_function = 'floor' 2616 return """ 2617 /* Code inspired from NumPy npy_rint implementation. */ 2618 { 2619 %(ctype)s y, r; 2620 y = %(floor_function)s(%(x)s); 2621 r = %(x)s - y; 2622 if(r > 0.5) { 2623 y += 1; 2624 } else if(r == 0.5) { 2625 r = y - 2.0*%(floor_function)s(0.5*y); 2626 /* 2627 If y is even, then r == 0 2628 If y is odd, then r == 1 2629 So we can just add r to y, so that 2630 y will be incremented only if he's odd. 2631 */ 2632 y += (int)r; 2633 } 2634 %(z)s = y; 2635 } 2636 """ % locals() 2637round_half_to_even = RoundHalfToEven(same_out_float_only) 2638 2639 2640def round_half_away_from_zero_(a): 2641 if a > 0: 2642 return np.floor(a + 0.5) 2643 else: 2644 return np.ceil(a - 0.5) 2645 2646round_half_away_from_zero_vec64 = np.vectorize( 2647 round_half_away_from_zero_, 2648 doc='round_half_away_from_zero_vec64') 2649round_half_away_from_zero_vec32 = np.vectorize( 2650 round_half_away_from_zero_, 2651 doc='round_half_away_from_zero_vec32', 2652 otypes=['float32']) 2653 2654 2655def round_half_away_from_zero_vec(a): 2656 if getattr(a, 'dtype', None) == np.float32: 2657 return round_half_away_from_zero_vec32(a) 2658 return round_half_away_from_zero_vec64(a) 2659 2660 2661class RoundHalfAwayFromZero(UnaryScalarOp): 2662 """ 2663 Implement the same rounding algo as c round() fct. 2664 2665 numpy.round fct IS DIFFERENT! 2666 See http://en.wikipedia.org/wiki/Rounding for more details. 2667 2668 """ 2669 def impl(self, x): 2670 return round_half_away_from_zero_vec(x) 2671 2672 def grad(self, inputs, gout): 2673 (x,) = inputs 2674 (gz,) = gout 2675 rval = x.zeros_like() 2676 2677 if rval.type.dtype in discrete_types: 2678 rval = rval.astype(theano.config.floatX) 2679 2680 return [rval] 2681 2682 def c_code(self, node, name, inputs, outputs, sub): 2683 (x,) = inputs 2684 (z,) = outputs 2685 if node.outputs[0].type.dtype in ['float32', 'float64']: 2686 return "%(z)s = round(%(x)s);" % locals() 2687 else: 2688 raise NotImplementedError("The output should be float32 or float64") 2689round_half_away_from_zero = RoundHalfAwayFromZero(same_out_float_only) 2690 2691 2692class Neg(UnaryScalarOp): 2693 # We can use numpy.negative here, because even if it gives unexpected 2694 # results on Boolean arrays, it will be passed other dtypes as Theano 2695 # does not have a Boolean type for tensors. 2696 nfunc_spec = ('negative', 1, 1) 2697 2698 def impl(self, x): 2699 return -x 2700 2701 def L_op(self, inputs, outputs, gout): 2702 (x,) = inputs 2703 (gz,) = gout 2704 if outputs[0].type in discrete_types: 2705 if x.type in discrete_types: 2706 return [x.zeros_like(dtype=theano.config.floatX)] 2707 else: 2708 return [x.zeros_like()] 2709 2710 return -gz, 2711 2712 def c_code(self, node, name, inputs, outputs, sub): 2713 (x,) = inputs 2714 (z,) = outputs 2715 return "%(z)s = -%(x)s;" % locals() 2716neg = Neg(same_out_nobool, name='neg') 2717 2718pprint.assign(add, printing.OperatorPrinter('+', -2, 'either')) 2719pprint.assign(mul, printing.OperatorPrinter('*', -1, 'either')) 2720pprint.assign(sub, printing.OperatorPrinter('-', -2, 'left')) 2721pprint.assign(neg, printing.OperatorPrinter('-', 0, 'either')) 2722pprint.assign(true_div, printing.OperatorPrinter('/', -1, 'left')) 2723pprint.assign(int_div, printing.OperatorPrinter('//', -1, 'left')) 2724pprint.assign(pow, printing.OperatorPrinter('**', 1, 'right')) 2725pprint.assign(mod, printing.OperatorPrinter('%', -1, 'left')) 2726 2727 2728class Inv(UnaryScalarOp): 2729 """ 2730 Multiplicative inverse. Also called reciprocal. 2731 2732 """ 2733 def impl(self, x): 2734 return np.float32(1.0) / x 2735 2736 def L_op(self, inputs, outputs, gout): 2737 (x,) = inputs 2738 (gz,) = gout 2739 if x.type in complex_types: 2740 raise NotImplementedError() 2741 if outputs[0].type in discrete_types: 2742 if x.type in discrete_types: 2743 return [x.zeros_like(dtype=theano.config.floatX)] 2744 else: 2745 return [x.zeros_like()] 2746 2747 return -gz / (x * x), 2748 2749 def c_code(self, node, name, inputs, outputs, sub): 2750 (x,) = inputs 2751 (z,) = outputs 2752 if node.inputs[0].type in complex_types: 2753 raise NotImplementedError() 2754 return "%(z)s = 1.0 / %(x)s;" % locals() 2755inv = Inv(upgrade_to_float, name='inv') 2756 2757 2758class Log(UnaryScalarOp): 2759 """ 2760 log base e. 2761 2762 """ 2763 nfunc_spec = ('log', 1, 1) 2764 amd_float32 = "amd_vrsa_logf" 2765 amd_float64 = "amd_vrda_log" 2766 2767 def impl(self, x): 2768 # If x is an int8 or uint8, numpy.log will compute the result in 2769 # half-precision (float16), where we want float32. 2770 x_dtype = str(getattr(x, 'dtype', '')) 2771 if x_dtype in ('int8', 'uint8'): 2772 return np.log(x, sig='f') 2773 return np.log(x) 2774 2775 def L_op(self, inputs, outputs, gout): 2776 (x,) = inputs 2777 (gz,) = gout 2778 if x.type in complex_types: 2779 raise NotImplementedError() 2780 if outputs[0].type in discrete_types: 2781 if x.type in discrete_types: 2782 return [x.zeros_like(dtype=theano.config.floatX)] 2783 else: 2784 return [x.zeros_like()] 2785 2786 return gz / x, 2787 2788 def c_code(self, node, name, inputs, outputs, sub): 2789 # todo: the version using log2 seems to be very slightly faster 2790 # on some machines for some reason, check if it's worth switching 2791 # return "%(z)s = log2(%(x)s) * 0.69314718055994529;" % locals() 2792 (x,) = inputs 2793 (z,) = outputs 2794 if node.inputs[0].type in complex_types: 2795 raise NotImplementedError('type not supported', type) 2796 cast = node.outputs[0].type.dtype_specs()[1] 2797 return "%(z)s = log((%(cast)s)%(x)s);" % locals() 2798log = Log(upgrade_to_float, name='log') 2799 2800 2801class Log2(UnaryScalarOp): 2802 """ 2803 log base 2. 2804 2805 """ 2806 nfunc_spec = ('log2', 1, 1) 2807 amd_float32 = "amd_vrsa_log2f" 2808 amd_float64 = "amd_vrda_log2" 2809 2810 def impl(self, x): 2811 # If x is an int8 or uint8, numpy.log2 will compute the result in 2812 # half-precision (float16), where we want float32. 2813 x_dtype = str(getattr(x, 'dtype', '')) 2814 if x_dtype in ('int8', 'uint8'): 2815 return np.log2(x, sig='f') 2816 return np.log2(x) 2817 2818 def L_op(self, inputs, outputs, gout): 2819 (x,) = inputs 2820 (gz,) = gout 2821 if x.type in complex_types: 2822 raise NotImplementedError() 2823 if outputs[0].type in discrete_types: 2824 if x.type in discrete_types: 2825 return [x.zeros_like(dtype=theano.config.floatX)] 2826 else: 2827 return [x.zeros_like()] 2828 2829 return gz / (x * np.asarray(math.log(2.0)).astype(x.dtype)), 2830 2831 def c_code(self, node, name, inputs, outputs, sub): 2832 (x,) = inputs 2833 (z,) = outputs 2834 if node.inputs[0].type in complex_types: 2835 raise NotImplementedError('type not supported', type) 2836 cast = node.outputs[0].type.dtype_specs()[1] 2837 return "%(z)s = log2((%(cast)s)%(x)s);" % locals() 2838log2 = Log2(upgrade_to_float, name='log2') 2839 2840 2841class Log10(UnaryScalarOp): 2842 """ 2843 log base 10. 2844 2845 """ 2846 nfunc_spec = ('log10', 1, 1) 2847 amd_float32 = "amd_vrsa_log10f" 2848 amd_float64 = "amd_vrda_log10" 2849 2850 def impl(self, x): 2851 # If x is an int8 or uint8, numpy.log10 will compute the result in 2852 # half-precision (float16), where we want float32. 2853 x_dtype = str(getattr(x, 'dtype', '')) 2854 if x_dtype in ('int8', 'uint8'): 2855 return np.log10(x, sig='f') 2856 return np.log10(x) 2857 2858 def L_op(self, inputs, outputs, gout): 2859 (x,) = inputs 2860 (gz,) = gout 2861 if x.type in complex_types: 2862 raise NotImplementedError() 2863 if outputs[0].type in discrete_types: 2864 if x.type in discrete_types: 2865 return [x.zeros_like(dtype=theano.config.floatX)] 2866 else: 2867 return [x.zeros_like()] 2868 2869 return gz / (x * np.log(10.0)), 2870 2871 def c_code(self, node, name, inputs, outputs, sub): 2872 (x,) = inputs 2873 (z,) = outputs 2874 if node.inputs[0].type in complex_types: 2875 raise NotImplementedError('type not supported', type) 2876 cast = node.outputs[0].type.dtype_specs()[1] 2877 return "%(z)s = log10((%(cast)s)%(x)s);" % locals() 2878log10 = Log10(upgrade_to_float, name='log10') 2879 2880 2881class Log1p(UnaryScalarOp): 2882 """ 2883 log(1+x). 2884 2885 """ 2886 nfunc_spec = ('log1p', 1, 1) 2887 2888 def impl(self, x): 2889 # If x is an int8 or uint8, numpy.log1p will compute the result in 2890 # half-precision (float16), where we want float32. 2891 x_dtype = str(getattr(x, 'dtype', '')) 2892 if x_dtype in ('int8', 'uint8'): 2893 return np.log1p(x, sig='f') 2894 return np.log1p(x) 2895 2896 def L_op(self, inputs, outputs, gout): 2897 (x,) = inputs 2898 (gz,) = gout 2899 if gz.type in complex_types: 2900 raise NotImplementedError() 2901 if outputs[0].type in discrete_types: 2902 if x.type in discrete_types: 2903 return [x.zeros_like(dtype=theano.config.floatX)] 2904 else: 2905 return [x.zeros_like()] 2906 2907 return [gz / (1 + x)] 2908 2909 def c_code(self, node, name, inputs, outputs, sub): 2910 (x,) = inputs 2911 (z,) = outputs 2912 if node.inputs[0].type in complex_types: 2913 raise NotImplementedError('type not supported', type) 2914 cast = node.outputs[0].type.dtype_specs()[1] 2915 return "%(z)s = log1p((%(cast)s)%(x)s);" % locals() 2916log1p = Log1p(upgrade_to_float, name='log1p') 2917 2918 2919class Exp(UnaryScalarOp): 2920 nfunc_spec = ('exp', 1, 1) 2921 amd_float32 = "amd_vrsa_expf" 2922 amd_float64 = "amd_vrda_exp" 2923 2924 def impl(self, x): 2925 # If x is an int8 or uint8, numpy.exp will compute the result in 2926 # half-precision (float16), where we want float32. 2927 x_dtype = str(getattr(x, 'dtype', '')) 2928 if x_dtype in ('int8', 'uint8'): 2929 return np.exp(x, sig='f') 2930 return np.exp(x) 2931 2932 def L_op(self, inputs, outputs, gout): 2933 (x,) = inputs 2934 (gz,) = gout 2935 if x.type in complex_types: 2936 raise NotImplementedError() 2937 if outputs[0].type in discrete_types: 2938 if x.type in discrete_types: 2939 return [x.zeros_like(dtype=theano.config.floatX)] 2940 else: 2941 return [x.zeros_like()] 2942 2943 return gz * exp(x), 2944 2945 def c_code(self, node, name, inputs, outputs, sub): 2946 (x,) = inputs 2947 (z,) = outputs 2948 if node.inputs[0].type in complex_types: 2949 raise NotImplementedError('type not supported', type) 2950 cast = node.outputs[0].type.dtype_specs()[1] 2951 return "%(z)s = exp((%(cast)s)%(x)s);" % locals() 2952exp = Exp(upgrade_to_float, name='exp') 2953 2954 2955class Exp2(UnaryScalarOp): 2956 nfunc_spec = ('exp2', 1, 1) 2957 2958 def impl(self, x): 2959 # If x is an int8 or uint8, numpy.exp2 will compute the result in 2960 # half-precision (float16), where we want float32. 2961 x_dtype = str(getattr(x, 'dtype', '')) 2962 if x_dtype in ('int8', 'uint8'): 2963 return np.exp2(x, sig='f') 2964 return np.exp2(x) 2965 2966 def L_op(self, inputs, outputs, gout): 2967 (x,) = inputs 2968 (gz,) = gout 2969 if x.type in complex_types: 2970 raise NotImplementedError() 2971 if outputs[0].type in discrete_types: 2972 if x.type in discrete_types: 2973 return [x.zeros_like(dtype=theano.config.floatX)] 2974 else: 2975 return [x.zeros_like()] 2976 2977 return gz * exp2(x) * log(np.cast[x.type](2)), 2978 2979 def c_code(self, node, name, inputs, outputs, sub): 2980 (x,) = inputs 2981 (z,) = outputs 2982 if node.inputs[0].type in complex_types: 2983 raise NotImplementedError('type not supported', type) 2984 cast = node.outputs[0].type.dtype_specs()[1] 2985 return "%(z)s = exp2((%(cast)s)%(x)s);" % locals() 2986exp2 = Exp2(upgrade_to_float, name='exp2') 2987 2988 2989class Expm1(UnaryScalarOp): 2990 nfunc_spec = ('expm1', 1, 1) 2991 2992 def impl(self, x): 2993 # If x is an int8 or uint8, numpy.expm1 will compute the result in 2994 # half-precision (float16), where we want float32. 2995 x_dtype = str(getattr(x, 'dtype', '')) 2996 if x_dtype in ('int8', 'uint8'): 2997 return np.expm1(x, sig='f') 2998 return np.expm1(x) 2999 3000 def L_op(self, inputs, outputs, gout): 3001 (x,) = inputs 3002 (gz,) = gout 3003 if x.type in complex_types: 3004 raise NotImplementedError() 3005 if outputs[0].type in discrete_types: 3006 if x.type in discrete_types: 3007 return [x.zeros_like(dtype=theano.config.floatX)] 3008 else: 3009 return [x.zeros_like()] 3010 3011 return gz * exp(x), 3012 3013 def c_code(self, node, name, inputs, outputs, sub): 3014 (x,) = inputs 3015 (z,) = outputs 3016 if node.inputs[0].type in complex_types: 3017 raise NotImplementedError('type not supported', type) 3018 cast = node.outputs[0].type.dtype_specs()[1] 3019 return "%(z)s = expm1((%(cast)s)%(x)s);" % locals() 3020 3021 def c_code_cache_version(self): 3022 return (5,) 3023expm1 = Expm1(upgrade_to_float, name='expm1') 3024 3025 3026class Sqr(UnaryScalarOp): 3027 nfunc_spec = ('square', 1, 1) 3028 3029 def impl(self, x): 3030 return x * x 3031 3032 def L_op(self, inputs, outputs, gout): 3033 (x,) = inputs 3034 (gz,) = gout 3035 if gz.type in complex_types: 3036 raise NotImplementedError() 3037 if outputs[0].type in discrete_types: 3038 if x.type in discrete_types: 3039 return [x.zeros_like(dtype=theano.config.floatX)] 3040 else: 3041 return [x.zeros_like()] 3042 3043 return gz * x * 2, 3044 3045 def c_code(self, node, name, inputs, outputs, sub): 3046 (x,) = inputs 3047 (z,) = outputs 3048 return "%(z)s = %(x)s * %(x)s;" % locals() 3049sqr = Sqr(same_out, name='sqr') 3050 3051 3052class Sqrt(UnaryScalarOp): 3053 nfunc_spec = ('sqrt', 1, 1) 3054 3055 def impl(self, x): 3056 # If x is an int8 or uint8, numpy.sqrt will compute the result in 3057 # half-precision (float16), where we want float32. 3058 x_dtype = str(getattr(x, 'dtype', '')) 3059 if x_dtype in ('int8', 'uint8'): 3060 return np.sqrt(x, sig='f') 3061 return np.sqrt(x) 3062 3063 def L_op(self, inputs, outputs, gout): 3064 (x,) = inputs 3065 (gz,) = gout 3066 if gz.type in complex_types: 3067 raise NotImplementedError() 3068 if outputs[0].type in discrete_types: 3069 if x.type in discrete_types: 3070 return [x.zeros_like(dtype=theano.config.floatX)] 3071 else: 3072 return [x.zeros_like()] 3073 3074 return (gz * 0.5) / sqrt(x), 3075 3076 def c_code(self, node, name, inputs, outputs, sub): 3077 (x,) = inputs 3078 (z,) = outputs 3079 if node.inputs[0].type in complex_types: 3080 raise NotImplementedError('type not supported', type) 3081 cast = node.outputs[0].type.dtype_specs()[1] 3082 return "%(z)s = sqrt((%(cast)s)%(x)s);" % locals() 3083sqrt = Sqrt(upgrade_to_float, name='sqrt') 3084 3085 3086class Deg2Rad(UnaryScalarOp): 3087 nfunc_spec = ('deg2rad', 1, 1) 3088 3089 def impl(self, x): 3090 # If x is an int8 or uint8, numpy.deg2rad will compute the result in 3091 # half-precision (float16), where we want float32. 3092 x_dtype = str(getattr(x, 'dtype', '')) 3093 if x_dtype in ('int8', 'uint8'): 3094 return np.deg2rad(x, sig='f') 3095 return np.deg2rad(x) 3096 3097 def L_op(self, inputs, outputs, gout): 3098 (x,) = inputs 3099 (gz,) = gout 3100 if gz.type in complex_types: 3101 raise NotImplementedError() 3102 if outputs[0].type in discrete_types: 3103 if x.type in discrete_types: 3104 return [x.zeros_like(dtype=theano.config.floatX)] 3105 else: 3106 return [x.zeros_like()] 3107 3108 return gz * np.asarray(np.pi / 180, gz.type), 3109 3110 def c_code(self, node, name, inputs, outputs, sub): 3111 (x,) = inputs 3112 (z,) = outputs 3113 if node.inputs[0].type in complex_types: 3114 raise NotImplementedError('type not supported', type) 3115 return "%(z)s = %(x)s * (M_PI / 180.0);" % locals() 3116deg2rad = Deg2Rad(upgrade_to_float, name='deg2rad') 3117 3118 3119class Rad2Deg(UnaryScalarOp): 3120 nfunc_spec = ('rad2deg', 1, 1) 3121 3122 def impl(self, x): 3123 # If x is an int8 or uint8, numpy.rad2deg will compute the result in 3124 # half-precision (float16), where we want float32. 3125 x_dtype = str(getattr(x, 'dtype', '')) 3126 if x_dtype in ('int8', 'uint8'): 3127 return np.rad2deg(x, sig='f') 3128 return np.rad2deg(x) 3129 3130 def L_op(self, inputs, outputs, gout): 3131 (x,) = inputs 3132 (gz,) = gout 3133 if gz.type in complex_types: 3134 raise NotImplementedError() 3135 if outputs[0].type in discrete_types: 3136 if x.type in discrete_types: 3137 return [x.zeros_like(dtype=theano.config.floatX)] 3138 else: 3139 return [x.zeros_like()] 3140 3141 return gz * np.asarray(180. / np.pi, gz.type), 3142 3143 def c_code(self, node, name, inputs, outputs, sub): 3144 (x,) = inputs 3145 (z,) = outputs 3146 if node.inputs[0].type in complex_types: 3147 raise NotImplementedError('type not supported', type) 3148 return "%(z)s = %(x)s * (180.0 / M_PI);" % locals() 3149rad2deg = Rad2Deg(upgrade_to_float, name='rad2deg') 3150 3151 3152class Cos(UnaryScalarOp): 3153 nfunc_spec = ('cos', 1, 1) 3154 amd_float32 = "amd_vrsa_cosf" 3155 amd_float64 = "amd_vrda_cos" 3156 3157 def impl(self, x): 3158 # If x is an int8 or uint8, numpy.cos will compute the result in 3159 # half-precision (float16), where we want float32. 3160 x_dtype = str(getattr(x, 'dtype', '')) 3161 if x_dtype in ('int8', 'uint8'): 3162 return np.cos(x, sig='f') 3163 return np.cos(x) 3164 3165 def L_op(self, inputs, outputs, gout): 3166 (x,) = inputs 3167 (gz,) = gout 3168 if gz.type in complex_types: 3169 raise NotImplementedError() 3170 if outputs[0].type in discrete_types: 3171 if x.type in discrete_types: 3172 return [x.zeros_like(dtype=theano.config.floatX)] 3173 else: 3174 return [x.zeros_like()] 3175 3176 return -gz * sin(x), 3177 3178 def c_code(self, node, name, inputs, outputs, sub): 3179 (x,) = inputs 3180 (z,) = outputs 3181 if node.inputs[0].type in complex_types: 3182 raise NotImplementedError('type not supported', type) 3183 cast = node.outputs[0].type.dtype_specs()[1] 3184 return "%(z)s = cos((%(cast)s)%(x)s);" % locals() 3185cos = Cos(upgrade_to_float, name='cos') 3186 3187 3188class ArcCos(UnaryScalarOp): 3189 nfunc_spec = ('arccos', 1, 1) 3190 3191 def impl(self, x): 3192 # If x is an int8 or uint8, numpy.arccos will compute the result in 3193 # half-precision (float16), where we want float32. 3194 x_dtype = str(getattr(x, 'dtype', '')) 3195 if x_dtype in ('int8', 'uint8'): 3196 return np.arccos(x, sig='f') 3197 return np.arccos(x) 3198 3199 def L_op(self, inputs, outputs, gout): 3200 (x,) = inputs 3201 (gz,) = gout 3202 if gz.type in complex_types: 3203 raise NotImplementedError() 3204 if outputs[0].type in discrete_types: 3205 if x.type in discrete_types: 3206 return [x.zeros_like(dtype=theano.config.floatX)] 3207 else: 3208 return [x.zeros_like()] 3209 3210 return - gz / sqrt(np.cast[x.type](1) - sqr(x)), 3211 3212 def c_code(self, node, name, inputs, outputs, sub): 3213 (x,) = inputs 3214 (z,) = outputs 3215 if node.inputs[0].type in complex_types: 3216 raise NotImplementedError('type not supported', type) 3217 cast = node.outputs[0].type.dtype_specs()[1] 3218 return "%(z)s = acos((%(cast)s)%(x)s);" % locals() 3219arccos = ArcCos(upgrade_to_float, name='arccos') 3220 3221 3222class Sin(UnaryScalarOp): 3223 nfunc_spec = ('sin', 1, 1) 3224 amd_float32 = "amd_vrsa_sinf" 3225 amd_float64 = "amd_vrda_sin" 3226 3227 def impl(self, x): 3228 # If x is an int8 or uint8, numpy.sin will compute the result in 3229 # half-precision (float16), where we want float32. 3230 x_dtype = str(getattr(x, 'dtype', '')) 3231 if x_dtype in ('int8', 'uint8'): 3232 return np.sin(x, sig='f') 3233 return np.sin(x) 3234 3235 def L_op(self, inputs, outputs, gout): 3236 (x,) = inputs 3237 (gz,) = gout 3238 if x.type in complex_types: 3239 raise NotImplementedError() 3240 if outputs[0].type in discrete_types: 3241 if x.type in discrete_types: 3242 return [x.zeros_like(dtype=theano.config.floatX)] 3243 else: 3244 return [x.zeros_like()] 3245 3246 return gz * cos(x), 3247 3248 def c_code(self, node, name, inputs, outputs, sub): 3249 (x,) = inputs 3250 (z,) = outputs 3251 if node.inputs[0].type in complex_types: 3252 raise NotImplementedError('type not supported', type) 3253 cast = node.outputs[0].type.dtype_specs()[1] 3254 return "%(z)s = sin((%(cast)s)%(x)s);" % locals() 3255sin = Sin(upgrade_to_float, name='sin') 3256 3257 3258class ArcSin(UnaryScalarOp): 3259 nfunc_spec = ('arcsin', 1, 1) 3260 3261 def impl(self, x): 3262 # If x is an int8 or uint8, numpy.arcsin will compute the result in 3263 # half-precision (float16), where we want float32. 3264 x_dtype = str(getattr(x, 'dtype', '')) 3265 if x_dtype in ('int8', 'uint8'): 3266 return np.arcsin(x, sig='f') 3267 return np.arcsin(x) 3268 3269 def L_op(self, inputs, outputs, gout): 3270 (x,) = inputs 3271 (gz,) = gout 3272 if gz.type in complex_types: 3273 raise NotImplementedError() 3274 if outputs[0].type in discrete_types: 3275 if x.type in discrete_types: 3276 return [x.zeros_like(dtype=theano.config.floatX)] 3277 else: 3278 return [x.zeros_like()] 3279 3280 return gz / sqrt(np.cast[x.type](1) - sqr(x)), 3281 3282 def c_code(self, node, name, inputs, outputs, sub): 3283 (x,) = inputs 3284 (z,) = outputs 3285 if node.inputs[0].type in complex_types: 3286 raise NotImplementedError('type not supported', type) 3287 cast = node.outputs[0].type.dtype_specs()[1] 3288 return "%(z)s = asin((%(cast)s)%(x)s);" % locals() 3289arcsin = ArcSin(upgrade_to_float, name='arcsin') 3290 3291 3292class Tan(UnaryScalarOp): 3293 nfunc_spec = ('tan', 1, 1) 3294 3295 def impl(self, x): 3296 # If x is an int8 or uint8, numpy.tan will compute the result in 3297 # half-precision (float16), where we want float32. 3298 x_dtype = str(getattr(x, 'dtype', '')) 3299 if x_dtype in ('int8', 'uint8'): 3300 return np.tan(x, sig='f') 3301 return np.tan(x) 3302 3303 def L_op(self, inputs, outputs, gout): 3304 (x,) = inputs 3305 (gz,) = gout 3306 if x.type in complex_types: 3307 raise NotImplementedError() 3308 if outputs[0].type in discrete_types: 3309 if x.type in discrete_types: 3310 return [x.zeros_like(dtype=theano.config.floatX)] 3311 else: 3312 return [x.zeros_like()] 3313 3314 return gz / sqr(cos(x)), 3315 3316 def c_code(self, node, name, inputs, outputs, sub): 3317 (x,) = inputs 3318 (z,) = outputs 3319 if node.inputs[0].type in complex_types: 3320 raise NotImplementedError('type not supported', type) 3321 cast = node.outputs[0].type.dtype_specs()[1] 3322 return "%(z)s = tan((%(cast)s)%(x)s);" % locals() 3323tan = Tan(upgrade_to_float, name='tan') 3324 3325 3326class ArcTan(UnaryScalarOp): 3327 nfunc_spec = ('arctan', 1, 1) 3328 3329 def impl(self, x): 3330 # If x is an int8 or uint8, numpy.arctan will compute the result in 3331 # half-precision (float16), where we want float32. 3332 x_dtype = str(getattr(x, 'dtype', '')) 3333 if x_dtype in ('int8', 'uint8'): 3334 return np.arctan(x, sig='f') 3335 return np.arctan(x) 3336 3337 def L_op(self, inputs, outputs, gout): 3338 (x,) = inputs 3339 (gz,) = gout 3340 if gz.type in complex_types: 3341 raise NotImplementedError() 3342 if outputs[0].type in discrete_types: 3343 if x.type in discrete_types: 3344 return [x.zeros_like(dtype=theano.config.floatX)] 3345 else: 3346 return [x.zeros_like()] 3347 3348 return gz / (np.cast[x.type](1) + sqr(x)), 3349 3350 def c_code(self, node, name, inputs, outputs, sub): 3351 (x,) = inputs 3352 (z,) = outputs 3353 if node.inputs[0].type in complex_types: 3354 raise NotImplementedError('type not supported', type) 3355 cast = node.outputs[0].type.dtype_specs()[1] 3356 return "%(z)s = atan((%(cast)s)%(x)s);" % locals() 3357arctan = ArcTan(upgrade_to_float, name='arctan') 3358 3359 3360class ArcTan2(BinaryScalarOp): 3361 nfunc_spec = ('arctan2', 2, 1) 3362 3363 def impl(self, y, x): 3364 # If x and y are int8 or uint8, numpy.arctan2 will compute the result 3365 # in half-precision (float16), where we want float32. 3366 x_dtype = str(getattr(x, 'dtype', '')) 3367 if x_dtype in ('int8', 'uint8'): 3368 y_dtype = str(getattr(x, 'dtype', '')) 3369 if y_dtype in ('int8', 'uint8'): 3370 return np.arctan2(y, x, sig='f') 3371 return np.arctan2(y, x) 3372 3373 def L_op(self, inputs, outputs, gout): 3374 (y, x) = inputs 3375 (gz,) = gout 3376 if gz.type in complex_types: 3377 raise NotImplementedError() 3378 else: 3379 if outputs[0].type in discrete_types: 3380 if x.type in discrete_types: 3381 gx = x.zeros_like(dtype=theano.config.floatX) 3382 else: 3383 gx = x.zeros_like() 3384 if y.type in discrete_types: 3385 gy = y.zeros_like(dtype=theano.config.floatX) 3386 else: 3387 gy = y.zeros_like() 3388 return [gx, gy] 3389 3390 # If the output is float, the gradient should flow, 3391 # even if the inputs are ints 3392 return [gz * x / (sqr(x) + sqr(y)), 3393 gz * neg(y) / (sqr(x) + sqr(y))] 3394 3395 def c_code(self, node, name, inputs, outputs, sub): 3396 (y, x) = inputs 3397 (z,) = outputs 3398 if (node.inputs[0].type in complex_types or 3399 node.inputs[1].type in complex_types): 3400 raise NotImplementedError('type not supported', type) 3401 cast = node.outputs[0].type.dtype_specs()[1] 3402 return "%(z)s = atan2((%(cast)s)%(y)s, (%(cast)s)%(x)s);" % locals() 3403arctan2 = ArcTan2(upgrade_to_float, name='arctan2') 3404 3405 3406class Cosh(UnaryScalarOp): 3407 """ 3408 cosh(x) = (exp(x) + exp(-x)) / 2. 3409 3410 """ 3411 nfunc_spec = ('cosh', 1, 1) 3412 3413 def impl(self, x): 3414 # If x is an int8 or uint8, numpy.cosh will compute the result in 3415 # half-precision (float16), where we want float32. 3416 x_dtype = str(getattr(x, 'dtype', '')) 3417 if x_dtype in ('int8', 'uint8'): 3418 return np.cosh(x, sig='f') 3419 return np.cosh(x) 3420 3421 def L_op(self, inputs, outputs, gout): 3422 (x,) = inputs 3423 (gz,) = gout 3424 if x.type in complex_types: 3425 raise NotImplementedError() 3426 if outputs[0].type in discrete_types: 3427 if x.type in discrete_types: 3428 return [x.zeros_like(dtype=theano.config.floatX)] 3429 else: 3430 return [x.zeros_like()] 3431 3432 return gz * sinh(x), 3433 3434 def c_code(self, node, name, inputs, outputs, sub): 3435 (x,) = inputs 3436 (z,) = outputs 3437 if node.inputs[0].type in complex_types: 3438 raise NotImplementedError('type not supported', type) 3439 cast = node.outputs[0].type.dtype_specs()[1] 3440 return "%(z)s = cosh((%(cast)s)%(x)s);" % locals() 3441cosh = Cosh(upgrade_to_float, name='cosh') 3442 3443 3444class ArcCosh(UnaryScalarOp): 3445 nfunc_spec = ('arccosh', 1, 1) 3446 3447 def impl(self, x): 3448 # If x is an int8 or uint8, numpy.arccosh will compute the result in 3449 # half-precision (float16), where we want float32. 3450 x_dtype = str(getattr(x, 'dtype', '')) 3451 if x_dtype in ('int8', 'uint8'): 3452 return np.arccosh(x, sig='f') 3453 return np.arccosh(x) 3454 3455 def L_op(self, inputs, outputs, gout): 3456 (x,) = inputs 3457 (gz,) = gout 3458 if x.type in complex_types: 3459 raise NotImplementedError() 3460 if outputs[0].type in discrete_types: 3461 if x.type in discrete_types: 3462 return [x.zeros_like(dtype=theano.config.floatX)] 3463 else: 3464 return [x.zeros_like()] 3465 3466 return gz / sqrt(sqr(x) - np.cast[x.type](1)), 3467 3468 def c_code(self, node, name, inputs, outputs, sub): 3469 (x,) = inputs 3470 (z,) = outputs 3471 if node.inputs[0].type in complex_types: 3472 raise NotImplementedError('type not supported', type) 3473 cast = node.outputs[0].type.dtype_specs()[1] 3474 return "%(z)s = acosh((%(cast)s)%(x)s);" % locals() 3475arccosh = ArcCosh(upgrade_to_float, name='arccosh') 3476 3477 3478class Sinh(UnaryScalarOp): 3479 """ 3480 sinh(x) = (exp(x) - exp(-x)) / 2. 3481 3482 """ 3483 nfunc_spec = ('sinh', 1, 1) 3484 3485 def impl(self, x): 3486 # If x is an int8 or uint8, numpy.sinh will compute the result in 3487 # half-precision (float16), where we want float32. 3488 x_dtype = str(getattr(x, 'dtype', '')) 3489 if x_dtype in ('int8', 'uint8'): 3490 return np.sinh(x, sig='f') 3491 return np.sinh(x) 3492 3493 def L_op(self, inputs, outputs, gout): 3494 (x,) = inputs 3495 (gz,) = gout 3496 if x.type in complex_types: 3497 raise NotImplementedError() 3498 if outputs[0].type in discrete_types: 3499 if x.type in discrete_types: 3500 return [x.zeros_like(dtype=theano.config.floatX)] 3501 else: 3502 return [x.zeros_like()] 3503 3504 return gz * cosh(x), 3505 3506 def c_code(self, node, name, inputs, outputs, sub): 3507 (x,) = inputs 3508 (z,) = outputs 3509 if node.inputs[0].type in complex_types: 3510 raise NotImplementedError('type not supported', type) 3511 cast = node.outputs[0].type.dtype_specs()[1] 3512 return "%(z)s = sinh((%(cast)s)%(x)s);" % locals() 3513sinh = Sinh(upgrade_to_float, name='sinh') 3514 3515 3516class ArcSinh(UnaryScalarOp): 3517 nfunc_spec = ('arcsinh', 1, 1) 3518 3519 def impl(self, x): 3520 # If x is an int8 or uint8, numpy.arcsinh will compute the result in 3521 # half-precision (float16), where we want float32. 3522 x_dtype = str(getattr(x, 'dtype', '')) 3523 if x_dtype in ('int8', 'uint8'): 3524 return np.arcsinh(x, sig='f') 3525 return np.arcsinh(x) 3526 3527 def L_op(self, inputs, outputs, gout): 3528 (x,) = inputs 3529 (gz,) = gout 3530 if x.type in complex_types: 3531 raise NotImplementedError() 3532 if outputs[0].type in discrete_types: 3533 if x.type in discrete_types: 3534 return [x.zeros_like(dtype=theano.config.floatX)] 3535 else: 3536 return [x.zeros_like()] 3537 3538 return gz / sqrt(sqr(x) + np.cast[x.type](1)), 3539 3540 def c_code(self, node, name, inputs, outputs, sub): 3541 (x,) = inputs 3542 (z,) = outputs 3543 if node.inputs[0].type in complex_types: 3544 raise NotImplementedError('type not supported', type) 3545 cast = node.outputs[0].type.dtype_specs()[1] 3546 return "%(z)s = asinh((%(cast)s)%(x)s);" % locals() 3547arcsinh = ArcSinh(upgrade_to_float, name='arcsinh') 3548 3549 3550class Tanh(UnaryScalarOp): 3551 """ 3552 tanh(x) = sinh(x) / cosh(x) 3553 = (exp(2*x) - 1) / (exp(2*x) + 1). 3554 3555 """ 3556 nfunc_spec = ('tanh', 1, 1) 3557 3558 def impl(self, x): 3559 # If x is an int8 or uint8, numpy.tanh will compute the result in 3560 # half-precision (float16), where we want float32. 3561 x_dtype = str(getattr(x, 'dtype', '')) 3562 if x_dtype in ('int8', 'uint8'): 3563 return np.tanh(x, sig='f') 3564 return np.tanh(x) 3565 3566 def L_op(self, inputs, outputs, gout): 3567 (x,) = inputs 3568 (gz,) = gout 3569 if x.type in complex_types: 3570 raise NotImplementedError() 3571 if outputs[0].type in discrete_types: 3572 if x.type in discrete_types: 3573 return [x.zeros_like(dtype=theano.config.floatX)] 3574 else: 3575 return [x.zeros_like()] 3576 3577 return gz * (1 - sqr(tanh(x))), 3578 3579 def c_code(self, node, name, inputs, outputs, sub): 3580 (x,) = inputs 3581 (z,) = outputs 3582 if node.inputs[0].type in complex_types: 3583 raise NotImplementedError('type not supported', type) 3584 cast = node.outputs[0].type.dtype_specs()[1] 3585 return "%(z)s = tanh((%(cast)s)%(x)s);" % locals() 3586tanh = Tanh(upgrade_to_float, name='tanh') 3587 3588 3589class ArcTanh(UnaryScalarOp): 3590 nfunc_spec = ('arctanh', 1, 1) 3591 3592 def impl(self, x): 3593 # If x is an int8 or uint8, numpy.arctanh will compute the result in 3594 # half-precision (float16), where we want float32. 3595 x_dtype = str(getattr(x, 'dtype', '')) 3596 if x_dtype in ('int8', 'uint8'): 3597 return np.arctanh(x, sig='f') 3598 return np.arctanh(x) 3599 3600 def L_op(self, inputs, outputs, gout): 3601 (x,) = inputs 3602 (gz,) = gout 3603 if x.type in complex_types: 3604 raise NotImplementedError() 3605 if outputs[0].type in discrete_types: 3606 if x.type in discrete_types: 3607 return [x.zeros_like(dtype=theano.config.floatX)] 3608 else: 3609 return [x.zeros_like()] 3610 3611 return gz / (np.cast[x.type](1) - sqr(x)), 3612 3613 def c_code(self, node, name, inputs, outputs, sub): 3614 (x,) = inputs 3615 (z,) = outputs 3616 if node.inputs[0].type in complex_types: 3617 raise NotImplementedError('type not supported', type) 3618 cast = node.outputs[0].type.dtype_specs()[1] 3619 return "%(z)s = atanh((%(cast)s)%(x)s);" % locals() 3620arctanh = ArcTanh(upgrade_to_float, name='arctanh') 3621 3622 3623class Real(UnaryScalarOp): 3624 """ 3625 Extract the real coordinate of a complex number. 3626 3627 """ 3628 # numpy.real(float32) return a view on the inputs. 3629 # nfunc_spec = ('real', 1, 1) 3630 3631 def impl(self, x): 3632 return np.real(x) 3633 3634 def grad(self, inputs, gout): 3635 (x,) = inputs 3636 (gz,) = gout 3637 return [complex(gz, 0)] 3638 3639real = Real(real_out, name='real') 3640 3641 3642class Imag(UnaryScalarOp): 3643 nfunc_spec = ('imag', 1, 1) 3644 3645 def impl(self, x): 3646 return np.imag(x) 3647 3648 def grad(self, inputs, gout): 3649 (x,) = inputs 3650 (gz,) = gout 3651 if x.type in complex_types: 3652 return [complex(0, gz)] 3653 elif x.type in float_types: 3654 return [second(x, 0)] 3655 else: 3656 return [x.zeros_like(dtype=theano.config.floatX)] 3657 3658imag = Imag(real_out, name='imag') 3659 3660 3661class Angle(UnaryScalarOp): 3662 nfunc_spec = ('angle', 1, 1) 3663 3664 def impl(self, x): 3665 return np.angle(x) 3666 3667 def grad(self, inputs, gout): 3668 # y = x.imag 3669 # r = sqrt(y**2 + x.real**2) 3670 # g = y/r 3671 # if x == 0 and y == 0: 3672 # theta = 0 3673 # elif x >= 0: 3674 # theta = numpy.arcsin(g) 3675 # else: 3676 # theta = -numpy.arcsin(g)+numpy.pi 3677 3678 (c,) = inputs 3679 (gtheta,) = gout 3680 x = real(c) 3681 y = imag(c) 3682 r = abs(c) 3683 3684 gr = -gtheta * y / (r ** 2 * sqrt(1 - (y / r) ** 2)) 3685 gx = gr * x / r 3686 gy = gr * y / r 3687 if c in complex_types: 3688 return [cast(complex(gx, gy), x.type.dtype)] 3689 elif c in float_types: 3690 return [cast(second(x, 0), x.type.dtype)] 3691 else: 3692 return [c.zeros_like(dtype=theano.config.floatX)] 3693 3694angle = Angle(specific_out(float64), name='angle') 3695 3696 3697class Complex(BinaryScalarOp): 3698 @staticmethod 3699 def output_types_preference(x, y): 3700 if x in complex_types: 3701 raise TypeError(x) 3702 if y in complex_types: 3703 raise TypeError(y) 3704 3705 up = Scalar.upcast(x, y) 3706 if up in ('float64', 'int64', 'uint64', 'int32', 'uint32'): 3707 return [complex128] 3708 else: 3709 return [complex64] 3710 3711 def impl(self, x, y): 3712 return np.complex(x, y) 3713 3714 def grad(self, inputs, gout): 3715 (x, y) = inputs 3716 (gz,) = gout 3717 return [cast(real(gz), x.type.dtype), 3718 cast(imag(gz), y.type.dtype)] 3719complex = Complex(name='complex') 3720 3721 3722class Conj(UnaryScalarOp): 3723 nfunc_spec = ('conj', 1, 1) 3724 3725 def impl(self, x): 3726 return np.conj(x) 3727 3728 def c_code(self, node, name, inputs, outputs, sub): 3729 (x,) = inputs 3730 (z,) = outputs 3731 if node.inputs[0].type in complex_types: 3732 # For non complex, th 3733 raise NotImplementedError('type have no c code', 3734 node.inputs[0].type) 3735 return "%(z)s = %(x)s;" % locals() 3736 3737conj = Conj(same_out_min8, name='conj') 3738 3739 3740class ComplexFromPolar(BinaryScalarOp): 3741 @staticmethod 3742 def output_types_preference(x, y): 3743 return Complex.output_types_preference(x, y) 3744 3745 def impl(self, r, theta): 3746 if r < 0: 3747 raise ValueError('polar radius must be non-negative', r) 3748 x = r * np.cos(theta) 3749 y = r * np.sin(theta) 3750 if x.dtype == 'float32': 3751 return np.complex64(np.complex(x, y)) 3752 else: 3753 return np.complex128(np.complex(x, y)) 3754 3755 def grad(self, inputs, gout): 3756 (r, theta) = inputs 3757 (gz,) = gout 3758 gr = gz * complex_from_polar(1, theta) 3759 gtheta = gz * complex_from_polar(r, -theta) 3760 return [gr, gtheta] 3761complex_from_polar = ComplexFromPolar(name='complex_from_polar') 3762 3763 3764class Composite(ScalarOp): 3765 """ 3766 Composite is an Op that takes a graph of scalar operations and 3767 produces c code for the whole graph. Its purpose is to implement loop 3768 fusion. 3769 3770 Composite depends on all the Ops in its graph having C code. 3771 3772 """ 3773 init_param = ('inputs', 'outputs') 3774 3775 def __str__(self): 3776 if self.name is None: 3777 self.init_name() 3778 return self.name 3779 3780 def make_new_inplace(self, output_types_preference=None, name=None): 3781 """ 3782 This op.__init__ fct don't have the same parameter as other scalar op. 3783 This break the insert_inplace_optimizer optimization. 3784 This fct allow fix patch this. 3785 3786 """ 3787 d = dict([(k, getattr(self, k)) for k in self.init_param]) 3788 out = self.__class__(**d) 3789 if name: 3790 out.name = name 3791 else: 3792 name = out.name 3793 super(Composite, out).__init__(output_types_preference, name) 3794 return out 3795 3796 def init_c_code(self): 3797 """ 3798 Assemble the C code for this Composite Op. 3799 3800 The result is assigned to `self._c_code`. 3801 """ 3802 # It was already called 3803 if hasattr(self, '_c_code'): 3804 return 3805 subd = dict(chain( 3806 ((e, "%%(i%i)s" % i) for i, e in enumerate(self.fgraph.inputs)), 3807 ((e, "%%(o%i)s" % i) for i, e in enumerate(self.fgraph.outputs)))) 3808 3809 for var in self.fgraph.variables: 3810 if var.owner is None: 3811 if var not in self.fgraph.inputs: 3812 # This is an orphan 3813 if isinstance(var, Constant): 3814 subd[var] = var.type.c_literal(var.data) 3815 else: 3816 raise ValueError( 3817 "All orphans in the fgraph to Composite must" 3818 " be Constant instances.") 3819 elif (any(i.dtype == 'float16' for i in var.owner.inputs) or 3820 any(o.dtype == 'float16' for o in var.owner.outputs)): 3821 # flag for elemwise ops to check. 3822 self.inner_float16 = True 3823 3824 _c_code = "{\n" 3825 self.nodenames = ["%(nodename)s_" + ('subnode%i' % j) 3826 for j, n in enumerate(self.fgraph.toposort())] 3827 3828 i = 0 3829 for j, node in enumerate(self.fgraph.toposort()): 3830 for output in node.outputs: 3831 if output not in subd: 3832 i += 1 3833 name = "V%%(id)s_tmp%i" % i 3834 subd[output] = name 3835 _c_code += "%s %s;\n" % ( 3836 output.type.dtype_specs()[1], name) 3837 s = node.op.c_code( 3838 node, 3839 self.nodenames[j], 3840 [subd[input] for input in node.inputs], 3841 [subd[output] for output in node.outputs], 3842 dict(fail="%(fail)s", id="%%(id)s_%i" % j)) 3843 _c_code += s 3844 _c_code += "\n" 3845 _c_code += "}\n" 3846 self._c_code = _c_code 3847 3848 def init_py_impls(self): 3849 """ 3850 Return a list of functions that compute each output of self. 3851 3852 """ 3853 # In the case where the graph is a dag, but not a tree like: 3854 # add(*1 -> mul(x, y), *1) 3855 3856 # We have an efficient way to build the executable (we build 3857 # and traverse each node only once). 3858 3859 # But we don't have an efficient execution. We will execute 3860 # like a tree, so nodes that have more then 1 client will be 3861 # executed as many times as there number of clients. In the 3862 # example aboce, it will calculate *1 twice. Doing otherwise 3863 # imply making a complicated execution engine. 3864 3865 # We need the fast creation of the executor as we always do it 3866 # even if we will use the c code. The Python implementation is 3867 # already slow, so it is not as much important to have a fast 3868 # execution there. 3869 3870 memo = {} 3871 3872 def compose_impl(r): 3873 if r in memo: 3874 return memo[r] 3875 if r in self.fgraph.inputs: 3876 idx = self.fgraph.inputs.index(r) 3877 3878 def f(inputs): 3879 return inputs[idx] 3880 memo[r] = f 3881 return f 3882 elif r.owner is None: # in fgraph.orphans: 3883 def f(inputs): 3884 return r.data 3885 memo[r] = f 3886 return f 3887 node = r.owner 3888 producers = [compose_impl(input) for input in node.inputs] 3889 3890 def f(inputs): 3891 return node.op.impl(*[p(inputs) for p in producers]) 3892 memo[r] = f 3893 return f 3894 self._impls = [compose_impl(r) for r in self.fgraph.outputs] 3895 3896 def init_name(self): 3897 """ 3898 Return a readable string representation of self.fgraph. 3899 3900 """ 3901 rval = self.name 3902 if rval is None: 3903 for i, r in enumerate(self.fgraph.inputs): 3904 r.name = 'i%i' % i 3905 for i, r in enumerate(self.fgraph.outputs): 3906 r.name = 'o%i' % i 3907 io = set(self.fgraph.inputs + self.fgraph.outputs) 3908 for i, r in enumerate(self.fgraph.variables): 3909 if r not in io and len(r.clients) > 1: 3910 r.name = 't%i' % i 3911 rval = "Composite{%s}" % ', '.join([pprint(output) for output 3912 in self.fgraph.outputs]) 3913 self.name = rval 3914 3915 def init_fgraph(self): 3916 # The clone done by FunctionGraph is needed as we don't want 3917 # the fgraph to be set to the variable as we need to pickle 3918 # them for the cache of c module to work. 3919 fgraph = FunctionGraph(self.inputs, self.outputs) 3920 gof.MergeOptimizer().optimize(fgraph) 3921 for node in fgraph.apply_nodes: 3922 if not isinstance(node.op, ScalarOp): 3923 raise ValueError("The fgraph to Composite must be exclusively" 3924 " composed of ScalarOp instances.") 3925 self.fgraph = fgraph 3926 3927 def __init__(self, inputs, outputs): 3928 # We need to clone the graph as sometimes its nodes already 3929 # contain a reference to an fgraph. As we want the Composite 3930 # to be pickable, we can't have reference to fgraph. 3931 3932 # Also, if there is Composite in the inner graph, we want to 3933 # remove them. In that case, we do a more complicated clone 3934 # that will flatten Composite. We don't need to do this 3935 # recusively, as the way the fusion optimizer work, we have 3936 # only 1 new Composite each time at the output. 3937 for i in inputs: 3938 assert i not in outputs # This isn't supported, use identity 3939 if len(outputs) > 1 or not any([isinstance(var.owner.op, Composite) 3940 for var in outputs]): 3941 # No inner Composite 3942 inputs, outputs = gof.graph.clone(inputs, outputs) 3943 else: 3944 # Inner Composite that we need to flatten 3945 assert len(outputs) == 1 3946 # 1. Create a new graph from inputs up to the 3947 # Composite 3948 res = theano.compile.rebuild_collect_shared( 3949 inputs=inputs, 3950 outputs=outputs[0].owner.inputs, 3951 copy_inputs_over=False) # Clone also the inputs 3952 # 2. We continue this partial clone with the graph in 3953 # the inner Composite 3954 res2 = theano.compile.rebuild_collect_shared( 3955 inputs=outputs[0].owner.op.inputs, 3956 outputs=outputs[0].owner.op.outputs, 3957 replace=dict(izip(outputs[0].owner.op.inputs, res[1])) 3958 ) 3959 assert len(res2[1]) == len(outputs) 3960 assert len(res[0]) == len(inputs) 3961 assert res[0] != inputs 3962 inputs, outputs = res[0], res2[1] 3963 # Next assert comment just for speed 3964 # assert not any([isinstance(node.op, Composite) for node in 3965 # theano.gof.graph.ops(inputs, outputs)]) 3966 3967 self.inputs = copy(inputs) 3968 self.outputs = copy(outputs) 3969 self.inputs_type = tuple([input.type for input in inputs]) 3970 self.outputs_type = tuple([output.type for output in outputs]) 3971 self.nin = len(inputs) 3972 self.nout = len(outputs) 3973 self.init_fgraph() # self.fgraph 3974 # Postpone the creation in case it isn't needed. 3975 # self.init_name() # self.name 3976 self.name = None 3977 self.prepare_node_called = set() 3978 3979 def prepare_node(self, node, storage_map, compute_map, impl): 3980 if impl == 'py': 3981 self.init_py_impls() # self._impls 3982 if impl not in self.prepare_node_called: 3983 for n in theano.gof.graph.list_of_nodes(self.inputs, self.outputs): 3984 n.op.prepare_node(n, None, None, impl) 3985 self.prepare_node_called.add(impl) 3986 3987 def clone_float32(self): 3988 # This will not modify the fgraph or the nodes 3989 new_ins, new_outs = composite_f32.apply(self.fgraph) 3990 return Composite(new_ins, new_outs) 3991 3992 def output_types(self, input_types): 3993 if tuple(input_types) != self.inputs_type: 3994 raise TypeError("Wrong types for Composite. Expected %s, got %s." 3995 % (self.inputs_type, tuple(input_types))) 3996 return self.outputs_type 3997 3998 def make_node(self, *inputs): 3999 if (tuple([i.type for i in self.inputs]) == 4000 tuple([i.type for i in inputs])): 4001 return super(Composite, self).make_node(*inputs) 4002 else: 4003 # Make a new op with the right input type. 4004 assert len(inputs) == self.nin 4005 res = theano.compile.rebuild_collect_shared( 4006 self.outputs, 4007 replace=dict(izip(self.inputs, inputs)), 4008 rebuild_strict=False) 4009 # After rebuild_collect_shared, the Variable in inputs 4010 # are not necessarily in the graph represented by res. 4011 # res[2][0] is a dict that map from the original variable to the 4012 # cloned variable. 4013 cloned_inputs = [res[2][0][i] for i in inputs] 4014 node = Composite(cloned_inputs, res[1]).make_node(*inputs) 4015 return node 4016 4017 def perform(self, node, inputs, output_storage): 4018 for storage, impl in zip(output_storage, self._impls): 4019 storage[0] = impl(inputs) 4020 4021 def impl(self, *inputs): 4022 output_storage = [[None] for i in xrange(self.nout)] 4023 self.perform(None, inputs, output_storage) 4024 ret = utils.to_return_values([storage[0] for storage in 4025 output_storage]) 4026 if self.nout > 1: 4027 ret = tuple(ret) 4028 return ret 4029 4030 def grad(self, inputs, output_grads): 4031 raise NotImplementedError("grad is not implemented for Composite") 4032 4033 def c_code(self, node, nodename, inames, onames, sub): 4034 self.init_c_code() 4035 4036 d = dict(chain(izip(("i%i" % i for i in xrange(len(inames))), inames), 4037 izip(("o%i" % i for i in xrange(len(onames))), 4038 onames)), **sub) 4039 d['nodename'] = nodename 4040 if 'id' not in sub: 4041 # The use of a dummy id is safe as the code is in a separate block. 4042 # It won't generate conflicting variable name. 4043 d['id'] = '_DUMMY_ID_' 4044 4045 return self._c_code % d 4046 4047 def c_code_cache_version(self): 4048 rval = [3] 4049 for x in self.fgraph.toposort(): 4050 xv = x.op.c_code_cache_version() 4051 if xv: 4052 rval.append(xv) 4053 else: 4054 return () 4055 return tuple(rval) 4056 4057 def c_support_code(self): 4058 rval = [] 4059 for subnode in self.fgraph.toposort(): 4060 try: 4061 rval.append(subnode.op.c_support_code().strip()) 4062 except gof.utils.MethodNotDefined: 4063 pass 4064 # remove duplicate code blocks 4065 return "\n".join(sorted(set(rval))) 4066 4067 def c_support_code_apply(self, node, name): 4068 self.init_c_code() 4069 rval = [] 4070 for subnode, subnodename in zip(self.fgraph.toposort(), self.nodenames): 4071 try: 4072 subnode_support_code = subnode.op.c_support_code_apply( 4073 subnode, 4074 subnodename % dict(nodename=name)) 4075 if subnode_support_code: 4076 rval.append(subnode_support_code) 4077 except gof.utils.MethodNotDefined: 4078 pass 4079 # there should be no need to remove duplicate code blocks because 4080 # each block should have been specialized for the given nodename. 4081 # Any block that isn't specialized should be returned via 4082 # c_support_code instead of c_support_code_apply. 4083 return "\n".join(rval) 4084 4085 def __eq__(self, other): 4086 if self is other: 4087 return True 4088 if (type(self) != type(other) or 4089 self.nin != other.nin or 4090 self.nout != other.nout): 4091 return False 4092 # see __hash__ for comment on why there is no mention of fgraph 4093 # or module cache key here. 4094 self.init_c_code() # self._c_code and self.nodenames 4095 other.init_c_code() 4096 return (self._c_code == other._c_code) 4097 4098 def __hash__(self): 4099 self.init_c_code() # self._c_code and self.nodenames 4100 rval = hash((type(self), 4101 self.nin, 4102 self.nout, 4103 self._c_code)) 4104 # Note that in general, the configparser settings at the time 4105 # of code generation (__init__) affect the semantics of this Op. 4106 # This function assumes that all relevant info about the configparser 4107 # is embodied in _c_code. So the _c_code, rather than self.fgraph, 4108 # is the signature of the semantics of this Op. 4109 # _c_code is preserved through unpickling, so the Op will not change 4110 # semantics when it is reloaded with different configparser 4111 # settings. 4112 return rval 4113 4114 def __getstate__(self): 4115 rval = dict(self.__dict__) 4116 rval.pop('_impls', None) 4117 rval.pop('prepare_node_called', None) 4118 del rval['fgraph'] 4119 return rval 4120 4121 def __setstate__(self, d): 4122 self.__dict__.update(d) 4123 # We must call init to set fgraph and _impls again, as otherwise 4124 # self.perform will not work. 4125 self.prepare_node_called = set() 4126 self.init_fgraph() 4127 self.init_py_impls() 4128 4129 4130class Compositef32(object): 4131 # This is a dict of scalar op classes that need special handling 4132 special = {} 4133 4134 def apply(self, fgraph): 4135 mapping = {} 4136 topo = fgraph.toposort() 4137 for i in fgraph.inputs: 4138 if i.dtype == 'float16': 4139 mapping[i] = get_scalar_type('float32')() 4140 if hasattr(i.tag, 'test_value'): 4141 mapping[i].tag.test_value = i.tag.test_value 4142 else: 4143 mapping[i] = i 4144 for node in topo: 4145 # Patch up for constants 4146 for i in node.inputs: 4147 if i not in mapping: 4148 assert type(i) is ScalarConstant 4149 if i.type == float16: 4150 ni = ScalarConstant(float32, i.data) 4151 else: 4152 ni = i 4153 mapping[i] = ni 4154 if type(node.op) in self.special: 4155 self.special[type(node.op)](node, mapping) 4156 continue 4157 new_node = node.clone_with_new_inputs( 4158 [mapping[inp] for inp in node.inputs], 4159 strict=False) 4160 # make sure we don't produce any float16. 4161 assert not any(o.dtype == 'float16' for o in new_node.outputs) 4162 for o, no in zip(node.outputs, new_node.outputs): 4163 mapping[o] = no 4164 4165 new_ins = [mapping[inp] for inp in fgraph.inputs] 4166 new_outs = [mapping[out] for out in fgraph.outputs] 4167 return new_ins, new_outs 4168 4169composite_f32 = Compositef32() 4170 4171 4172def handle_cast(node, mapping): 4173 inp = mapping[node.inputs[0]] 4174 out = node.outputs[0] 4175 node_ok = False 4176 if node.op.o_type == float16: 4177 if node.inputs[0].type == float32: 4178 # cast f32 -> f16, remove 4179 mapping[out] = inp 4180 return 4181 else: 4182 # cast to f16, convert to f32 4183 new_out = cast(inp, 'float32') 4184 # change the node for the following if 4185 node = new_out.owner 4186 mapping[out] = new_out 4187 node_ok = True 4188 if node.inputs[0].type == float16: 4189 if node.op.o_type == inp.type: 4190 # cast f16 to new input type, remove 4191 mapping[out] = inp 4192 return 4193 if not node_ok: 4194 new_node = node.clone_with_new_inputs([inp], 4195 strict=False) 4196 mapping[out] = new_node.outputs[0] 4197 4198Compositef32.special[Cast] = handle_cast 4199 4200 4201def handle_composite(node, mapping): 4202 new_op = node.op.clone_float32() 4203 new_outs = new_op(*[mapping[i] for i in node.inputs], return_list=True) 4204 assert len(new_outs) == len(node.outputs) 4205 for o, no in zip(node.outputs, new_outs): 4206 mapping[o] = no 4207 4208Compositef32.special[Composite] = handle_composite 4209