1import numpy as np
2import six
3
4from chainer import backend
5from chainer.backends import cuda
6from chainer import function_node
7from chainer.utils import type_check
8
9
10index_dtype = {t().itemsize: t for t in np.sctypes['int']}
11
12
13def _byte2step(iterable, itemsize):
14    for i in iterable:
15        assert i % itemsize == 0
16    return tuple([i // itemsize for i in iterable])
17
18
19def _step2byte(iterable, itemsize):
20    return tuple([i * itemsize for i in iterable])
21
22
23def _maybe_overlapping_memory(shape, strides):
24    """Returns bool value indicating the array with such shape and strides
25    might have overlapping memory.
26
27    Args:
28    shape (tuple of int): The shape of output.
29    strides (tuple of int): The strides of output, given in the unit of steps.
30    storage_offset (int):
31        The offset between the head of allocated memory and the pointer of
32        first element, given in the unit of steps.
33
34    Returns:
35        bool: Existence of the overlapping memory
36    """
37    max_ptr_in_slice = 0
38    for stride, size in sorted(zip([abs(s) for s in strides], shape)):
39        if stride <= max_ptr_in_slice:
40            return True
41        max_ptr_in_slice += stride * (size - 1)
42    return False
43
44
45def _min_index(shape, strides, storage_offset):
46    """Returns the leftest index in the array (in the unit-steps)
47
48    Args:
49        shape (tuple of int): The shape of output.
50        strides (tuple of int):
51            The strides of output, given in the unit of steps.
52        storage_offset (int):
53            The offset between the head of allocated memory and the pointer of
54            first element, given in the unit of steps.
55
56    Returns:
57        int: The leftest pointer in the array
58    """
59    sh_st_neg = [sh_st for sh_st in zip(shape, strides) if sh_st[1] < 0]
60    if not sh_st_neg:
61        return storage_offset
62    else:
63        return storage_offset + six.moves.reduce(
64            lambda base, sh_st: base + (sh_st[0] - 1) * sh_st[1], sh_st_neg, 0)
65
66
67def _max_index(shape, strides, storage_offset):
68    """Returns the rightest index in the array
69
70    Args:
71        shape (tuple of int): The shape of output.
72        strides (tuple of int): The strides of output, given in unit-steps.
73        storage_offset (int):
74            The offset between the head of allocated memory and the pointer of
75            first element, given in the unit of steps.
76
77    Returns:
78        int: The rightest pointer in the array
79    """
80    sh_st_pos = [sh_st for sh_st in zip(shape, strides) if sh_st[1] > 0]
81    if not sh_st_pos:
82        return storage_offset
83    else:
84        return storage_offset + six.moves.reduce(
85            lambda base, sh_st: base + (sh_st[0] - 1) * sh_st[1], sh_st_pos, 0)
86
87
88def _index_add(augend, indices, addend):
89    """Wrapper of :func:`cupyx.scatter_add` and :func:`numpy.add.at`
90
91    Args:
92        augend (:class:`numpy.ndarray` or :class:`cupy.ndarray`):
93            The array modified in-place.
94        indices (:class:`numpy.ndarray` or :class:`cupy.ndarray`):
95            The indices of ``augend``. The shape is the same to the ``addend``.
96        addend (:class:`numpy.ndarray` or :class:`cupy.ndarray`):
97            The array to be added.
98
99    Returns:
100        None
101    """
102    if isinstance(augend, cuda.ndarray):
103        cuda.cupyx.scatter_add(augend, indices, addend)
104    elif isinstance(augend, np.ndarray):
105        np.add.at(augend, indices, addend)
106
107
108def _get_base_array(array):
109    """Get the founder of :class:`numpy.ndarray`.
110
111    Args:
112        array (:class:`numpy.ndarray`):
113            The view of the base array.
114
115    Returns:
116        :class:`numpy.ndarray`:
117            The base array.
118    """
119    base_array_candidate = array
120    while base_array_candidate.base is not None:
121        base_array_candidate = base_array_candidate.base
122    return base_array_candidate
123
124
125def _stride_array(array, shape, strides, storage_offset):
126    """Wrapper of :func:`numpy.lib.stride_tricks.as_strided`.
127
128    .. note:
129        ``strides`` and ``storage_offset`` is given in the unit of steps
130        instead the unit of bytes. This specification differs from that of
131        :func:`numpy.lib.stride_tricks.as_strided`.
132
133    Args:
134        array (:class:`numpy.ndarray` of :class:`cupy.ndarray`):
135            The base array for the returned view.
136        shape (tuple of int):
137            The shape of the returned view.
138        strides (tuple of int):
139            The strides of the returned view, given in the unit of steps.
140        storage_offset (int):
141            The offset from the leftest pointer of allocated memory to
142            the first element of returned view, given in the unit of steps.
143
144    Returns:
145        :class:`numpy.ndarray` or :class:`cupy.ndarray`:
146            The new view for the base array.
147    """
148
149    min_index = _min_index(shape, strides, storage_offset)
150    max_index = _max_index(shape, strides, storage_offset)
151
152    strides = _step2byte(strides, array.itemsize)
153    storage_offset, = _step2byte((storage_offset,), array.itemsize)
154
155    if min_index < 0:
156        raise ValueError('Out of buffer: too small index was specified')
157
158    if isinstance(array, cuda.ndarray):
159        pooled_memory = array.data.mem
160        if (max_index + 1) * array.itemsize > pooled_memory.size:
161            raise ValueError('Out of buffer: too large index was specified')
162
163        memptr = cuda.cupy.cuda.memory.MemoryPointer(pooled_memory,
164                                                     storage_offset)
165        return cuda.cupy.ndarray(shape, array.dtype, memptr, strides)
166    elif isinstance(array, np.ndarray):
167        base_array = _get_base_array(array)
168        if (max_index + 1) * base_array.itemsize > base_array.nbytes:
169            raise ValueError('Out of buffer: too large index was specified')
170
171        return np.ndarray(shape, base_array.dtype, base_array.data,
172                          storage_offset, strides)
173    else:
174        raise TypeError('Only (np|cp).ndarray is accepted')
175
176
177class TensorGeometry(object):
178    def __init__(self, array):
179        self.shape = array.shape
180        self.strides = _byte2step(array.strides, array.itemsize)
181        if isinstance(array, np.ndarray):
182            base_array = _get_base_array(array)
183            array_ptr = array.__array_interface__['data'][0]
184            base_array_ptr = base_array.__array_interface__['data'][0]
185            offset_bytes = array_ptr - base_array_ptr
186        elif isinstance(array, cuda.ndarray):
187            offset_bytes = array.data.ptr - array.data.mem.ptr
188        else:
189            raise ValueError('only (np|cp).ndarray is supported')
190        self.storage_offset, = _byte2step((offset_bytes,), array.itemsize)
191        self.itemsize = array.itemsize
192
193    @property
194    def ndim(self):
195        return len(self.shape)
196
197
198class AsStrided(function_node.FunctionNode):
199    """Transportation of :func:`torch.Tensor.as_strided`.
200    While :func:`torch.Tensor.as_strided` does not support nagative strides,
201    this implementation does support it.
202    """
203
204    def __init__(self, shape, strides, storage_offset=None):
205        self.shape = shape
206        self.strides = strides
207        self.storage_offset = storage_offset
208        self.input_geometry = None
209
210    def check_type_forward(self, in_types):
211        type_check.expect(in_types.size() == 1)
212
213    def forward(self, inputs):
214        assert len(inputs) > 0
215
216        x = inputs[0]
217
218        self.input_geometry = TensorGeometry(x)
219
220        if self.storage_offset is None:
221            self.storage_offset = self.input_geometry.storage_offset
222
223        return _stride_array(x, self.shape, self.strides, self.storage_offset),
224
225    def backward(self, _, grad_outputs):
226        """Backward computation which calls :class:`AsStridedGrad`.
227
228        .. note:
229            While this implementation is based on *New-Style Function
230            Implementation*, the backward computation does not support
231            double-backpropagation due to *layout agnostic* algorithm (
232            originally named in the note of pytorch).
233        """
234        return AsStridedGrad(self.input_geometry, self.shape, self.strides,
235                             self.storage_offset).apply(grad_outputs)
236
237
238class AsStridedGrad(function_node.FunctionNode):
239    """Backward of :func:`~chainer.functions.as_strided`.
240    """
241
242    def __init__(self, input_geometry, shape, strides, storage_offset):
243        self.input_geometry = input_geometry
244        self.shape = shape
245        self.strides = strides
246        self.storage_offset = storage_offset
247
248    def forward(self, grads):
249        assert len(grads) > 0
250        gy = grads[0]
251
252        if gy.dtype not in np.sctypes['float']:
253            raise TypeError('Only float is supported for back propagation')
254
255        xp = backend.get_array_module(gy)
256        input_geometry = self.input_geometry
257        itemsize = input_geometry.itemsize
258
259        if 0 in input_geometry.shape:
260            return xp.zeros(input_geometry.shape)
261
262        #  1. remove redundant axis from input/output
263        #  [redundant axis]
264        #  axis with shape==0, shape==1 or strides==0
265        if 0 in gy.shape:
266            return backend.get_array_module(gy).zeros(input_geometry.shape)
267        else:
268            out_shape = tuple([
269                self.shape[i] for i in six.moves.range(gy.ndim)
270                if self.shape[i] != 1 and self.strides[i] != 0])
271            out_strides = tuple([
272                self.strides[i] for i in six.moves.range(gy.ndim)
273                if self.shape[i] != 1 and self.strides[i] != 0])
274            gy = gy.sum(
275                tuple([i for i in six.moves.range(gy.ndim)
276                       if self.strides[i] == 0]))
277            gy = gy.squeeze()
278
279        out_storage_offset = self.storage_offset
280
281        inp_shape = tuple([input_geometry.shape[i]
282                           for i in six.moves.range(input_geometry.ndim)
283                           if input_geometry.shape[i] != 1])
284        inp_strides = tuple([input_geometry.strides[i]
285                             for i in six.moves.range(input_geometry.ndim)
286                             if input_geometry.shape[i] != 1])
287        inp_storage_offset = input_geometry.storage_offset
288
289        #  2. calculate minimum required storage for gradient computation
290        inp_min_ptr = _min_index(inp_shape, inp_strides,
291                                 input_geometry.storage_offset)
292        out_min_ptr = _min_index(out_shape, out_strides, self.storage_offset)
293        common_min_ptr = min(inp_min_ptr, out_min_ptr)
294
295        inp_max_ptr = _max_index(inp_shape, inp_strides,
296                                 input_geometry.storage_offset)
297        out_max_ptr = _max_index(out_shape, out_strides, self.storage_offset)
298        common_max_ptr = max(inp_max_ptr, out_max_ptr)
299
300        base_size = (common_max_ptr - common_min_ptr) + 1
301
302        storage = xp.zeros(base_size, dtype=gy.dtype)
303        flatten_full_indices = xp.arange(base_size,
304                                         dtype=index_dtype[itemsize])
305
306        out_maybe_overlap = _maybe_overlapping_memory(out_shape, out_strides)
307
308        if out_maybe_overlap:
309            out_indices = _stride_array(flatten_full_indices, out_shape,
310                                        out_strides,
311                                        out_storage_offset - common_min_ptr)
312            _index_add(storage, out_indices, gy)
313        else:
314            storage_view = _stride_array(storage, out_shape, out_strides,
315                                         out_storage_offset - common_min_ptr)
316            storage_view[:] = gy[:]
317
318        inp_maybe_overlap = _maybe_overlapping_memory(inp_shape, inp_strides)
319        if inp_maybe_overlap:
320            count = xp.zeros_like(storage)
321            inp_indices = _stride_array(flatten_full_indices, inp_shape,
322                                        inp_strides,
323                                        inp_storage_offset - common_min_ptr)
324            _index_add(count, inp_indices, xp.ones(1))
325            with np.errstate(divide='ignore', invalid='ignore'):
326                storage /= count
327
328        return _stride_array(storage, inp_shape, inp_strides,
329                             inp_storage_offset - common_min_ptr),
330
331    def backward(self, target_input_indexes, grad_outputs):
332        raise NotImplementedError
333
334
335def as_strided(x, shape, strides, storage_offset=None):
336    """Create a new view of array with the given shape, strides, and offset.
337
338    Args:
339        x (tuple of :class:`~chainer.Variable` or :class:`numpy.ndarray` or \
340        :class:`cupy.ndarray`):
341            The array pointing a memory buffer. Its view is totally ignored.
342        shape (tuple of int):
343            The shape of output.
344        strides (tuple of int):
345            The strides of output, given in the unit of steps.
346        storage_offset (int):
347            The offset between the head of allocated memory and the pointer of
348            first element, given in the unit of steps.
349
350    Returns:
351        ~chainer.Variable: The strided variable.
352
353    .. warning::
354        Users should be aware that this function potentially causes unintended
355        side effects. See `numpy.lib.stride_tricks.as_strided`_ for the detail.
356
357    .. note::
358        The backward algorithm is borrowed from `torch.Tensor.as_strided`.
359        Therefore, the returned gradient of ``backward`` is *layout-agnostic*
360        when ``x`` contains memory overlap. See notes in pytorch's source
361        code (as_strided Backward and layout-aware/agnostic autograd) too.
362
363    .. note::
364        In this function ``strides`` and ``storage_offset`` are given in the
365        unit of steps instead of bytes. This specification differs from
366        :func:`numpy.lib.stride_tricks.as_strided`.
367
368    .. admonition:: Example
369
370        >>> from chainer import functions as F, Variable
371        >>> x = Variable(np.arange(4, dtype=np.float32))
372        >>> x
373        variable([0., 1., 2., 3.])
374        >>> y = F.as_strided(x, (3, 2), (1, 1), 0)
375        >>> y
376        variable([[0., 1.],
377                  [1., 2.],
378                  [2., 3.]])
379        >>> y.grad = np.ones((3, 2), dtype=np.float32)
380        >>> y.backward()
381        >>> x.grad
382        array([1., 2., 2., 1.], dtype=float32)
383
384    .. _numpy.lib.stride_tricks.as_strided:
385        https://docs.scipy.org/doc/numpy/reference/generated/\
386        numpy.lib.stride_tricks.as_strided.html
387
388    """
389    return AsStrided(shape, strides, storage_offset).apply((x,))[0]
390