1import numpy
2
3import chainer
4from chainer import backend
5from chainer.backends import cuda
6from chainer import function_node
7from chainer import utils
8from chainer.utils import type_check
9
10try:
11    from scipy import sparse
12    _scipy_available = True
13except ImportError:
14    _scipy_available = False
15
16
17def _coo_matmul(sp_data, sp_row, sp_col, sp_shape, sp_order,
18                dn, transa, transb, transc, dtype=None):
19    if dtype is None:
20        dtype = numpy.result_type(sp_data.dtype, dn.dtype)
21
22    A_data = sp_data
23    if transa:
24        A_row = sp_col
25        A_col = sp_row
26        A_shape = (sp_shape[1], sp_shape[0])
27        if sp_order == 'C':
28            A_order = 'F'
29        elif sp_order == 'F':
30            A_order = 'C'
31        else:
32            A_order = sp_order
33    else:
34        A_row = sp_row
35        A_col = sp_col
36        A_shape = sp_shape
37        A_order = sp_order
38    if transb:
39        B = dn.swapaxes(-1, -2)
40    else:
41        B = dn
42
43    xp = backend.get_array_module(A_data, B)
44    if xp is numpy:
45        C = _coo_matmul_cpu(A_data, A_row, A_col, A_shape, B, dtype)
46    else:
47        C = _coo_matmul_gpu(A_data, A_row, A_col, A_shape, A_order,
48                            B, dtype)
49
50    if transc:
51        C = C.swapaxes(-1, -2)
52    return C
53
54
55def _coo_matmul_cpu(A_data, A_row, A_col, A_shape, B, dtype):
56    # A_shape: (_m, _k)
57    # B.shape: ((nb,) _k, _n)
58    # A_data/row/col.shape: ((nb,) ldnz)
59    if not _scipy_available:
60        msg = 'SciPy seems to be unavailable on your system. A CPU' \
61              ' implementation of sparse_matmul uses SciPy, so you' \
62              ' cannot use sparse_matmul on the CPU.'
63        raise RuntimeError(msg)
64
65    _m, _k = A_shape
66    _n = B.shape[-1]
67    if B.ndim == 2:
68        sp_A = sparse.coo_matrix((A_data, (A_row, A_col)), shape=(_m, _k))
69        C = sp_A.dot(B).astype(dtype, copy=False)
70    else:
71        nb = B.shape[0]
72        C = numpy.empty((nb, _m, _n), dtype=dtype)
73        for i in range(nb):
74            nnz = len(numpy.where(A_row[i] >= 0)[0])
75            sp_A = sparse.coo_matrix((A_data[i, :nnz],
76                                      (A_row[i, :nnz], A_col[i, :nnz])),
77                                     shape=(_m, _k))
78            C[i] = sp_A.dot(B[i]).astype(dtype, copy=False)
79
80    return C
81
82
83def _coo_matmul_gpu(A_data, A_row, A_col, A_shape, A_order, B, dtype):
84    cupy_dtype = dtype
85    if cupy_dtype == numpy.float16:
86        cupy_dtype = numpy.float32
87        # fp32 is used in cupy kernel because fp16 atomicAdd is not supported
88
89    # A_shape: (_m, _k)
90    # B.shape: ((nb,) _k, _n)
91    # A_data/row/col.shape: ((nb,) ldnz)
92    _m, _k = A_shape
93    _n = B.shape[-1]
94    ldnz = A_data.shape[-1]
95    if B.ndim == 2:
96        nb = 1
97        C = cuda.cupy.zeros((_m, _n), dtype=cupy_dtype)
98    else:
99        nb = B.shape[0]
100        C = cuda.cupy.zeros((nb, _m, _n), dtype=cupy_dtype)
101
102    if A_order == 'C':
103        # A chunk is the number of non-zero elements handled by a single GPU
104        # thread. If contiguous non-zero elemets are related to the same
105        # location of the output matrix and they are processed in the same
106        # thread, number of atomic-add operations can be reduced.
107        chunk = max(ldnz // _m, 1)
108    else:
109        chunk = 1
110    nthreads = (nb * ldnz + chunk - 1) // chunk * _n
111    _cupy_coo_matmul()(nb, _m, _n, _k, ldnz, chunk,
112                       A_data, A_row, A_col, B, C,
113                       size=nthreads)
114
115    return C.astype(dtype, copy=False)
116
117
118def _cupy_coo_matmul():
119    utils.nondeterministic('atomicAdd')
120    return cuda.elementwise(
121        'int32 nb, int32 _m, int32 _n, int32 _k, int32 nnz, int32 chunk, \
122         raw A A_data, raw T A_row, raw T A_col, \
123         raw B _B',
124        'raw C _C',
125        '''
126        int i_n = (i % _n);
127        int i0 = (i / _n) * chunk;
128        int i_C = -1;
129        C val_C = 0;
130        for (int i1 = 0; i1 < chunk; i1++) {
131            int i_A = i0 + i1;
132            int i_b = i_A / nnz;
133            if (i_b >= nb) {
134                continue;
135            }
136            int i_k = A_col[i_A];
137            if (i_k < 0) {
138                continue;
139            }
140            assert(i_k < _k);
141            int i_m = A_row[i_A];
142            if (i_m < 0) {
143                continue;
144            }
145            assert(i_m < _m);
146            int i_B = i_n + _n * (i_k + _k * i_b);
147            int i_C_now = i_n + _n * (i_m + _m * i_b);
148            A val_A = A_data[i_A];
149            B val_B = _B[i_B];
150            C val_C_now = static_cast<C>(val_A * val_B);
151            if (i_C >= 0 && i_C != i_C_now) {
152                atomicAdd(&_C[i_C], val_C);
153                val_C = 0;
154            }
155            i_C = i_C_now;
156            val_C += val_C_now;
157        }
158        if (i_C >= 0) {
159            atomicAdd(&_C[i_C], val_C);
160        }
161        ''',
162        'coo_matmul')
163
164
165class CooMatMul(function_node.FunctionNode):
166
167    def __init__(self, sp_row, sp_col, sp_shape, sp_order='other',
168                 transa=False, transb=False, transc=False, dtype=None):
169        if sp_row.ndim != sp_col.ndim:
170            raise ValueError('ndim of sp_row and sp_col must be the same.')
171        if sp_row.ndim != 1 and sp_row.ndim != 2:
172            raise ValueError('ndim of sp_row and sp_col must be one or two.')
173        for i in range(sp_row.ndim):
174            if sp_row.shape[i] != sp_col.shape[i]:
175                msg = 'shape of sp_row and sp_col must be the same.'
176                raise ValueError(msg)
177        if len(sp_shape) != 2:
178            raise ValueError('len(sp_shape) must be two.')
179        self.sp_row = sp_row  # ((nb,) ldnz)
180        self.sp_col = sp_col  # ((nb,) ldnz)
181        self.sp_shape = sp_shape  # (_m, _k) when transa is False
182        self.sp_order = sp_order
183        self.transa = transa
184        self.transb = transb
185        self.transc = transc
186        self.dtype = dtype
187
188    def check_type_forward(self, in_types):
189        type_check._argname(in_types, ('sp', 'dn'))
190        sp_type, dn_type = in_types
191        # sp_type.shape: ((nb,) ldnz)
192        # dn_type.shape: ((nb,) _k, _n) when transb is False
193        sp_k_axis = -1
194        if self.transa:
195            sp_k_axis = -2
196        dn_k_axis = -2
197        if self.transb:
198            dn_k_axis = -1
199        type_check.expect(
200            sp_type.dtype.kind == 'f',
201            dn_type.dtype.kind == 'f',
202            dn_type.ndim >= 2,
203            dn_type.ndim <= 3,
204            sp_type.ndim == dn_type.ndim - 1,
205            sp_type.shape[-1] == self.sp_row.shape[-1],
206            self.sp_shape[sp_k_axis] == dn_type.shape[dn_k_axis],
207        )
208        dn_ndim = type_check.eval(dn_type.ndim)
209        if dn_ndim == 3:
210            type_check.expect(
211                sp_type.shape[0] == self.sp_row.shape[0],
212                dn_type.shape[0] == self.sp_row.shape[0],
213            )
214
215    def forward(self, inputs):
216        self.retain_inputs((0, 1))
217        sp, dn = inputs
218        c = _coo_matmul(sp, self.sp_row, self.sp_col, self.sp_shape,
219                        self.sp_order, dn,
220                        self.transa, self.transb, self.transc, self.dtype)
221        return utils.force_array(c, self.dtype),
222
223    def backward(self, indexes, grad_outputs):
224        sp, dn = self.get_retained_inputs()
225        g_c, = grad_outputs
226        ret = []
227        if 0 in indexes:
228            g_sp = CooMatMulGradSP(self.sp_row, self.sp_col, self.sp_shape,
229                                   self.sp_order,
230                                   self.transc, not self.transb, self.transa,
231                                   dtype=sp.dtype).apply((g_c, dn))[0]
232            ret.append(g_sp)
233        if 1 in indexes:
234            g_dn = CooMatMul(self.sp_row, self.sp_col, self.sp_shape,
235                             self.sp_order,
236                             not self.transa, self.transc, self.transb,
237                             dtype=dn.dtype).apply((sp, g_c))[0]
238            ret.append(g_dn)
239        return ret
240
241
242def _coo_matmul_gradsp(a, b, c_row, c_col, c_shape, transa, transb, transc,
243                       dtype):
244    if dtype is None:
245        dtype = numpy.result_type(a.dtype, b.dtype)
246
247    if transa:
248        A = a.swapaxes(-1, -2)
249    else:
250        A = a
251    if transb:
252        B = b.swapaxes(-1, -2)
253    else:
254        B = b
255    if transc:
256        C_row = c_col
257        C_col = c_row
258    else:
259        C_row = c_row
260        C_col = c_col
261
262    xp = backend.get_array_module(A, B)
263    if xp is numpy:
264        return _coo_matmul_gradsp_cpu(A, B, C_row, C_col, dtype)
265    else:
266        return _coo_matmul_gradsp_gpu(A, B, C_row, C_col, dtype)
267
268
269def _coo_matmul_gradsp_cpu(A, B, C_row, C_col, dtype):
270    # A.shape: ((nb,) _m, _k)
271    # B.shape: ((nb,) _k, _n)
272    # C_row/col.shape: ((nb,) ldnz)
273    _m, _k = A.shape[-2:]
274    ldnz = C_row.shape[-1]
275    if hasattr(numpy, 'matmul'):
276        C = numpy.matmul(A, B)
277    elif A.ndim == 2:
278        C = numpy.dot(A, B)
279    else:
280        C = numpy.einsum('...ij,...jk->...ik', A, B)
281    C = C.astype(dtype, copy=False)
282    if A.ndim == 2:
283        C_data = numpy.zeros((ldnz), dtype=dtype)
284        nnz = len(numpy.where(C_row >= 0)[0])
285        C_data[:nnz] = C[C_row[:nnz], C_col[:nnz]]
286    else:
287        nb = A.shape[0]
288        C_data = numpy.zeros((nb, ldnz), dtype=dtype)
289        for i in range(nb):
290            nnz = len(numpy.where(C_row[i] >= 0)[0])
291            C_data[i, :nnz] = C[i, C_row[i, :nnz], C_col[i, :nnz]]
292
293    return C_data
294
295
296def _coo_matmul_gradsp_gpu(A, B, C_row, C_col, dtype):
297    # A.shape: ((nb,) _m, _k)
298    # B.shape: ((nb,) _k, _n)
299    # C_row/col.shape: ((nb,) ldnz)
300    _m, _k = A.shape[-2:]
301    _n = B.shape[-1]
302    ldnz = C_row.shape[-1]
303    if A.ndim == 2:
304        nb = 1
305        C_data = cuda.cupy.zeros((ldnz), dtype=dtype)
306    else:
307        nb = A.shape[0]
308        C_data = cuda.cupy.zeros((nb, ldnz), dtype=dtype)
309
310    nthreads = nb * ldnz
311    _cupy_coo_matmul_gradsp()(nb, _m, _n, _k, ldnz, A, B, C_row, C_col, C_data,
312                              size=nthreads)
313
314    return C_data
315
316
317def _cupy_coo_matmul_gradsp():
318    return cuda.elementwise(
319        'int32 nb, int32 _m, int32 _n, int32 _k, int32 nnz, \
320         raw A _A, raw B _B, \
321         raw T C_row, raw T C_col',
322        'raw C C_data',
323        '''
324        int i_nz = (i % nnz);
325        int i_b = (i / nnz);
326        if (i_b >= nb) {
327            continue;
328        }
329        int i_C = i;
330        int i_m = C_row[i_C];
331        if (i_m < 0) {
332            continue;
333        }
334        assert(i_m < _m);
335        int i_n = C_col[i_C];
336        if (i_n < 0) {
337            continue;
338        }
339        assert(i_n < _n);
340        C val_C = 0.0;
341        for (int i_k = 0; i_k < _k; i_k++) {
342            int i_A = i_k + _k * (i_m + _m * i_b);
343            int i_B = i_n + _n * (i_k + _k * i_b);
344            A val_A = _A[i_A];
345            B val_B = _B[i_B];
346            val_C += static_cast<C>(val_A * val_B);
347        }
348        C_data[i_C] = val_C;
349        ''',
350        'coo_matmul_gradsp')
351
352
353class CooMatMulGradSP(function_node.FunctionNode):
354
355    def __init__(self, sp_row, sp_col, sp_shape, sp_order='other',
356                 transa=False, transb=False, transc=False,
357                 dtype=None):
358        if sp_row.ndim != sp_col.ndim:
359            raise ValueError('ndim of sp_row and sp_col must be the same.')
360        if sp_row.ndim != 1 and sp_row.ndim != 2:
361            raise ValueError('ndim of sp_row and sp_col must be one or two.')
362        for i in range(sp_row.ndim):
363            if sp_row.shape[i] != sp_col.shape[i]:
364                msg = 'shape of sp_row and sp_col must be the same.'
365                raise ValueError(msg)
366        if len(sp_shape) != 2:
367            raise ValueError('len(sp_shape) must be two.')
368        self.sp_row = sp_row  # ((nb,) ldnz)
369        self.sp_col = sp_col  # ((nb,) ldnz)
370        self.sp_shape = sp_shape  # (_m, _n) when transc is False
371        self.sp_order = sp_order
372        self.transa = transa
373        self.transb = transb
374        self.transc = transc
375        self.dtype = dtype
376
377    def check_type_forward(self, in_types):
378        type_check.expect(in_types.size() == 2)
379        a_type, b_type = in_types
380        # a_type.shape: ((nb,) _m, _k) when transa is False
381        # b_type.shape: ((nb,) _k, _n) when transb is False
382        a_m_axis, a_k_axis = -2, -1
383        b_k_axis, b_n_axis = -2, -1
384        sp_m_axis, sp_n_axis = -2, -1
385        if self.transa:
386            a_m_axis, a_k_axis = -1, -2
387        if self.transb:
388            b_k_axis, b_n_axis = -1, -2
389        if self.transc:
390            sp_m_axis, sp_n_axis = -1, -2
391        type_check.expect(
392            a_type.dtype.kind == 'f',
393            b_type.dtype.kind == 'f',
394            a_type.ndim >= 2,
395            a_type.ndim <= 3,
396            a_type.ndim == b_type.ndim,
397            a_type.shape[a_m_axis] == self.sp_shape[sp_m_axis],
398            b_type.shape[b_n_axis] == self.sp_shape[sp_n_axis],
399            a_type.shape[a_k_axis] == b_type.shape[b_k_axis],
400        )
401        a_ndim = type_check.eval(a_type.ndim)
402        if a_ndim == 3:
403            type_check.expect(
404                a_type.shape[0] == self.sp_row.shape[0],
405                b_type.shape[0] == self.sp_row.shape[0],
406            )
407
408    def forward(self, inputs):
409        self.retain_inputs((0, 1))
410        a, b = inputs
411        c = _coo_matmul_gradsp(a, b, self.sp_row, self.sp_col, self.sp_shape,
412                               self.transa, self.transb, self.transc,
413                               self.dtype)
414        return utils.force_array(c),
415
416    def backward(self, indexes, grad_outputs):
417        a, b = self.get_retained_inputs()
418        g_sp, = grad_outputs
419        ret = []
420        if 0 in indexes:
421            g_a = CooMatMul(self.sp_row, self.sp_col, self.sp_shape,
422                            self.sp_order,
423                            self.transc, not self.transb, self.transa,
424                            dtype=a.dtype).apply((g_sp, b))[0]
425            ret.append(g_a)
426        if 1 in indexes:
427            g_b = CooMatMul(self.sp_row, self.sp_col, self.sp_shape,
428                            self.sp_order,
429                            not self.transc, self.transa, not self.transb,
430                            dtype=b.dtype).apply((g_sp, a))[0]
431            ret.append(g_b)
432        return ret
433
434
435def sparse_matmul(a, b, transa=False, transb=False):
436    """Computes the batched multiplication of sparse and dense matrix.
437
438    The following use cases are supported:
439
440        1. C (dense) = A (sparse) * B (dense)
441        2. C (dense) = A (dense) * B (sparse)
442
443    Args:
444        a (~chainer.Variable or ~chainer.utils.CooMatrix): The left operand of
445            matrix multiplication.
446        b (~chainer.Variable or ~chainer.utils.CooMatrix): The right operand of
447            matrix multiplication.
448        transa (bool): If ``True``, each matrix in ``a`` will be transposed.
449        transb (bool): If ``True``, each matrix in ``b`` will be transposed.
450
451    Returns:
452        ~chainer.Variable: Result of batched mat-mul.
453
454    .. seealso::
455        See :func:`~chainer.utils.to_coo` for how to construct a COO matrix
456        from an array.
457
458    .. note::
459        Performance of this function on GPU can be improved by using the
460        ``order`` argument of :class:`~chainer.utils.CooMatrix` when the sparse
461        matrix is created.
462
463    """
464    if (isinstance(a, utils.CooMatrix) and
465            isinstance(b, (chainer.Variable, numpy.ndarray, cuda.ndarray))):
466        return CooMatMul(a.row, a.col, a.shape, a.order,
467                         transa=transa,
468                         transb=transb,
469                         transc=False).apply((a.data, b))[0]
470    elif (isinstance(a, (chainer.Variable, numpy.ndarray, cuda.ndarray)) and
471          isinstance(b, utils.CooMatrix)):
472        return CooMatMul(b.row, b.col, b.shape, b.order,
473                         transa=not transb,
474                         transb=not transa,
475                         transc=True).apply((b.data, a))[0]
476    else:
477        msg = 'This combination of type of inputs is not supported.\n'
478        msg += '    a: {}\n'.format(type(a))
479        msg += '    b: {}\n'.format(type(b))
480        raise ValueError(msg)
481