1import numpy 2 3import chainer 4from chainer import functions 5 6 7class ScatterGGNNReadout(chainer.Chain): 8 """GGNN submodule for readout part using scatter operation. 9 10 Args: 11 out_dim (int): dimension of output feature vector 12 in_channels (int or None): dimension of feature vector associated to 13 each node. `in_channels` is the total dimension of `h` and `h0`. 14 nobias (bool): If ``True``, then this function does not use 15 the bias 16 activation (~chainer.Function or ~chainer.FunctionNode): 17 activate function for node representation 18 `functions.tanh` was suggested in original paper. 19 activation_agg (~chainer.Function or ~chainer.FunctionNode): 20 activate function for aggregation 21 `functions.tanh` was suggested in original paper. 22 concat_n_info (bool): If ``True``, node information is concated 23 to the result. 24 """ 25 26 def __init__(self, out_dim, in_channels=None, nobias=False, 27 activation=functions.identity, 28 activation_agg=functions.identity, 29 concat_n_info=False): 30 super(ScatterGGNNReadout, self).__init__() 31 self.concat_n_info = concat_n_info 32 if self.concat_n_info: 33 out_dim -= 1 34 with self.init_scope(): 35 self.i_layer = chainer.links.Linear( 36 in_channels, out_dim, nobias=nobias) 37 self.j_layer = chainer.links.Linear( 38 in_channels, out_dim, nobias=nobias) 39 self.out_dim = out_dim 40 self.in_channels = in_channels 41 self.nobias = nobias 42 self.activation = activation 43 self.activation_agg = activation_agg 44 45 def __call__(self, h, batch, h0=None, is_real_node=None): 46 # --- Readout part --- 47 h1 = functions.concat((h, h0), axis=1) if h0 is not None else h 48 49 g1 = functions.sigmoid(self.i_layer(h1)) 50 g2 = self.activation(self.j_layer(h1)) 51 g = g1 * g2 52 53 # sum along node axis 54 y = self.xp.zeros((int(batch[-1]) + 1, self.out_dim), 55 dtype=numpy.float32) 56 y = functions.scatter_add(y, batch, g) 57 y = self.activation_agg(y) 58 59 if self.concat_n_info: 60 n_nodes = self.xp.zeros(y.shape[0], dtype=self.xp.float32) 61 n_nodes = functions.scatter_add(n_nodes, batch, 62 self.xp.ones(batch.shape[0])) 63 y = functions.concat((y, n_nodes.reshape((-1, 1)))) 64 65 return y 66