1import chainer
2from chainer import function_node
3from chainer.utils import type_check
4import chainerx
5
6
7class Cholesky(function_node.FunctionNode):
8
9    @property
10    def label(self):
11        return 'cholesky'
12
13    def check_type_forward(self, in_types):
14        type_check._argname(in_types, ('a', ))
15        a_type, = in_types
16
17        type_check.expect(
18            a_type.dtype.kind == 'f',
19            a_type.ndim == 2,
20        )
21
22    def forward(self, inputs):
23        a, = inputs
24        self.retain_outputs((0,))
25        xp = chainer.backend.get_array_module(a)
26        return xp.linalg.cholesky(a),
27
28    def forward_chainerx(self, inputs):
29        return chainerx.linalg.cholesky(*inputs),
30
31    def backward(self, indexes, grad_outputs):
32        gy, = grad_outputs
33        xp = chainer.backend.get_array_module(gy)
34        y, = self.get_retained_outputs()
35        n = y.shape[0]
36        dtype = y.dtype
37
38        F = chainer.functions
39        y_inv = F.inv(y)
40        mask = xp.tri(n, dtype=dtype) - 0.5 * xp.eye(n, dtype=dtype)
41        phi = mask * F.matmul(y, gy, transa=True)
42        s = F.matmul(F.matmul(y_inv, phi, transa=True), y_inv)
43        gx = 0.5 * (s + s.T)
44        return gx,
45
46
47def cholesky(a):
48    """Cholesky Decomposition
49
50    Args:
51        a (:class:`~chainer.Variable` or :ref:`ndarray`): Input variable.
52
53    Returns:
54        ~chainer.Variable: Output variable.
55    """
56    return Cholesky().apply((a,))[0]
57