1from chainer.functions.activation import swish 2from chainer import initializers 3from chainer import link 4from chainer import variable 5 6 7class Swish(link.Link): 8 9 """Swish activation function as a link. 10 11 Args: 12 beta_shape (tuple of ints or None): Shape of the parameter variable 13 :math:`\\beta`. If ``None``, parameter initialization will be 14 deferred until the first forward data pass at which time the shape 15 will be determined. 16 beta_init (float): Initial value of the parameter variable 17 :math:`\\beta`. 18 19 See the paper for details: `Searching for Activation Functions 20 <https://arxiv.org/abs/1710.05941>`_ 21 22 To try Swish instead of ReLU, replace ``F.relu`` with individual ``Swish`` 23 links registered to the model. For example, the model defined in the 24 `MNIST example 25 <https://github.com/chainer/chainer/tree/master/examples/mnist/train_mnist.py>`_ 26 can be rewritten as follows. 27 28 ReLU version (original):: 29 30 class MLP(chainer.Chain): 31 32 def __init__(self, n_units, n_out): 33 super(MLP, self).__init__() 34 with self.init_scope(): 35 self.l1 = L.Linear(None, n_units) 36 self.l2 = L.Linear(None, n_units) 37 self.l3 = L.Linear(None, n_out) 38 39 def forward(self, x): 40 h1 = F.relu(self.l1(x)) 41 h2 = F.relu(self.l2(h1)) 42 return self.l3(h2) 43 44 Swish version:: 45 46 class MLP(chainer.Chain): 47 48 def __init__(self, n_units, n_out): 49 super(MLP, self).__init__() 50 with self.init_scope(): 51 self.l1 = L.Linear(None, n_units) 52 self.s1 = L.Swish(None) 53 self.l2 = L.Linear(None, n_units) 54 self.s2 = L.Swish(None) 55 self.l3 = L.Linear(None, n_out) 56 57 def forward(self, x): 58 h1 = self.s1(self.l1(x)) 59 h2 = self.s2(self.l2(h1)) 60 return self.l3(h2) 61 62 .. seealso:: 63 See :func:`chainer.functions.swish` for the definition of Swish 64 activation function. 65 66 Attributes: 67 beta (~chainer.Parameter): Parameter variable :math:`\\beta`. 68 69 """ 70 71 def __init__(self, beta_shape, beta_init=1.0): 72 super(Swish, self).__init__() 73 74 with self.init_scope(): 75 if beta_shape is not None: 76 self.beta = variable.Parameter(beta_init, beta_shape) 77 else: 78 beta_init = initializers.Constant(beta_init) 79 self.beta = variable.Parameter(beta_init) 80 81 def forward(self, x): 82 """Applies the Swish activation function. 83 84 Args: 85 x (~chainer.Variable): Input variable. 86 87 Returns: 88 ~chainer.Variable: Output of the Swish activation function. 89 90 """ 91 if self.beta.array is None: 92 self.beta.initialize(x.shape[1:]) 93 94 return swish.swish(x, self.beta) 95