1from chainer.functions.activation import swish
2from chainer import initializers
3from chainer import link
4from chainer import variable
5
6
7class Swish(link.Link):
8
9    """Swish activation function as a link.
10
11    Args:
12        beta_shape (tuple of ints or None): Shape of the parameter variable
13            :math:`\\beta`. If ``None``, parameter initialization will be
14            deferred until the first forward data pass at which time the shape
15            will be determined.
16        beta_init (float): Initial value of the parameter variable
17            :math:`\\beta`.
18
19    See the paper for details: `Searching for Activation Functions
20    <https://arxiv.org/abs/1710.05941>`_
21
22    To try Swish instead of ReLU, replace ``F.relu`` with individual ``Swish``
23    links registered to the model. For example, the model defined in the
24    `MNIST example
25    <https://github.com/chainer/chainer/tree/master/examples/mnist/train_mnist.py>`_
26    can be rewritten as follows.
27
28    ReLU version (original)::
29
30        class MLP(chainer.Chain):
31
32            def __init__(self, n_units, n_out):
33                super(MLP, self).__init__()
34                with self.init_scope():
35                    self.l1 = L.Linear(None, n_units)
36                    self.l2 = L.Linear(None, n_units)
37                    self.l3 = L.Linear(None, n_out)
38
39            def forward(self, x):
40                h1 = F.relu(self.l1(x))
41                h2 = F.relu(self.l2(h1))
42                return self.l3(h2)
43
44    Swish version::
45
46        class MLP(chainer.Chain):
47
48            def __init__(self, n_units, n_out):
49                super(MLP, self).__init__()
50                with self.init_scope():
51                    self.l1 = L.Linear(None, n_units)
52                    self.s1 = L.Swish(None)
53                    self.l2 = L.Linear(None, n_units)
54                    self.s2 = L.Swish(None)
55                    self.l3 = L.Linear(None, n_out)
56
57            def forward(self, x):
58                h1 = self.s1(self.l1(x))
59                h2 = self.s2(self.l2(h1))
60                return self.l3(h2)
61
62    .. seealso::
63        See :func:`chainer.functions.swish` for the definition of Swish
64        activation function.
65
66    Attributes:
67        beta (~chainer.Parameter): Parameter variable :math:`\\beta`.
68
69    """
70
71    def __init__(self, beta_shape, beta_init=1.0):
72        super(Swish, self).__init__()
73
74        with self.init_scope():
75            if beta_shape is not None:
76                self.beta = variable.Parameter(beta_init, beta_shape)
77            else:
78                beta_init = initializers.Constant(beta_init)
79                self.beta = variable.Parameter(beta_init)
80
81    def forward(self, x):
82        """Applies the Swish activation function.
83
84        Args:
85            x (~chainer.Variable): Input variable.
86
87        Returns:
88            ~chainer.Variable: Output of the Swish activation function.
89
90        """
91        if self.beta.array is None:
92            self.beta.initialize(x.shape[1:])
93
94        return swish.swish(x, self.beta)
95