1""" 2Ops and optimizations: sigmoid, softplus. 3 4These functions implement special cases of exp and log to improve numerical 5stability. 6 7""" 8from __future__ import absolute_import, print_function, division 9 10import warnings 11 12import numpy as np 13 14import theano 15from theano import config, gof, printing, scalar 16from theano.compat import imap 17from theano.printing import pprint 18from theano.tensor import basic as tensor 19from theano.tensor import elemwise, opt, NotScalarConstantError 20from theano.tensor.type import values_eq_approx_remove_inf 21from theano.gof.opt import copy_stack_trace 22 23############ 24# 25# SCALAR OPS 26# 27 28 29class ScalarSigmoid(scalar.UnaryScalarOp): 30 """ 31 This is just speed opt. Not for stability. 32 33 """ 34 @staticmethod 35 def st_impl(x): 36 if x < -30.0: 37 return 0.0 38 if x > 30.0: 39 return 1.0 40 # If x is an int8 or uint8, numpy.exp will compute the result in 41 # half-precision (float16), where we want float32. 42 x_dtype = str(getattr(x, 'dtype', '')) 43 if x_dtype in ('int8', 'uint8'): 44 return 1.0 / (1.0 + np.exp(-x, sig='f')) 45 return 1.0 / (1.0 + np.exp(-x)) 46 47 def impl(self, x): 48 return ScalarSigmoid.st_impl(x) 49 50 def grad(self, inp, grads): 51 x, = inp 52 gz, = grads 53 y = scalar_sigmoid(x) 54 rval = gz * y * (1.0 - y) 55 56 assert rval.type.dtype.find('float') != -1 57 58 return [rval] 59 60 def c_code(self, node, name, inp, out, sub): 61 x, = inp 62 z, = out 63 # We add boundary checks prevent exp from generating inf or 64 # 0. The reset of the logic always generate 0 or 1 in those 65 # cases. This is a speed optimization. 66 # The constants were obtained by looking at the output of 67 # python commands like: 68 # 69 # import numpy, theano 70 # dt='float32' # or float64 71 # for i in xrange(750): 72 # print i, repr(theano._asarray(1.0, dtype=dt) / 73 # (theano._asarray(1.0, dtype=dt) + 74 # numpy.exp(-theano._asarray([i,-i], dtype=dt)))) 75 76 # float16 limits: -11.0, 7.0f 77 # We use the float32 limits for float16 for now as the 78 # computation will happen in float32 anyway. 79 if (node.inputs[0].type == scalar.float32 or 80 node.inputs[0].type == scalar.float16): 81 return """%(z)s = %(x)s < -88.0f ? 0.0 : %(x)s > 15.0f ? 1.0f : 1.0f /(1.0f + exp(-%(x)s));""" % locals() 82 elif node.inputs[0].type == scalar.float64: 83 return """%(z)s = %(x)s < -709.0 ? 0.0 : %(x)s > 19.0 ? 1.0 : 1.0 /(1.0+exp(-%(x)s));""" % locals() 84 else: 85 raise NotImplementedError('only floatingpoint is implemented') 86 87 def c_code_cache_version(self): 88 v = super(ScalarSigmoid, self).c_code_cache_version() 89 if v: 90 return (2,) + v 91 else: 92 return v 93 94 # This fct is disabled as it is slower then the normal code! 95 def c_code_contiguous_disabled(self, node, name, inp, out, sub): 96 x, = inp 97 z, = out 98 if (not theano.config.lib.amdlibm or 99 node.inputs[0].dtype != node.outputs[0].dtype): 100 raise theano.gof.utils.MethodNotDefined() 101 dtype = node.inputs[0].dtype 102 if dtype == 'float32' and self.amd_float32 is not None: 103 dtype = 'float' 104 fct = "amd_vrsa_expf" 105 elif dtype == 'float64' and self.amd_float64 is not None: 106 dtype = 'double' 107 fct = "amd_vrda_exp" 108 else: 109 raise theano.gof.utils.MethodNotDefined() 110 return """ 111 npy_intp n = PyArray_SIZE(%(z)s); 112 %(dtype)s * x = (%(dtype)s*) PyArray_DATA(%(x)s); 113 %(dtype)s * z = (%(dtype)s*) PyArray_DATA(%(z)s); 114 // We block to keep the data in l1 115 // normal l1 size = 32k: 32k/2(input + output)/8(nb bytes of double)=2k 116 // We stay bellow the 2k limit to let space for 117 // This is faster than the not blocking version 118 for(int i=0;i<n;i+=2048){ 119 npy_intp nb = (n-i<2048)?n-i:2048; 120 for(int j=0;j<nb;j++){ 121 z[i+j] = -x[i+j]; 122 } 123 %(fct)s(nb, z+i, z+i); 124 for(int j=0;j<nb;j++){ 125 z[i+j] = 1.0 /(1.0+z[i+j]); 126 } 127 } 128 """ % locals() 129 raise theano.gof.utils.MethodNotDefined() 130 131 @staticmethod 132 def gen_graph(): 133 """ 134 This method was used to generate the graph: sigmoid_prec.png in the doc. 135 136 """ 137 data = np.arange(-15, 15, .1) 138 val = 1 / (1 + np.exp(-data)) 139 140 def hard_sigmoid(x): 141 return theano.tensor.nnet.hard_sigmoid(x) 142 143 def ultra_fast_sigmoid(x): 144 return theano.tensor.nnet.ultra_fast_sigmoid(x) 145 146 val_hard = hard_sigmoid(data).eval() 147 val_ultra = ultra_fast_sigmoid(data).eval() 148 149 import matplotlib.pyplot as plt 150 import os 151 fig = plt.figure() 152 ax = fig.add_subplot(111) 153 ax.plot(data, val) # , 'o-') 154 ax.plot(data, val_ultra) # , '-') 155 ax.plot(data, val_hard) # , '-') 156 ax.grid(True) 157 ax.legend(("sigmoid", "ultra_fast", "hard"), "upper left") 158 fname = os.path.join(os.path.dirname(theano.__file__), '..', 159 'doc', 'library', 'tensor', 'nnet', 160 'sigmoid_prec.png') 161 plt.savefig(fname) 162 print("New picture saved at", fname) 163 print(val_ultra.max()) 164 print(val_ultra.min()) 165 166 167scalar_sigmoid = ScalarSigmoid(scalar.upgrade_to_float, name='scalar_sigmoid') 168sigmoid = elemwise.Elemwise(scalar_sigmoid, name='sigmoid') 169 170sigmoid_inplace = elemwise.Elemwise( 171 ScalarSigmoid(scalar.transfer_type(0)), 172 inplace_pattern={0: 0}, 173 name='sigmoid_inplace', 174) 175 176pprint.assign(sigmoid, printing.FunctionPrinter('sigmoid')) 177 178 179class UltraFastScalarSigmoid(scalar.UnaryScalarOp): 180 """ 181 This is just speed opt. Not for stability. 182 183 """ 184 @staticmethod 185 def st_impl(x): 186 x = 0.5 * x 187 # The if is a tanh approximate. 188 if x >= 0: 189 if x < 1.7: 190 z = (1.5 * x / (1 + x)) 191 elif x < 3: 192 z = (0.935409070603099 + 0.0458812946797165 * (x - 1.7)) 193 else: 194 z = 0.99505475368673 195 else: 196 xx = -x 197 if xx < 1.7: 198 z = (1.5 * xx / (1 + xx)) 199 elif xx < 3: 200 z = (0.935409070603099 + 0.0458812946797165 * (xx - 1.7)) 201 else: 202 z = 0.99505475368673 203 z = -z 204 205 return 0.5 * (z + 1.) 206 207 def impl(self, x): 208 return UltraFastScalarSigmoid.st_impl(x) 209 210 def c_code(self, node, name, inp, out, sub): 211 x, = inp 212 z, = out 213 dtype = node.outputs[0].type.dtype_specs()[1] 214 215 return """ 216 %(dtype)s x = 0.5 * %(x)s; 217 // The if is a tanh approximate. 218 if(x>=0) { 219 %(z)s = (x<1.7 ? (1.5*x/(1+x)) : 220 (x<3 ? (0.935409070603099 + 0.0458812946797165*(x-1.7)): 221 0.99505475368673)); 222 } else { 223 %(dtype)s xx = -x; 224 %(z)s = -(xx<1.7 ? (1.5*xx/(1+xx)) : 225 (xx<3 ? (0.935409070603099 + 0.0458812946797165*(xx-1.7)): 226 0.99505475368673)); 227 } 228 229 //%(z)s = 0.5*(ultrafasttanh(0.5*x)+1.); 230 %(z)s = 0.5*(%(z)s+1.); 231 """ % locals() 232 233ultra_fast_scalar_sigmoid = UltraFastScalarSigmoid( 234 scalar.upgrade_to_float, name='ultra_fast_scalar_sigmoid') 235ultra_fast_sigmoid = elemwise.Elemwise(ultra_fast_scalar_sigmoid, 236 name='ultra_fast_sigmoid') 237 238ultra_fast_sigmoid_inplace = elemwise.Elemwise( 239 UltraFastScalarSigmoid(scalar.transfer_type(0)), 240 inplace_pattern={0: 0}, 241 name='ultra_fast_sigmoid_inplace', 242) 243 244pprint.assign(ultra_fast_sigmoid, 245 printing.FunctionPrinter('ultra_fast_sigmoid')) 246 247 248# @opt.register_uncanonicalize 249@gof.local_optimizer([sigmoid]) 250def local_ultra_fast_sigmoid(node): 251 """ 252 When enabled, change all sigmoid to ultra_fast_sigmoid. 253 254 For example do mode.including('local_ultra_fast_sigmoid') 255 or use the Theano flag optimizer_including=local_ultra_fast_sigmoid. 256 257 This speeds up the sigmoid op by using an approximation. 258 259 This is done after the stabilization and specialize phases 260 to avoid interacting with them. 261 262 """ 263 if (isinstance(node.op, tensor.Elemwise) and 264 node.op.scalar_op == scalar_sigmoid): 265 out = ultra_fast_sigmoid(node.inputs[0]) 266 copy_stack_trace(node.outputs[0], out) 267 268 def values_eq_approx_remove_low_prec(a, b): 269 # atol is found by trial/error. 270 # Other test could fail without good reason. 271 return tensor.TensorType.values_eq_approx(a, b, atol=0.02) 272 # Let DebugMode know that there this opt approx the values. 273 out.tag.values_eq_approx = values_eq_approx_remove_low_prec 274 return [out] 275theano.compile.optdb['uncanonicalize'].register("local_ultra_fast_sigmoid", 276 local_ultra_fast_sigmoid) 277 278 279def hard_sigmoid(x): 280 """ 281 An approximation of sigmoid. 282 283 More approximate and faster than ultra_fast_sigmoid. 284 285 Approx in 3 parts: 0, scaled linear, 1. 286 287 Removing the slope and shift does not make it faster. 288 289 """ 290 # Use the same dtype as determined by "upgrade_to_float", 291 # and perform computation in that dtype. 292 out_dtype = scalar.upgrade_to_float(scalar.Scalar(dtype=x.dtype))[0].dtype 293 slope = tensor.constant(0.2, dtype=out_dtype) 294 shift = tensor.constant(0.5, dtype=out_dtype) 295 x = (x * slope) + shift 296 x = tensor.clip(x, 0, 1) 297 return x 298 299 300# @opt.register_uncanonicalize 301@gof.local_optimizer([sigmoid]) 302def local_hard_sigmoid(node): 303 if (isinstance(node.op, tensor.Elemwise) and 304 node.op.scalar_op == scalar_sigmoid): 305 out = hard_sigmoid(node.inputs[0]) 306 copy_stack_trace(node.outputs[0], out) 307 308 def values_eq_approx_remove_low_prec(a, b): 309 # atol is found by trial/error. 310 # Other test could fail without good reason. 311 return tensor.TensorType.values_eq_approx(a, b, atol=0.1) 312 # Let DebugMode know that there this opt approx the values. 313 out.tag.values_eq_approx = values_eq_approx_remove_low_prec 314 return [out] 315theano.compile.optdb['uncanonicalize'].register("local_hard_sigmoid", 316 local_hard_sigmoid) 317 318 319class ScalarSoftplus(scalar.UnaryScalarOp): 320 """ 321 This helps numerical stability. 322 """ 323 @staticmethod 324 def static_impl(x): 325 if x < -30.0: 326 return 0.0 327 if x > 30.0: 328 return x 329 # If x is an int8 or uint8, numpy.exp will compute the result in 330 # half-precision (float16), where we want float32. 331 x_dtype = str(getattr(x, 'dtype', '')) 332 if x_dtype in ('int8', 'uint8'): 333 return np.log1p(np.exp(x, sig='f')) 334 return np.log1p(np.exp(x)) 335 336 def impl(self, x): 337 return ScalarSoftplus.static_impl(x) 338 339 def grad(self, inp, grads): 340 x, = inp 341 gz, = grads 342 return [gz * scalar_sigmoid(x)] 343 344 def c_code(self, node, name, inp, out, sub): 345 x, = inp 346 z, = out 347 # These constants were obtained by looking at the output of 348 # python commands like: 349 # for i in xrange(750): 350 # print i, repr(numpy.log1p(numpy.exp(theano._asarray([i,-i], dtype=dt)))) 351 # the boundary checks prevent us from generating inf 352 353 # float16 limits: -17.0, 6.0 354 # We use the float32 limits for float16 for now as the 355 # computation will happen in float32 anyway. 356 if (node.inputs[0].type == scalar.float32 or 357 node.inputs[0].type == scalar.float16): 358 return """%(z)s = %(x)s < -103.0f ? 0.0 : %(x)s > 14.0f ? %(x)s : log1p(exp(%(x)s));""" % locals() 359 elif node.inputs[0].type == scalar.float64: 360 return """%(z)s = %(x)s < -745.0 ? 0.0 : %(x)s > 16.0 ? %(x)s : log1p(exp(%(x)s));""" % locals() 361 else: 362 raise NotImplementedError('only floatingpoint is implemented') 363 364 def c_code_cache_version(self): 365 v = super(ScalarSoftplus, self).c_code_cache_version() 366 if v: 367 return (2,) + v 368 else: 369 return v 370scalar_softplus = ScalarSoftplus(scalar.upgrade_to_float, 371 name='scalar_softplus') 372softplus = elemwise.Elemwise(scalar_softplus, name='softplus') 373 374pprint.assign(softplus, printing.FunctionPrinter('softplus')) 375 376 377def _skip_mul_1(r): 378 if r.owner and r.owner.op == tensor.mul: 379 not_is_1 = [i for i in r.owner.inputs if not _is_1(i)] 380 if len(not_is_1) == 1: 381 return not_is_1[0] 382 383logsigm_to_softplus = gof.PatternSub( 384 (tensor.log, (sigmoid, 'x')), 385 (tensor.neg, (softplus, (tensor.neg, 'x'))), 386 allow_multiple_clients=True, 387 values_eq_approx=values_eq_approx_remove_inf, 388 skip_identities_fn=_skip_mul_1) 389 390 391def _is_1(expr): 392 """ 393 394 Returns 395 ------- 396 bool 397 True iff expr is a constant close to 1. 398 399 """ 400 try: 401 v = opt.get_scalar_constant_value(expr) 402 return np.allclose(v, 1) 403 except tensor.NotScalarConstantError: 404 return False 405 406log1msigm_to_softplus = gof.PatternSub( 407 (tensor.log, 408 (tensor.sub, 409 dict(pattern='y', constraint=_is_1), 410 (sigmoid, 'x'))), 411 (tensor.neg, (softplus, 'x')), 412 allow_multiple_clients=True, 413 values_eq_approx=values_eq_approx_remove_inf, 414 skip_identities_fn=_skip_mul_1) 415 416 417log1pexp_to_softplus = gof.PatternSub( 418 (tensor.log1p, 419 (tensor.exp, 'x')), 420 (softplus, 'x'), 421 values_eq_approx=values_eq_approx_remove_inf, 422 allow_multiple_clients=True) 423 424log1p_neg_sigmoid = gof.PatternSub( 425 (tensor.log1p, 426 (tensor.neg, (sigmoid, 'x'))), 427 (tensor.neg, (softplus, 'x')), 428 values_eq_approx=values_eq_approx_remove_inf, 429 allow_multiple_clients=True) 430 431opt.register_stabilize(logsigm_to_softplus, name='logsigm_to_softplus') 432opt.register_stabilize(log1msigm_to_softplus, name='log1msigm_to_softplus') 433opt.register_stabilize(log1pexp_to_softplus, name='log1pexp_to_softplus') 434opt.register_stabilize(log1p_neg_sigmoid, name='log1p_neg_sigmoid,') 435 436 437def is_1pexp(t, only_process_constants=True): 438 """ 439 440 Returns 441 ------- 442 object 443 If 't' is of the form (1+exp(x)), return (False, x). 444 Else return None. 445 446 """ 447 if t.owner and t.owner.op == tensor.add: 448 scalars, scalar_inputs, nonconsts = \ 449 opt.scalarconsts_rest(t.owner.inputs, 450 only_process_constants=only_process_constants) 451 # scalar_inputs are potentially dimshuffled and filled with scalars 452 if len(nonconsts) == 1: 453 maybe_exp = nonconsts[0] 454 if maybe_exp.owner and maybe_exp.owner.op == tensor.exp: 455 # Verify that the constant terms sum to 1. 456 if scalars: 457 scal_sum = scalars[0] 458 for s in scalars[1:]: 459 scal_sum = scal_sum + s 460 if np.allclose(scal_sum, 1): 461 return False, maybe_exp.owner.inputs[0] 462 # Before 7987b51 there used to be a bug where *any* constant 463 # was considered as if it was equal to 1, and thus this 464 # function would incorrectly identify it as (1 + exp(x)). 465 if config.warn.identify_1pexp_bug: 466 warnings.warn( 467 'Although your current code is fine, please note that ' 468 'Theano versions prior to 0.5 (more specifically, ' 469 'prior to commit 7987b51 on 2011-12-18) may have ' 470 'yielded an incorrect result. To remove this warning, ' 471 'either set the `warn.identify_1pexp_bug` config ' 472 'option to False, or `warn.ignore_bug_before` to at ' 473 'least \'0.4.1\'.') 474 return None 475 476 477def is_exp(var): 478 """ 479 Match a variable with either of the `exp(x)` or `-exp(x)` patterns. 480 481 Parameters 482 ---------- 483 var 484 The Variable to analyze. 485 486 Returns 487 ------- 488 tuple 489 A pair (b, x) with `b` a boolean set to True if `var` is of the 490 form `-exp(x)` and False if `var` is of the form `exp(x)`. If `var` 491 cannot be cast into either form, then return `None`. 492 493 """ 494 neg = False 495 neg_info = is_neg(var) 496 if neg_info is not None: 497 neg = True 498 var = neg_info 499 if var.owner and var.owner.op == tensor.exp: 500 return neg, var.owner.inputs[0] 501 502 503def is_mul(var): 504 """ 505 Match a variable with `x * y * z * ...`. 506 507 Parameters 508 ---------- 509 var 510 The Variable to analyze. 511 512 Returns 513 ------- 514 object 515 A list [x, y, z, ...] if `var` is of the form `x * y * z * ...`, 516 or None if `var` cannot be cast into this form. 517 518 """ 519 if var.owner and var.owner.op == tensor.mul: 520 return var.owner.inputs 521 else: 522 return None 523 524 525def partition_num_or_denom(r, f): 526 if r.owner and r.owner.op == tensor.mul: 527 a = r.owner.inputs 528 else: 529 a = [r] 530 531 # ugly 2.4-compatible thing 532 f_terms = [] 533 neg = False 534 rest = [] 535 for t in a: 536 f_t = f(t) 537 if f_t is None: 538 rest.append(t) 539 else: 540 neg_t, f_t = f_t 541 f_terms.append(f_t) 542 neg ^= neg_t # bit flip if neg_t is true 543 return f_terms, rest, neg 544 545 546def is_neg(var): 547 """ 548 Match a variable with the `-x` pattern. 549 550 Parameters 551 ---------- 552 var 553 The Variable to analyze. 554 555 Returns 556 ------- 557 object 558 `x` if `var` is of the form `-x`, or None otherwise. 559 560 """ 561 apply = var.owner 562 if not apply: 563 return None 564 # First match against `tensor.neg`. 565 if apply.op == tensor.neg: 566 return apply.inputs[0] 567 # Then match against a multiplication by -1. 568 if apply.op == tensor.mul and len(apply.inputs) >= 2: 569 for idx, mul_input in enumerate(apply.inputs): 570 try: 571 constant = opt.get_scalar_constant_value(mul_input) 572 is_minus_1 = np.allclose(constant, -1) 573 except NotScalarConstantError: 574 is_minus_1 = False 575 if is_minus_1: 576 # Found a multiplication by -1. 577 if len(apply.inputs) == 2: 578 # Only return the other input. 579 return apply.inputs[1 - idx] 580 else: 581 # Return the multiplication of all other inputs. 582 return tensor.mul(*(apply.inputs[0:idx] + 583 apply.inputs[idx + 1:])) 584 # No match. 585 return None 586 587 588@opt.register_stabilize 589@gof.local_optimizer([tensor.true_div]) 590def local_exp_over_1_plus_exp(node): 591 """ 592 exp(x)/(1+exp(x)) -> sigm(x) 593 c/(1+exp(x)) -> c*sigm(-x) 594 595 """ 596 # this optimization should be done for numerical stability 597 # so we don't care to check client counts 598 if node.op == tensor.true_div: 599 600 # find all the exp() terms in the numerator 601 num, denom = node.inputs 602 num_exp_x, num_rest, num_neg = partition_num_or_denom(num, is_exp) 603 denom_1pexp, denom_rest, \ 604 denom_neg = partition_num_or_denom(denom, is_1pexp) 605 606 sigmoids = [] 607 for t in denom_1pexp: 608 if t in num_exp_x: 609 # case: exp(x) /(1+exp(x)) 610 sigmoids.append(sigmoid(t)) 611 del num_exp_x[num_exp_x.index(t)] 612 else: 613 # case: 1/(1+exp(x)) 614 sigmoids.append(sigmoid(-t)) 615 copy_stack_trace(node.outputs[0], sigmoids[-1]) 616 617 if not sigmoids: # we didn't find any. abort 618 return 619 # put the new numerator together 620 new_num = sigmoids + [tensor.exp(t) for t in num_exp_x] + num_rest 621 if len(new_num) == 1: 622 new_num = new_num[0] 623 else: 624 new_num = tensor.mul(*new_num) 625 626 if num_neg ^ denom_neg: 627 new_num = -new_num 628 629 copy_stack_trace(num, new_num) 630 631 if len(denom_rest) == 0: 632 return [new_num] 633 elif len(denom_rest) == 1: 634 out = new_num / denom_rest[0] 635 else: 636 out = new_num / tensor.mul(*denom_rest) 637 638 copy_stack_trace(node.outputs[0], out) 639 return [out] 640 641 642def parse_mul_tree(root): 643 """ 644 Parse a tree of multiplications starting at the given root. 645 646 Parameters 647 ---------- 648 root 649 The variable at the root of the tree. 650 651 Returns 652 ------- 653 object 654 A tree where each non-leaf node corresponds to a multiplication 655 in the computation of `root`, represented by the list of its inputs. 656 Each input is a pair [n, x] with `n` a boolean value indicating whether 657 sub-tree `x` should be negated. 658 659 Examples 660 -------- 661 x * y -> [False, [[False, x], [False, y]]] 662 -(x * y) -> [True, [[False, x], [False, y]]] 663 -x * y -> [False, [[True, x], [False, y]]] 664 -x -> [True, x] 665 (x * y) * -z -> [False, [[False, [[False, x], [False, y]]], 666 [True, z]]] 667 668 """ 669 # Is it a multiplication? 670 mul_info = is_mul(root) 671 if mul_info is None: 672 # Is it a negation? 673 neg_info = is_neg(root) 674 if neg_info is None: 675 # Keep the root "as is". 676 return [False, root] 677 else: 678 # Recurse, inverting the negation. 679 neg, sub_tree = parse_mul_tree(neg_info) 680 return [not neg, sub_tree] 681 else: 682 # Recurse into inputs. 683 return [False, list(map(parse_mul_tree, mul_info))] 684 685 686def replace_leaf(arg, leaves, new_leaves, op, neg): 687 """ 688 Attempt to replace a leaf of a multiplication tree. 689 690 We search for a leaf in `leaves` whose argument is `arg`, and if we find 691 one, we remove it from `leaves` and add to `new_leaves` a leaf with 692 argument `arg` and variable `op(arg)`. 693 694 Parameters 695 ---------- 696 arg 697 The argument of the leaf we are looking for. 698 leaves 699 List of leaves to look into. Each leaf should be a pair 700 (x, l) with `x` the argument of the Op found in the leaf, and `l` the 701 actual leaf as found in a multiplication tree output by `parse_mul_tree` 702 (i.e. a pair [boolean, variable]). 703 new_leaves 704 If a replacement occurred, then the leaf is removed from `leaves` 705 and added to the list `new_leaves` (after being modified by `op`). 706 op 707 A function that, when applied to `arg`, returns the Variable 708 we want to replace the original leaf variable with. 709 neg : bool 710 If True, then the boolean value associated to the leaf should 711 be swapped. If False, then this value should remain unchanged. 712 713 Returns 714 ------- 715 bool 716 True if a replacement occurred, or False otherwise. 717 718 """ 719 for idx, x in enumerate(leaves): 720 if x[0] == arg: 721 x[1][0] ^= neg 722 x[1][1] = op(arg) 723 leaves.pop(idx) 724 new_leaves.append(x) 725 return True 726 return False 727 728 729def simplify_mul(tree): 730 """ 731 Simplify a multiplication tree. 732 733 Parameters 734 ---------- 735 tree 736 A multiplication tree (as output by `parse_mul_tree`). 737 738 Returns 739 ------- 740 object 741 A multiplication tree computing the same output as `tree` but without 742 useless multiplications by 1 nor -1 (identified by leaves of the form 743 [False, None] or [True, None] respectively). Useless multiplications 744 (with less than two inputs) are also removed from the tree. 745 746 """ 747 neg, inputs = tree 748 if isinstance(inputs, list): 749 # Recurse through inputs. 750 s_inputs = [] 751 for s_i in imap(simplify_mul, inputs): 752 if s_i[1] is None: 753 # Multiplication by +/-1. 754 neg ^= s_i[0] 755 else: 756 s_inputs.append(s_i) 757 if not s_inputs: 758 # The multiplication is empty. 759 rval = [neg, None] 760 elif len(s_inputs) == 1: 761 # The multiplication has a single input. 762 s_inputs[0][0] ^= neg 763 rval = s_inputs[0] 764 else: 765 rval = [neg, s_inputs] 766 else: 767 rval = tree 768 # print 'simplify_mul: %s -> %s' % (tree, rval) 769 return rval 770 771 772def compute_mul(tree): 773 """ 774 Compute the Variable that is the output of a multiplication tree. 775 776 This is the inverse of the operation performed by `parse_mul_tree`, i.e. 777 compute_mul(parse_mul_tree(tree)) == tree. 778 779 Parameters 780 ---------- 781 tree 782 A multiplication tree (as output by `parse_mul_tree`). 783 784 Returns 785 ------- 786 object 787 A Variable that computes the multiplication represented by the tree. 788 789 """ 790 neg, inputs = tree 791 if inputs is None: 792 raise AssertionError( 793 'Function `compute_mul` found a missing leaf, did you forget to ' 794 'call `simplify_mul` on the tree first?') 795 elif isinstance(inputs, list): 796 # Recurse through inputs. 797 rval = tensor.mul(*list(map(compute_mul, inputs))) 798 else: 799 rval = inputs 800 if neg: 801 rval = -rval 802 return rval 803 804 805def perform_sigm_times_exp(tree, exp_x=None, exp_minus_x=None, sigm_x=None, 806 sigm_minus_x=None, parent=None, child_idx=None, 807 full_tree=None): 808 """ 809 Core processing of the `local_sigm_times_exp` optimization. 810 811 This recursive function operates on a multiplication tree as output by 812 `parse_mul_tree`. It walks through the tree and modifies it in-place 813 by replacing matching pairs (exp, sigmoid) with the desired optimized 814 version. 815 816 Parameters 817 ---------- 818 tree 819 The sub-tree to operate on. 820 exp_x 821 List of arguments x so that `exp(x)` exists somewhere in the whole 822 multiplication tree. Each argument is a pair (x, leaf) with `x` the 823 argument of the exponential, and `leaf` the corresponding leaf in the 824 multiplication tree (of the form [n, exp(x)] -- see `parse_mul_tree`). 825 If None, this argument is initialized to an empty list. 826 exp_minus_x 827 Similar to `exp_x`, but for `exp(-x)`. 828 sigm_x 829 Similar to `exp_x`, but for `sigmoid(x)`. 830 sigm_minus_x 831 Similar to `exp_x`, but for `sigmoid(-x)`. 832 parent 833 Parent of `tree` (None if `tree` is the global root). 834 child_idx 835 Index of `tree` in its parent's inputs (None if `tree` is the global 836 root). 837 full_tree 838 The global multiplication tree (should not be set except by recursive 839 calls to this function). Used for debugging only. 840 841 Returns 842 ------- 843 bool 844 True if a modification was performed somewhere in the whole multiplication 845 tree, or False otherwise. 846 847 """ 848 if exp_x is None: 849 exp_x = [] 850 if exp_minus_x is None: 851 exp_minus_x = [] 852 if sigm_x is None: 853 sigm_x = [] 854 if sigm_minus_x is None: 855 sigm_minus_x = [] 856 if full_tree is None: 857 full_tree = tree 858 if False: # Debug code. 859 print('<perform_sigm_times_exp>') 860 print(' full_tree = %s' % full_tree) 861 print(' tree = %s' % tree) 862 print(' exp_x = %s' % exp_x) 863 print(' exp_minus_x = %s' % exp_minus_x) 864 print(' sigm_x = %s' % sigm_x) 865 print(' sigm_minus_x= %s' % sigm_minus_x) 866 neg, inputs = tree 867 if isinstance(inputs, list): 868 # Recurse through inputs of the multiplication. 869 rval = False 870 for sub_idx, sub_tree in enumerate(inputs): 871 rval |= perform_sigm_times_exp( 872 tree=sub_tree, parent=tree, child_idx=sub_idx, 873 exp_x=exp_x, exp_minus_x=exp_minus_x, sigm_x=sigm_x, 874 sigm_minus_x=sigm_minus_x, full_tree=full_tree) 875 return rval 876 else: 877 # Reached a leaf: if it is an exponential or a sigmoid, then we 878 # first attempt to find a match in leaves already visited. 879 # If there is such a match, we modify the already-visited leaf 880 # accordingly: for instance if we visited a leaf sigmoid(x), then 881 # find later a -exp(-x), we replace the previous leaf by 882 # -sigmoid(-x) and remove the -exp(-x) from the tree. 883 # If no match is found, then we register this leaf so that it can 884 # be found later while walking the tree. 885 var = inputs 886 keep_it = False 887 exp_info = is_exp(var) 888 if exp_info is not None: 889 exp_neg, exp_arg = exp_info 890 neg ^= exp_neg 891 neg_arg = is_neg(exp_arg) 892 if neg_arg is None: 893 if not replace_leaf(exp_arg, sigm_minus_x, sigm_x, 894 sigmoid, neg): 895 exp_x.append((exp_arg, tree)) 896 keep_it = True 897 else: 898 if not replace_leaf(neg_arg, sigm_x, sigm_minus_x, 899 lambda x: sigmoid(-x), neg): 900 exp_minus_x.append((neg_arg, tree)) 901 keep_it = True 902 elif var.owner and var.owner.op == sigmoid: 903 sigm_arg = var.owner.inputs[0] 904 neg_arg = is_neg(sigm_arg) 905 if neg_arg is None: 906 if not replace_leaf(sigm_arg, exp_minus_x, sigm_minus_x, 907 lambda x: sigmoid(-x), neg): 908 sigm_x.append((sigm_arg, tree)) 909 keep_it = True 910 else: 911 if not replace_leaf(neg_arg, exp_x, sigm_x, sigmoid, neg): 912 sigm_minus_x.append((neg_arg, tree)) 913 keep_it = True 914 else: 915 # It is not an exponential nor a sigmoid. 916 keep_it = True 917 if not keep_it: 918 # Delete this leaf, i.e. replace it by [False, None] (corresponding 919 # to a multiplication by 1). 920 assert parent is not None 921 parent[1][child_idx] = [False, None] 922 return not keep_it 923 924 925@opt.register_stabilize 926@gof.local_optimizer([tensor.mul]) 927def local_sigm_times_exp(node): 928 """ 929 exp(x) * sigm(-x) -> sigm(x) 930 exp(-x) * sigm(x) -> sigm(-x) 931 932 todo: add stack traces to the intermediate variables 933 """ 934 # Bail early if it is not a multiplication. 935 if node.op != tensor.mul: 936 return None 937 # Obtain tree of multiplications starting at this node. 938 mul_tree = parse_mul_tree(node.outputs[0]) 939 # Perform core optimization. 940 did_something = perform_sigm_times_exp(mul_tree) 941 if not did_something: 942 # No change. 943 return None 944 # The optimization may have introduced multiplications by 1 in the tree: 945 # get rid of them. 946 mul_tree = simplify_mul(mul_tree) 947 # Recompute final output based on the updated tree. 948 out = compute_mul(mul_tree) 949 # keep the stack trace 950 copy_stack_trace(node.outputs[0], out) 951 return [out] 952 953 954@opt.register_stabilize 955@gof.local_optimizer([tensor.inv]) 956def local_inv_1_plus_exp(node): 957 """ 958 1/(1+exp(x)) -> sigm(-x) 959 960 """ 961 # this optimization should be done for numerical stability 962 # so we don't care to check client counts 963 if node.op == tensor.inv: 964 inv_arg = node.inputs[0] 965 if inv_arg.owner and inv_arg.owner.op == tensor.add: 966 scalars, scalar_inputs, nonconsts = \ 967 opt.scalarconsts_rest(inv_arg.owner.inputs, only_process_constants=True) 968 # scalar_inputs are potentially dimshuffled and fill'd scalars 969 if len(nonconsts) == 1: 970 if nonconsts[0].owner and nonconsts[0].owner.op == tensor.exp: 971 if scalars and np.allclose(np.sum(scalars), 1): 972 out = opt._fill_chain( 973 sigmoid( 974 tensor.neg(nonconsts[0].owner.inputs[0])), 975 scalar_inputs) 976 # keep combined stack traces of 977 # exp(x): nonconsts[0], 978 # 1 + exp(x): inv_arg, 979 # 1 / (1 + exp(x)): node.outputs[0] 980 copy_stack_trace( 981 [nonconsts[0], inv_arg, node.outputs[0]], out) 982 return out 983 984# Registration is below, and conditional. 985 986 987@gof.local_optimizer([tensor.sub]) 988def local_1msigmoid(node): 989 """ 990 1-sigm(x) -> sigm(-x) 991 992 """ 993 if node.op == tensor.sub: 994 sub_l, sub_r = node.inputs 995 if len(sub_r.clients) > 1: 996 return # graph is using both sigm and 1-sigm 997 if sub_r.owner and sub_r.owner.op == sigmoid: 998 try: 999 val_l = opt.get_scalar_constant_value(sub_l) 1000 except tensor.NotScalarConstantError: 1001 return 1002 if np.allclose(np.sum(val_l), 1): 1003 out = sigmoid(-sub_r.owner.inputs[0]) 1004 copy_stack_trace([sub_r, node.outputs[0]], out) 1005 return [out] 1006 1007register_local_1msigmoid = False 1008# This is False because the Stabilize pattern above 1009# is looking for 1-sigm. Also Canonizer turns neg into *(-1) and so 1010# this optimization might set off an unwanted chain of things. 1011# OTH - this transformation can be seen as pushing normal arithmetic either below or above the 1012# sigmoidal nonlinearity... so if the canonicalized form had anything to say about that then it 1013# would be a consideration... anyway leaving False for now. 1014 1015if register_local_1msigmoid: 1016 opt.register_canonicalize(local_1msigmoid) 1017