1from chainer import function_node
2
3
4class Identity(function_node.FunctionNode):
5
6    """Identity function."""
7
8    def forward(self, xs):
9        return xs
10
11    def backward(self, indexes, gys):
12        return gys
13
14
15def identity(*inputs):
16    """Just returns input variables."""
17    ret = Identity().apply(inputs)
18    return ret[0] if len(ret) == 1 else ret
19