1import chainer 2from chainer import function_node 3from chainer.utils import type_check 4import chainerx 5 6 7class Cholesky(function_node.FunctionNode): 8 9 @property 10 def label(self): 11 return 'cholesky' 12 13 def check_type_forward(self, in_types): 14 type_check._argname(in_types, ('a', )) 15 a_type, = in_types 16 17 type_check.expect( 18 a_type.dtype.kind == 'f', 19 a_type.ndim == 2, 20 ) 21 22 def forward(self, inputs): 23 a, = inputs 24 self.retain_outputs((0,)) 25 xp = chainer.backend.get_array_module(a) 26 return xp.linalg.cholesky(a), 27 28 def forward_chainerx(self, inputs): 29 return chainerx.linalg.cholesky(*inputs), 30 31 def backward(self, indexes, grad_outputs): 32 gy, = grad_outputs 33 xp = chainer.backend.get_array_module(gy) 34 y, = self.get_retained_outputs() 35 n = y.shape[0] 36 dtype = y.dtype 37 38 F = chainer.functions 39 y_inv = F.inv(y) 40 mask = xp.tri(n, dtype=dtype) - 0.5 * xp.eye(n, dtype=dtype) 41 phi = mask * F.matmul(y, gy, transa=True) 42 s = F.matmul(F.matmul(y_inv, phi, transa=True), y_inv) 43 gx = 0.5 * (s + s.T) 44 return gx, 45 46 47def cholesky(a): 48 """Cholesky Decomposition 49 50 Args: 51 a (:class:`~chainer.Variable` or :ref:`ndarray`): Input variable. 52 53 Returns: 54 ~chainer.Variable: Output variable. 55 """ 56 return Cholesky().apply((a,))[0] 57