1import numpy
2
3import chainer
4from chainer import functions
5
6
7class ScatterGGNNReadout(chainer.Chain):
8    """GGNN submodule for readout part using scatter operation.
9
10    Args:
11        out_dim (int): dimension of output feature vector
12        in_channels (int or None): dimension of feature vector associated to
13            each node. `in_channels` is the total dimension of `h` and `h0`.
14        nobias (bool): If ``True``, then this function does not use
15            the bias
16        activation (~chainer.Function or ~chainer.FunctionNode):
17            activate function for node representation
18            `functions.tanh` was suggested in original paper.
19        activation_agg (~chainer.Function or ~chainer.FunctionNode):
20            activate function for aggregation
21            `functions.tanh` was suggested in original paper.
22        concat_n_info (bool): If ``True``, node information is concated
23            to the result.
24    """
25
26    def __init__(self, out_dim, in_channels=None, nobias=False,
27                 activation=functions.identity,
28                 activation_agg=functions.identity,
29                 concat_n_info=False):
30        super(ScatterGGNNReadout, self).__init__()
31        self.concat_n_info = concat_n_info
32        if self.concat_n_info:
33            out_dim -= 1
34        with self.init_scope():
35            self.i_layer = chainer.links.Linear(
36                in_channels, out_dim, nobias=nobias)
37            self.j_layer = chainer.links.Linear(
38                in_channels, out_dim, nobias=nobias)
39        self.out_dim = out_dim
40        self.in_channels = in_channels
41        self.nobias = nobias
42        self.activation = activation
43        self.activation_agg = activation_agg
44
45    def __call__(self, h, batch, h0=None, is_real_node=None):
46        # --- Readout part ---
47        h1 = functions.concat((h, h0), axis=1) if h0 is not None else h
48
49        g1 = functions.sigmoid(self.i_layer(h1))
50        g2 = self.activation(self.j_layer(h1))
51        g = g1 * g2
52
53        # sum along node axis
54        y = self.xp.zeros((int(batch[-1]) + 1, self.out_dim),
55                          dtype=numpy.float32)
56        y = functions.scatter_add(y, batch, g)
57        y = self.activation_agg(y)
58
59        if self.concat_n_info:
60            n_nodes = self.xp.zeros(y.shape[0], dtype=self.xp.float32)
61            n_nodes = functions.scatter_add(n_nodes, batch,
62                                            self.xp.ones(batch.shape[0]))
63            y = functions.concat((y, n_nodes.reshape((-1, 1))))
64
65        return y
66