1"""
2Ops and optimizations: sigmoid, softplus.
3
4These functions implement special cases of exp and log to improve numerical
5stability.
6
7"""
8from __future__ import absolute_import, print_function, division
9
10import warnings
11
12import numpy as np
13
14import theano
15from theano import config, gof, printing, scalar
16from theano.compat import imap
17from theano.printing import pprint
18from theano.tensor import basic as tensor
19from theano.tensor import elemwise, opt, NotScalarConstantError
20from theano.tensor.type import values_eq_approx_remove_inf
21from theano.gof.opt import copy_stack_trace
22
23############
24#
25# SCALAR OPS
26#
27
28
29class ScalarSigmoid(scalar.UnaryScalarOp):
30    """
31    This is just speed opt. Not for stability.
32
33    """
34    @staticmethod
35    def st_impl(x):
36        if x < -30.0:
37            return 0.0
38        if x > 30.0:
39            return 1.0
40        # If x is an int8 or uint8, numpy.exp will compute the result in
41        # half-precision (float16), where we want float32.
42        x_dtype = str(getattr(x, 'dtype', ''))
43        if x_dtype in ('int8', 'uint8'):
44            return 1.0 / (1.0 + np.exp(-x, sig='f'))
45        return 1.0 / (1.0 + np.exp(-x))
46
47    def impl(self, x):
48        return ScalarSigmoid.st_impl(x)
49
50    def grad(self, inp, grads):
51        x, = inp
52        gz, = grads
53        y = scalar_sigmoid(x)
54        rval = gz * y * (1.0 - y)
55
56        assert rval.type.dtype.find('float') != -1
57
58        return [rval]
59
60    def c_code(self, node, name, inp, out, sub):
61        x, = inp
62        z, = out
63        # We add boundary checks prevent exp from generating inf or
64        # 0. The reset of the logic always generate 0 or 1 in those
65        # cases. This is a speed optimization.
66        # The constants were obtained by looking at the output of
67        # python commands like:
68        #
69        # import numpy, theano
70        # dt='float32'  # or float64
71        # for i in xrange(750):
72        #     print i, repr(theano._asarray(1.0, dtype=dt) /
73        #                   (theano._asarray(1.0, dtype=dt) +
74        #                    numpy.exp(-theano._asarray([i,-i], dtype=dt))))
75
76        # float16 limits: -11.0, 7.0f
77        # We use the float32 limits for float16 for now as the
78        # computation will happen in float32 anyway.
79        if (node.inputs[0].type == scalar.float32 or
80                node.inputs[0].type == scalar.float16):
81            return """%(z)s = %(x)s < -88.0f ? 0.0 : %(x)s > 15.0f ? 1.0f : 1.0f /(1.0f + exp(-%(x)s));""" % locals()
82        elif node.inputs[0].type == scalar.float64:
83            return """%(z)s = %(x)s < -709.0 ? 0.0 : %(x)s > 19.0 ? 1.0 : 1.0 /(1.0+exp(-%(x)s));""" % locals()
84        else:
85            raise NotImplementedError('only floatingpoint is implemented')
86
87    def c_code_cache_version(self):
88        v = super(ScalarSigmoid, self).c_code_cache_version()
89        if v:
90            return (2,) + v
91        else:
92            return v
93
94    # This fct is disabled as it is slower then the normal code!
95    def c_code_contiguous_disabled(self, node, name, inp, out, sub):
96        x, = inp
97        z, = out
98        if (not theano.config.lib.amdlibm or
99                node.inputs[0].dtype != node.outputs[0].dtype):
100            raise theano.gof.utils.MethodNotDefined()
101        dtype = node.inputs[0].dtype
102        if dtype == 'float32' and self.amd_float32 is not None:
103            dtype = 'float'
104            fct = "amd_vrsa_expf"
105        elif dtype == 'float64' and self.amd_float64 is not None:
106            dtype = 'double'
107            fct = "amd_vrda_exp"
108        else:
109            raise theano.gof.utils.MethodNotDefined()
110        return """
111        npy_intp n = PyArray_SIZE(%(z)s);
112        %(dtype)s * x = (%(dtype)s*) PyArray_DATA(%(x)s);
113        %(dtype)s * z = (%(dtype)s*) PyArray_DATA(%(z)s);
114        // We block to keep the data in l1
115        // normal l1 size = 32k: 32k/2(input + output)/8(nb bytes of double)=2k
116        // We stay bellow the 2k limit to let space for
117        // This is faster than the not blocking version
118        for(int i=0;i<n;i+=2048){
119            npy_intp nb = (n-i<2048)?n-i:2048;
120            for(int j=0;j<nb;j++){
121                z[i+j] = -x[i+j];
122            }
123            %(fct)s(nb, z+i, z+i);
124            for(int j=0;j<nb;j++){
125                z[i+j] = 1.0 /(1.0+z[i+j]);
126            }
127        }
128        """ % locals()
129        raise theano.gof.utils.MethodNotDefined()
130
131    @staticmethod
132    def gen_graph():
133        """
134        This method was used to generate the graph: sigmoid_prec.png in the doc.
135
136        """
137        data = np.arange(-15, 15, .1)
138        val = 1 / (1 + np.exp(-data))
139
140        def hard_sigmoid(x):
141            return theano.tensor.nnet.hard_sigmoid(x)
142
143        def ultra_fast_sigmoid(x):
144            return theano.tensor.nnet.ultra_fast_sigmoid(x)
145
146        val_hard = hard_sigmoid(data).eval()
147        val_ultra = ultra_fast_sigmoid(data).eval()
148
149        import matplotlib.pyplot as plt
150        import os
151        fig = plt.figure()
152        ax = fig.add_subplot(111)
153        ax.plot(data, val)  # , 'o-')
154        ax.plot(data, val_ultra)  # , '-')
155        ax.plot(data, val_hard)  # , '-')
156        ax.grid(True)
157        ax.legend(("sigmoid", "ultra_fast", "hard"), "upper left")
158        fname = os.path.join(os.path.dirname(theano.__file__), '..',
159                             'doc', 'library', 'tensor', 'nnet',
160                             'sigmoid_prec.png')
161        plt.savefig(fname)
162        print("New picture saved at", fname)
163        print(val_ultra.max())
164        print(val_ultra.min())
165
166
167scalar_sigmoid = ScalarSigmoid(scalar.upgrade_to_float, name='scalar_sigmoid')
168sigmoid = elemwise.Elemwise(scalar_sigmoid, name='sigmoid')
169
170sigmoid_inplace = elemwise.Elemwise(
171    ScalarSigmoid(scalar.transfer_type(0)),
172    inplace_pattern={0: 0},
173    name='sigmoid_inplace',
174)
175
176pprint.assign(sigmoid, printing.FunctionPrinter('sigmoid'))
177
178
179class UltraFastScalarSigmoid(scalar.UnaryScalarOp):
180    """
181    This is just speed opt. Not for stability.
182
183    """
184    @staticmethod
185    def st_impl(x):
186        x = 0.5 * x
187        # The if is a tanh approximate.
188        if x >= 0:
189            if x < 1.7:
190                z = (1.5 * x / (1 + x))
191            elif x < 3:
192                z = (0.935409070603099 + 0.0458812946797165 * (x - 1.7))
193            else:
194                z = 0.99505475368673
195        else:
196            xx = -x
197            if xx < 1.7:
198                z = (1.5 * xx / (1 + xx))
199            elif xx < 3:
200                z = (0.935409070603099 + 0.0458812946797165 * (xx - 1.7))
201            else:
202                z = 0.99505475368673
203            z = -z
204
205        return 0.5 * (z + 1.)
206
207    def impl(self, x):
208        return UltraFastScalarSigmoid.st_impl(x)
209
210    def c_code(self, node, name, inp, out, sub):
211        x, = inp
212        z, = out
213        dtype = node.outputs[0].type.dtype_specs()[1]
214
215        return """
216        %(dtype)s x = 0.5 * %(x)s;
217   // The if is a tanh approximate.
218   if(x>=0) {
219        %(z)s = (x<1.7 ? (1.5*x/(1+x)) :
220                         (x<3 ? (0.935409070603099 + 0.0458812946797165*(x-1.7)):
221                         0.99505475368673));
222    } else {
223        %(dtype)s xx = -x;
224        %(z)s = -(xx<1.7 ? (1.5*xx/(1+xx)) :
225                           (xx<3 ? (0.935409070603099 + 0.0458812946797165*(xx-1.7)):
226                                   0.99505475368673));
227    }
228
229        //%(z)s = 0.5*(ultrafasttanh(0.5*x)+1.);
230        %(z)s = 0.5*(%(z)s+1.);
231        """ % locals()
232
233ultra_fast_scalar_sigmoid = UltraFastScalarSigmoid(
234    scalar.upgrade_to_float, name='ultra_fast_scalar_sigmoid')
235ultra_fast_sigmoid = elemwise.Elemwise(ultra_fast_scalar_sigmoid,
236                                       name='ultra_fast_sigmoid')
237
238ultra_fast_sigmoid_inplace = elemwise.Elemwise(
239    UltraFastScalarSigmoid(scalar.transfer_type(0)),
240    inplace_pattern={0: 0},
241    name='ultra_fast_sigmoid_inplace',
242)
243
244pprint.assign(ultra_fast_sigmoid,
245              printing.FunctionPrinter('ultra_fast_sigmoid'))
246
247
248# @opt.register_uncanonicalize
249@gof.local_optimizer([sigmoid])
250def local_ultra_fast_sigmoid(node):
251    """
252    When enabled, change all sigmoid to ultra_fast_sigmoid.
253
254    For example do mode.including('local_ultra_fast_sigmoid')
255    or use the Theano flag optimizer_including=local_ultra_fast_sigmoid.
256
257    This speeds up the sigmoid op by using an approximation.
258
259    This is done after the stabilization and specialize phases
260    to avoid interacting with them.
261
262    """
263    if (isinstance(node.op, tensor.Elemwise) and
264            node.op.scalar_op == scalar_sigmoid):
265        out = ultra_fast_sigmoid(node.inputs[0])
266        copy_stack_trace(node.outputs[0], out)
267
268        def values_eq_approx_remove_low_prec(a, b):
269            # atol is found by trial/error.
270            # Other test could fail without good reason.
271            return tensor.TensorType.values_eq_approx(a, b, atol=0.02)
272        # Let DebugMode know that there this opt approx the values.
273        out.tag.values_eq_approx = values_eq_approx_remove_low_prec
274        return [out]
275theano.compile.optdb['uncanonicalize'].register("local_ultra_fast_sigmoid",
276                                                local_ultra_fast_sigmoid)
277
278
279def hard_sigmoid(x):
280    """
281    An approximation of sigmoid.
282
283    More approximate and faster than ultra_fast_sigmoid.
284
285    Approx in 3 parts: 0, scaled linear, 1.
286
287    Removing the slope and shift does not make it faster.
288
289    """
290    # Use the same dtype as determined by "upgrade_to_float",
291    # and perform computation in that dtype.
292    out_dtype = scalar.upgrade_to_float(scalar.Scalar(dtype=x.dtype))[0].dtype
293    slope = tensor.constant(0.2, dtype=out_dtype)
294    shift = tensor.constant(0.5, dtype=out_dtype)
295    x = (x * slope) + shift
296    x = tensor.clip(x, 0, 1)
297    return x
298
299
300# @opt.register_uncanonicalize
301@gof.local_optimizer([sigmoid])
302def local_hard_sigmoid(node):
303    if (isinstance(node.op, tensor.Elemwise) and
304            node.op.scalar_op == scalar_sigmoid):
305        out = hard_sigmoid(node.inputs[0])
306        copy_stack_trace(node.outputs[0], out)
307
308        def values_eq_approx_remove_low_prec(a, b):
309            # atol is found by trial/error.
310            # Other test could fail without good reason.
311            return tensor.TensorType.values_eq_approx(a, b, atol=0.1)
312        # Let DebugMode know that there this opt approx the values.
313        out.tag.values_eq_approx = values_eq_approx_remove_low_prec
314        return [out]
315theano.compile.optdb['uncanonicalize'].register("local_hard_sigmoid",
316                                                local_hard_sigmoid)
317
318
319class ScalarSoftplus(scalar.UnaryScalarOp):
320    """
321    This helps numerical stability.
322    """
323    @staticmethod
324    def static_impl(x):
325        if x < -30.0:
326            return 0.0
327        if x > 30.0:
328            return x
329        # If x is an int8 or uint8, numpy.exp will compute the result in
330        # half-precision (float16), where we want float32.
331        x_dtype = str(getattr(x, 'dtype', ''))
332        if x_dtype in ('int8', 'uint8'):
333            return np.log1p(np.exp(x, sig='f'))
334        return np.log1p(np.exp(x))
335
336    def impl(self, x):
337        return ScalarSoftplus.static_impl(x)
338
339    def grad(self, inp, grads):
340        x, = inp
341        gz, = grads
342        return [gz * scalar_sigmoid(x)]
343
344    def c_code(self, node, name, inp, out, sub):
345        x, = inp
346        z, = out
347        # These constants were obtained by looking at the output of
348        # python commands like:
349        #  for i in xrange(750):
350        #      print i, repr(numpy.log1p(numpy.exp(theano._asarray([i,-i], dtype=dt))))
351        # the boundary checks prevent us from generating inf
352
353        # float16 limits: -17.0, 6.0
354        # We use the float32 limits for float16 for now as the
355        # computation will happen in float32 anyway.
356        if (node.inputs[0].type == scalar.float32 or
357                node.inputs[0].type == scalar.float16):
358            return """%(z)s = %(x)s < -103.0f ? 0.0 : %(x)s > 14.0f ? %(x)s : log1p(exp(%(x)s));""" % locals()
359        elif node.inputs[0].type == scalar.float64:
360            return """%(z)s = %(x)s < -745.0 ? 0.0 : %(x)s > 16.0 ? %(x)s : log1p(exp(%(x)s));""" % locals()
361        else:
362            raise NotImplementedError('only floatingpoint is implemented')
363
364    def c_code_cache_version(self):
365        v = super(ScalarSoftplus, self).c_code_cache_version()
366        if v:
367            return (2,) + v
368        else:
369            return v
370scalar_softplus = ScalarSoftplus(scalar.upgrade_to_float,
371                                 name='scalar_softplus')
372softplus = elemwise.Elemwise(scalar_softplus, name='softplus')
373
374pprint.assign(softplus, printing.FunctionPrinter('softplus'))
375
376
377def _skip_mul_1(r):
378    if r.owner and r.owner.op == tensor.mul:
379        not_is_1 = [i for i in r.owner.inputs if not _is_1(i)]
380        if len(not_is_1) == 1:
381            return not_is_1[0]
382
383logsigm_to_softplus = gof.PatternSub(
384    (tensor.log, (sigmoid, 'x')),
385    (tensor.neg, (softplus, (tensor.neg, 'x'))),
386    allow_multiple_clients=True,
387    values_eq_approx=values_eq_approx_remove_inf,
388    skip_identities_fn=_skip_mul_1)
389
390
391def _is_1(expr):
392    """
393
394    Returns
395    -------
396    bool
397        True iff expr is a constant close to 1.
398
399    """
400    try:
401        v = opt.get_scalar_constant_value(expr)
402        return np.allclose(v, 1)
403    except tensor.NotScalarConstantError:
404        return False
405
406log1msigm_to_softplus = gof.PatternSub(
407    (tensor.log,
408        (tensor.sub,
409            dict(pattern='y', constraint=_is_1),
410            (sigmoid, 'x'))),
411    (tensor.neg, (softplus, 'x')),
412    allow_multiple_clients=True,
413    values_eq_approx=values_eq_approx_remove_inf,
414    skip_identities_fn=_skip_mul_1)
415
416
417log1pexp_to_softplus = gof.PatternSub(
418    (tensor.log1p,
419     (tensor.exp, 'x')),
420    (softplus, 'x'),
421    values_eq_approx=values_eq_approx_remove_inf,
422    allow_multiple_clients=True)
423
424log1p_neg_sigmoid = gof.PatternSub(
425    (tensor.log1p,
426     (tensor.neg, (sigmoid, 'x'))),
427    (tensor.neg, (softplus, 'x')),
428    values_eq_approx=values_eq_approx_remove_inf,
429    allow_multiple_clients=True)
430
431opt.register_stabilize(logsigm_to_softplus, name='logsigm_to_softplus')
432opt.register_stabilize(log1msigm_to_softplus, name='log1msigm_to_softplus')
433opt.register_stabilize(log1pexp_to_softplus, name='log1pexp_to_softplus')
434opt.register_stabilize(log1p_neg_sigmoid, name='log1p_neg_sigmoid,')
435
436
437def is_1pexp(t, only_process_constants=True):
438    """
439
440    Returns
441    -------
442    object
443        If 't' is of the form (1+exp(x)), return (False, x).
444        Else return None.
445
446    """
447    if t.owner and t.owner.op == tensor.add:
448        scalars, scalar_inputs, nonconsts = \
449            opt.scalarconsts_rest(t.owner.inputs,
450                                  only_process_constants=only_process_constants)
451        # scalar_inputs are potentially dimshuffled and filled with scalars
452        if len(nonconsts) == 1:
453            maybe_exp = nonconsts[0]
454            if maybe_exp.owner and maybe_exp.owner.op == tensor.exp:
455                # Verify that the constant terms sum to 1.
456                if scalars:
457                    scal_sum = scalars[0]
458                    for s in scalars[1:]:
459                        scal_sum = scal_sum + s
460                    if np.allclose(scal_sum, 1):
461                        return False, maybe_exp.owner.inputs[0]
462                # Before 7987b51 there used to be a bug where *any* constant
463                # was considered as if it was equal to 1, and thus this
464                # function would incorrectly identify it as (1 + exp(x)).
465                if config.warn.identify_1pexp_bug:
466                    warnings.warn(
467                        'Although your current code is fine, please note that '
468                        'Theano versions prior to 0.5 (more specifically, '
469                        'prior to commit 7987b51 on 2011-12-18) may have '
470                        'yielded an incorrect result. To remove this warning, '
471                        'either set the `warn.identify_1pexp_bug` config '
472                        'option to False, or `warn.ignore_bug_before` to at '
473                        'least \'0.4.1\'.')
474    return None
475
476
477def is_exp(var):
478    """
479    Match a variable with either of the `exp(x)` or `-exp(x)` patterns.
480
481    Parameters
482    ----------
483    var
484        The Variable to analyze.
485
486    Returns
487    -------
488    tuple
489        A pair (b, x) with `b` a boolean set to True if `var` is of the
490        form `-exp(x)` and False if `var` is of the form `exp(x)`. If `var`
491        cannot be cast into either form, then return `None`.
492
493    """
494    neg = False
495    neg_info = is_neg(var)
496    if neg_info is not None:
497        neg = True
498        var = neg_info
499    if var.owner and var.owner.op == tensor.exp:
500        return neg, var.owner.inputs[0]
501
502
503def is_mul(var):
504    """
505    Match a variable with `x * y * z * ...`.
506
507    Parameters
508    ----------
509    var
510        The Variable to analyze.
511
512    Returns
513    -------
514    object
515        A list [x, y, z, ...] if `var` is of the form `x * y * z * ...`,
516        or None if `var` cannot be cast into this form.
517
518    """
519    if var.owner and var.owner.op == tensor.mul:
520        return var.owner.inputs
521    else:
522        return None
523
524
525def partition_num_or_denom(r, f):
526    if r.owner and r.owner.op == tensor.mul:
527        a = r.owner.inputs
528    else:
529        a = [r]
530
531    # ugly 2.4-compatible thing
532    f_terms = []
533    neg = False
534    rest = []
535    for t in a:
536        f_t = f(t)
537        if f_t is None:
538            rest.append(t)
539        else:
540            neg_t, f_t = f_t
541            f_terms.append(f_t)
542            neg ^= neg_t  # bit flip if neg_t is true
543    return f_terms, rest, neg
544
545
546def is_neg(var):
547    """
548    Match a variable with the `-x` pattern.
549
550    Parameters
551    ----------
552    var
553        The Variable to analyze.
554
555    Returns
556    -------
557    object
558        `x` if `var` is of the form `-x`, or None otherwise.
559
560    """
561    apply = var.owner
562    if not apply:
563        return None
564    # First match against `tensor.neg`.
565    if apply.op == tensor.neg:
566        return apply.inputs[0]
567    # Then match against a multiplication by -1.
568    if apply.op == tensor.mul and len(apply.inputs) >= 2:
569        for idx, mul_input in enumerate(apply.inputs):
570            try:
571                constant = opt.get_scalar_constant_value(mul_input)
572                is_minus_1 = np.allclose(constant, -1)
573            except NotScalarConstantError:
574                is_minus_1 = False
575            if is_minus_1:
576                # Found a multiplication by -1.
577                if len(apply.inputs) == 2:
578                    # Only return the other input.
579                    return apply.inputs[1 - idx]
580                else:
581                    # Return the multiplication of all other inputs.
582                    return tensor.mul(*(apply.inputs[0:idx] +
583                                        apply.inputs[idx + 1:]))
584    # No match.
585    return None
586
587
588@opt.register_stabilize
589@gof.local_optimizer([tensor.true_div])
590def local_exp_over_1_plus_exp(node):
591    """
592    exp(x)/(1+exp(x)) -> sigm(x)
593    c/(1+exp(x)) -> c*sigm(-x)
594
595    """
596    # this optimization should be done for numerical stability
597    # so we don't care to check client counts
598    if node.op == tensor.true_div:
599
600        # find all the exp() terms in the numerator
601        num, denom = node.inputs
602        num_exp_x, num_rest, num_neg = partition_num_or_denom(num, is_exp)
603        denom_1pexp, denom_rest, \
604            denom_neg = partition_num_or_denom(denom, is_1pexp)
605
606        sigmoids = []
607        for t in denom_1pexp:
608            if t in num_exp_x:
609                # case: exp(x) /(1+exp(x))
610                sigmoids.append(sigmoid(t))
611                del num_exp_x[num_exp_x.index(t)]
612            else:
613                # case: 1/(1+exp(x))
614                sigmoids.append(sigmoid(-t))
615            copy_stack_trace(node.outputs[0], sigmoids[-1])
616
617        if not sigmoids:  # we didn't find any.  abort
618            return
619        # put the new numerator together
620        new_num = sigmoids + [tensor.exp(t) for t in num_exp_x] + num_rest
621        if len(new_num) == 1:
622            new_num = new_num[0]
623        else:
624            new_num = tensor.mul(*new_num)
625
626        if num_neg ^ denom_neg:
627            new_num = -new_num
628
629        copy_stack_trace(num, new_num)
630
631        if len(denom_rest) == 0:
632            return [new_num]
633        elif len(denom_rest) == 1:
634            out = new_num / denom_rest[0]
635        else:
636            out = new_num / tensor.mul(*denom_rest)
637
638        copy_stack_trace(node.outputs[0], out)
639        return [out]
640
641
642def parse_mul_tree(root):
643    """
644    Parse a tree of multiplications starting at the given root.
645
646    Parameters
647    ----------
648    root
649        The variable at the root of the tree.
650
651    Returns
652    -------
653    object
654        A tree where each non-leaf node corresponds to a multiplication
655        in the computation of `root`, represented by the list of its inputs.
656        Each input is a pair [n, x] with `n` a boolean value indicating whether
657        sub-tree `x` should be negated.
658
659    Examples
660    --------
661        x * y               -> [False, [[False, x], [False, y]]]
662        -(x * y)            -> [True, [[False, x], [False, y]]]
663        -x * y              -> [False, [[True, x], [False, y]]]
664        -x                  -> [True, x]
665        (x * y) * -z        -> [False, [[False, [[False, x], [False, y]]],
666                                        [True, z]]]
667
668    """
669    # Is it a multiplication?
670    mul_info = is_mul(root)
671    if mul_info is None:
672        # Is it a negation?
673        neg_info = is_neg(root)
674        if neg_info is None:
675            # Keep the root "as is".
676            return [False, root]
677        else:
678            # Recurse, inverting the negation.
679            neg, sub_tree = parse_mul_tree(neg_info)
680            return [not neg, sub_tree]
681    else:
682        # Recurse into inputs.
683        return [False, list(map(parse_mul_tree, mul_info))]
684
685
686def replace_leaf(arg, leaves, new_leaves, op, neg):
687    """
688    Attempt to replace a leaf of a multiplication tree.
689
690    We search for a leaf in `leaves` whose argument is `arg`, and if we find
691    one, we remove it from `leaves` and add to `new_leaves` a leaf with
692    argument `arg` and variable `op(arg)`.
693
694    Parameters
695    ----------
696    arg
697        The argument of the leaf we are looking for.
698    leaves
699        List of leaves to look into. Each leaf should be a pair
700        (x, l) with `x` the argument of the Op found in the leaf, and `l` the
701        actual leaf as found in a multiplication tree output by `parse_mul_tree`
702        (i.e. a pair [boolean, variable]).
703    new_leaves
704        If a replacement occurred, then the leaf is removed from `leaves`
705        and added to the list `new_leaves` (after being modified by `op`).
706    op
707        A function that, when applied to `arg`, returns the Variable
708        we want to replace the original leaf variable with.
709    neg : bool
710        If True, then the boolean value associated to the leaf should
711        be swapped. If False, then this value should remain unchanged.
712
713    Returns
714    -------
715    bool
716        True if a replacement occurred, or False otherwise.
717
718    """
719    for idx, x in enumerate(leaves):
720        if x[0] == arg:
721            x[1][0] ^= neg
722            x[1][1] = op(arg)
723            leaves.pop(idx)
724            new_leaves.append(x)
725            return True
726    return False
727
728
729def simplify_mul(tree):
730    """
731    Simplify a multiplication tree.
732
733    Parameters
734    ----------
735    tree
736        A multiplication tree (as output by `parse_mul_tree`).
737
738    Returns
739    -------
740    object
741        A multiplication tree computing the same output as `tree` but without
742        useless multiplications by 1 nor -1 (identified by leaves of the form
743        [False, None] or [True, None] respectively). Useless multiplications
744        (with less than two inputs) are also removed from the tree.
745
746    """
747    neg, inputs = tree
748    if isinstance(inputs, list):
749        # Recurse through inputs.
750        s_inputs = []
751        for s_i in imap(simplify_mul, inputs):
752            if s_i[1] is None:
753                # Multiplication by +/-1.
754                neg ^= s_i[0]
755            else:
756                s_inputs.append(s_i)
757        if not s_inputs:
758            # The multiplication is empty.
759            rval = [neg, None]
760        elif len(s_inputs) == 1:
761            # The multiplication has a single input.
762            s_inputs[0][0] ^= neg
763            rval = s_inputs[0]
764        else:
765            rval = [neg, s_inputs]
766    else:
767        rval = tree
768    # print 'simplify_mul: %s -> %s' % (tree, rval)
769    return rval
770
771
772def compute_mul(tree):
773    """
774    Compute the Variable that is the output of a multiplication tree.
775
776    This is the inverse of the operation performed by `parse_mul_tree`, i.e.
777    compute_mul(parse_mul_tree(tree)) == tree.
778
779    Parameters
780    ----------
781    tree
782        A multiplication tree (as output by `parse_mul_tree`).
783
784    Returns
785    -------
786    object
787        A Variable that computes the multiplication represented by the tree.
788
789    """
790    neg, inputs = tree
791    if inputs is None:
792        raise AssertionError(
793            'Function `compute_mul` found a missing leaf, did you forget to '
794            'call `simplify_mul` on the tree first?')
795    elif isinstance(inputs, list):
796        # Recurse through inputs.
797        rval = tensor.mul(*list(map(compute_mul, inputs)))
798    else:
799        rval = inputs
800    if neg:
801        rval = -rval
802    return rval
803
804
805def perform_sigm_times_exp(tree, exp_x=None, exp_minus_x=None, sigm_x=None,
806                           sigm_minus_x=None, parent=None, child_idx=None,
807                           full_tree=None):
808    """
809    Core processing of the `local_sigm_times_exp` optimization.
810
811    This recursive function operates on a multiplication tree as output by
812    `parse_mul_tree`. It walks through the tree and modifies it in-place
813    by replacing matching pairs (exp, sigmoid) with the desired optimized
814    version.
815
816    Parameters
817    ----------
818    tree
819        The sub-tree to operate on.
820    exp_x
821        List of arguments x so that `exp(x)` exists somewhere in the whole
822        multiplication tree. Each argument is a pair (x, leaf) with `x` the
823        argument of the exponential, and `leaf` the corresponding leaf in the
824        multiplication tree (of the form [n, exp(x)] -- see `parse_mul_tree`).
825        If None, this argument is initialized to an empty list.
826    exp_minus_x
827        Similar to `exp_x`, but for `exp(-x)`.
828    sigm_x
829        Similar to `exp_x`, but for `sigmoid(x)`.
830    sigm_minus_x
831        Similar to `exp_x`, but for `sigmoid(-x)`.
832    parent
833        Parent of `tree` (None if `tree` is the global root).
834    child_idx
835        Index of `tree` in its parent's inputs (None if `tree` is the global
836        root).
837    full_tree
838        The global multiplication tree (should not be set except by recursive
839        calls to this function). Used for debugging only.
840
841    Returns
842    -------
843    bool
844        True if a modification was performed somewhere in the whole multiplication
845        tree, or False otherwise.
846
847    """
848    if exp_x is None:
849        exp_x = []
850    if exp_minus_x is None:
851        exp_minus_x = []
852    if sigm_x is None:
853        sigm_x = []
854    if sigm_minus_x is None:
855        sigm_minus_x = []
856    if full_tree is None:
857        full_tree = tree
858    if False:  # Debug code.
859        print('<perform_sigm_times_exp>')
860        print('  full_tree   = %s' % full_tree)
861        print('  tree        = %s' % tree)
862        print('  exp_x       = %s' % exp_x)
863        print('  exp_minus_x = %s' % exp_minus_x)
864        print('  sigm_x      = %s' % sigm_x)
865        print('  sigm_minus_x= %s' % sigm_minus_x)
866    neg, inputs = tree
867    if isinstance(inputs, list):
868        # Recurse through inputs of the multiplication.
869        rval = False
870        for sub_idx, sub_tree in enumerate(inputs):
871            rval |= perform_sigm_times_exp(
872                tree=sub_tree, parent=tree, child_idx=sub_idx,
873                exp_x=exp_x, exp_minus_x=exp_minus_x, sigm_x=sigm_x,
874                sigm_minus_x=sigm_minus_x, full_tree=full_tree)
875        return rval
876    else:
877        # Reached a leaf: if it is an exponential or a sigmoid, then we
878        # first attempt to find a match in leaves already visited.
879        # If there is such a match, we modify the already-visited leaf
880        # accordingly: for instance if we visited a leaf sigmoid(x), then
881        # find later a -exp(-x), we replace the previous leaf by
882        # -sigmoid(-x) and remove the -exp(-x) from the tree.
883        # If no match is found, then we register this leaf so that it can
884        # be found later while walking the tree.
885        var = inputs
886        keep_it = False
887        exp_info = is_exp(var)
888        if exp_info is not None:
889            exp_neg, exp_arg = exp_info
890            neg ^= exp_neg
891            neg_arg = is_neg(exp_arg)
892            if neg_arg is None:
893                if not replace_leaf(exp_arg, sigm_minus_x, sigm_x,
894                                    sigmoid, neg):
895                    exp_x.append((exp_arg, tree))
896                    keep_it = True
897            else:
898                if not replace_leaf(neg_arg, sigm_x, sigm_minus_x,
899                                    lambda x: sigmoid(-x), neg):
900                    exp_minus_x.append((neg_arg, tree))
901                    keep_it = True
902        elif var.owner and var.owner.op == sigmoid:
903            sigm_arg = var.owner.inputs[0]
904            neg_arg = is_neg(sigm_arg)
905            if neg_arg is None:
906                if not replace_leaf(sigm_arg, exp_minus_x, sigm_minus_x,
907                                    lambda x: sigmoid(-x), neg):
908                    sigm_x.append((sigm_arg, tree))
909                    keep_it = True
910            else:
911                if not replace_leaf(neg_arg, exp_x, sigm_x, sigmoid, neg):
912                    sigm_minus_x.append((neg_arg, tree))
913                    keep_it = True
914        else:
915            # It is not an exponential nor a sigmoid.
916            keep_it = True
917        if not keep_it:
918            # Delete this leaf, i.e. replace it by [False, None] (corresponding
919            # to a multiplication by 1).
920            assert parent is not None
921            parent[1][child_idx] = [False, None]
922        return not keep_it
923
924
925@opt.register_stabilize
926@gof.local_optimizer([tensor.mul])
927def local_sigm_times_exp(node):
928    """
929    exp(x) * sigm(-x) -> sigm(x)
930    exp(-x) * sigm(x) -> sigm(-x)
931
932    todo: add stack traces to the intermediate variables
933    """
934    # Bail early if it is not a multiplication.
935    if node.op != tensor.mul:
936        return None
937    # Obtain tree of multiplications starting at this node.
938    mul_tree = parse_mul_tree(node.outputs[0])
939    # Perform core optimization.
940    did_something = perform_sigm_times_exp(mul_tree)
941    if not did_something:
942        # No change.
943        return None
944    # The optimization may have introduced multiplications by 1 in the tree:
945    # get rid of them.
946    mul_tree = simplify_mul(mul_tree)
947    # Recompute final output based on the updated tree.
948    out = compute_mul(mul_tree)
949    # keep the stack trace
950    copy_stack_trace(node.outputs[0], out)
951    return [out]
952
953
954@opt.register_stabilize
955@gof.local_optimizer([tensor.inv])
956def local_inv_1_plus_exp(node):
957    """
958    1/(1+exp(x)) -> sigm(-x)
959
960    """
961    # this optimization should be done for numerical stability
962    # so we don't care to check client counts
963    if node.op == tensor.inv:
964        inv_arg = node.inputs[0]
965        if inv_arg.owner and inv_arg.owner.op == tensor.add:
966            scalars, scalar_inputs, nonconsts = \
967                opt.scalarconsts_rest(inv_arg.owner.inputs, only_process_constants=True)
968            # scalar_inputs are potentially dimshuffled and fill'd scalars
969            if len(nonconsts) == 1:
970                if nonconsts[0].owner and nonconsts[0].owner.op == tensor.exp:
971                    if scalars and np.allclose(np.sum(scalars), 1):
972                        out = opt._fill_chain(
973                            sigmoid(
974                                tensor.neg(nonconsts[0].owner.inputs[0])),
975                            scalar_inputs)
976                        # keep combined stack traces of
977                        #     exp(x):           nonconsts[0],
978                        #     1 + exp(x):       inv_arg,
979                        #     1 / (1 + exp(x)): node.outputs[0]
980                        copy_stack_trace(
981                            [nonconsts[0], inv_arg, node.outputs[0]], out)
982                        return out
983
984# Registration is below, and conditional.
985
986
987@gof.local_optimizer([tensor.sub])
988def local_1msigmoid(node):
989    """
990    1-sigm(x) -> sigm(-x)
991
992    """
993    if node.op == tensor.sub:
994        sub_l, sub_r = node.inputs
995        if len(sub_r.clients) > 1:
996            return  # graph is using both sigm and 1-sigm
997        if sub_r.owner and sub_r.owner.op == sigmoid:
998            try:
999                val_l = opt.get_scalar_constant_value(sub_l)
1000            except tensor.NotScalarConstantError:
1001                return
1002            if np.allclose(np.sum(val_l), 1):
1003                out = sigmoid(-sub_r.owner.inputs[0])
1004                copy_stack_trace([sub_r, node.outputs[0]], out)
1005                return [out]
1006
1007register_local_1msigmoid = False
1008# This is False because the Stabilize pattern above
1009# is looking for 1-sigm.  Also Canonizer turns neg into *(-1) and so
1010# this optimization might set off an unwanted chain of things.
1011# OTH - this transformation can be seen as pushing normal arithmetic either  below or above the
1012# sigmoidal nonlinearity... so if the canonicalized form had anything to say about that then it
1013# would be a consideration... anyway leaving False for now.
1014
1015if register_local_1msigmoid:
1016    opt.register_canonicalize(local_1msigmoid)
1017