1import functools
2import sys
3import math
4import warnings
5
6import numpy.core.numeric as _nx
7from numpy.core.numeric import (
8    asarray, ScalarType, array, alltrue, cumprod, arange, ndim
9    )
10from numpy.core.numerictypes import find_common_type, issubdtype
11
12import numpy.matrixlib as matrixlib
13from .function_base import diff
14from numpy.core.multiarray import ravel_multi_index, unravel_index
15from numpy.core.overrides import set_module
16from numpy.core import overrides, linspace
17from numpy.lib.stride_tricks import as_strided
18
19
20array_function_dispatch = functools.partial(
21    overrides.array_function_dispatch, module='numpy')
22
23
24__all__ = [
25    'ravel_multi_index', 'unravel_index', 'mgrid', 'ogrid', 'r_', 'c_',
26    's_', 'index_exp', 'ix_', 'ndenumerate', 'ndindex', 'fill_diagonal',
27    'diag_indices', 'diag_indices_from'
28    ]
29
30
31def _ix__dispatcher(*args):
32    return args
33
34
35@array_function_dispatch(_ix__dispatcher)
36def ix_(*args):
37    """
38    Construct an open mesh from multiple sequences.
39
40    This function takes N 1-D sequences and returns N outputs with N
41    dimensions each, such that the shape is 1 in all but one dimension
42    and the dimension with the non-unit shape value cycles through all
43    N dimensions.
44
45    Using `ix_` one can quickly construct index arrays that will index
46    the cross product. ``a[np.ix_([1,3],[2,5])]`` returns the array
47    ``[[a[1,2] a[1,5]], [a[3,2] a[3,5]]]``.
48
49    Parameters
50    ----------
51    args : 1-D sequences
52        Each sequence should be of integer or boolean type.
53        Boolean sequences will be interpreted as boolean masks for the
54        corresponding dimension (equivalent to passing in
55        ``np.nonzero(boolean_sequence)``).
56
57    Returns
58    -------
59    out : tuple of ndarrays
60        N arrays with N dimensions each, with N the number of input
61        sequences. Together these arrays form an open mesh.
62
63    See Also
64    --------
65    ogrid, mgrid, meshgrid
66
67    Examples
68    --------
69    >>> a = np.arange(10).reshape(2, 5)
70    >>> a
71    array([[0, 1, 2, 3, 4],
72           [5, 6, 7, 8, 9]])
73    >>> ixgrid = np.ix_([0, 1], [2, 4])
74    >>> ixgrid
75    (array([[0],
76           [1]]), array([[2, 4]]))
77    >>> ixgrid[0].shape, ixgrid[1].shape
78    ((2, 1), (1, 2))
79    >>> a[ixgrid]
80    array([[2, 4],
81           [7, 9]])
82
83    >>> ixgrid = np.ix_([True, True], [2, 4])
84    >>> a[ixgrid]
85    array([[2, 4],
86           [7, 9]])
87    >>> ixgrid = np.ix_([True, True], [False, False, True, False, True])
88    >>> a[ixgrid]
89    array([[2, 4],
90           [7, 9]])
91
92    """
93    out = []
94    nd = len(args)
95    for k, new in enumerate(args):
96        if not isinstance(new, _nx.ndarray):
97            new = asarray(new)
98            if new.size == 0:
99                # Explicitly type empty arrays to avoid float default
100                new = new.astype(_nx.intp)
101        if new.ndim != 1:
102            raise ValueError("Cross index must be 1 dimensional")
103        if issubdtype(new.dtype, _nx.bool_):
104            new, = new.nonzero()
105        new = new.reshape((1,)*k + (new.size,) + (1,)*(nd-k-1))
106        out.append(new)
107    return tuple(out)
108
109class nd_grid:
110    """
111    Construct a multi-dimensional "meshgrid".
112
113    ``grid = nd_grid()`` creates an instance which will return a mesh-grid
114    when indexed.  The dimension and number of the output arrays are equal
115    to the number of indexing dimensions.  If the step length is not a
116    complex number, then the stop is not inclusive.
117
118    However, if the step length is a **complex number** (e.g. 5j), then the
119    integer part of its magnitude is interpreted as specifying the
120    number of points to create between the start and stop values, where
121    the stop value **is inclusive**.
122
123    If instantiated with an argument of ``sparse=True``, the mesh-grid is
124    open (or not fleshed out) so that only one-dimension of each returned
125    argument is greater than 1.
126
127    Parameters
128    ----------
129    sparse : bool, optional
130        Whether the grid is sparse or not. Default is False.
131
132    Notes
133    -----
134    Two instances of `nd_grid` are made available in the NumPy namespace,
135    `mgrid` and `ogrid`, approximately defined as::
136
137        mgrid = nd_grid(sparse=False)
138        ogrid = nd_grid(sparse=True)
139
140    Users should use these pre-defined instances instead of using `nd_grid`
141    directly.
142    """
143
144    def __init__(self, sparse=False):
145        self.sparse = sparse
146
147    def __getitem__(self, key):
148        try:
149            size = []
150            typ = int
151            for k in range(len(key)):
152                step = key[k].step
153                start = key[k].start
154                if start is None:
155                    start = 0
156                if step is None:
157                    step = 1
158                if isinstance(step, (_nx.complexfloating, complex)):
159                    size.append(int(abs(step)))
160                    typ = float
161                else:
162                    size.append(
163                        int(math.ceil((key[k].stop - start)/(step*1.0))))
164                if (isinstance(step, (_nx.floating, float)) or
165                        isinstance(start, (_nx.floating, float)) or
166                        isinstance(key[k].stop, (_nx.floating, float))):
167                    typ = float
168            if self.sparse:
169                nn = [_nx.arange(_x, dtype=_t)
170                        for _x, _t in zip(size, (typ,)*len(size))]
171            else:
172                nn = _nx.indices(size, typ)
173            for k in range(len(size)):
174                step = key[k].step
175                start = key[k].start
176                if start is None:
177                    start = 0
178                if step is None:
179                    step = 1
180                if isinstance(step, (_nx.complexfloating, complex)):
181                    step = int(abs(step))
182                    if step != 1:
183                        step = (key[k].stop - start)/float(step-1)
184                nn[k] = (nn[k]*step+start)
185            if self.sparse:
186                slobj = [_nx.newaxis]*len(size)
187                for k in range(len(size)):
188                    slobj[k] = slice(None, None)
189                    nn[k] = nn[k][tuple(slobj)]
190                    slobj[k] = _nx.newaxis
191            return nn
192        except (IndexError, TypeError):
193            step = key.step
194            stop = key.stop
195            start = key.start
196            if start is None:
197                start = 0
198            if isinstance(step, (_nx.complexfloating, complex)):
199                step = abs(step)
200                length = int(step)
201                if step != 1:
202                    step = (key.stop-start)/float(step-1)
203                stop = key.stop + step
204                return _nx.arange(0, length, 1, float)*step + start
205            else:
206                return _nx.arange(start, stop, step)
207
208
209class MGridClass(nd_grid):
210    """
211    `nd_grid` instance which returns a dense multi-dimensional "meshgrid".
212
213    An instance of `numpy.lib.index_tricks.nd_grid` which returns an dense
214    (or fleshed out) mesh-grid when indexed, so that each returned argument
215    has the same shape.  The dimensions and number of the output arrays are
216    equal to the number of indexing dimensions.  If the step length is not a
217    complex number, then the stop is not inclusive.
218
219    However, if the step length is a **complex number** (e.g. 5j), then
220    the integer part of its magnitude is interpreted as specifying the
221    number of points to create between the start and stop values, where
222    the stop value **is inclusive**.
223
224    Returns
225    -------
226    mesh-grid `ndarrays` all of the same dimensions
227
228    See Also
229    --------
230    numpy.lib.index_tricks.nd_grid : class of `ogrid` and `mgrid` objects
231    ogrid : like mgrid but returns open (not fleshed out) mesh grids
232    r_ : array concatenator
233
234    Examples
235    --------
236    >>> np.mgrid[0:5,0:5]
237    array([[[0, 0, 0, 0, 0],
238            [1, 1, 1, 1, 1],
239            [2, 2, 2, 2, 2],
240            [3, 3, 3, 3, 3],
241            [4, 4, 4, 4, 4]],
242           [[0, 1, 2, 3, 4],
243            [0, 1, 2, 3, 4],
244            [0, 1, 2, 3, 4],
245            [0, 1, 2, 3, 4],
246            [0, 1, 2, 3, 4]]])
247    >>> np.mgrid[-1:1:5j]
248    array([-1. , -0.5,  0. ,  0.5,  1. ])
249
250    """
251    def __init__(self):
252        super(MGridClass, self).__init__(sparse=False)
253
254mgrid = MGridClass()
255
256class OGridClass(nd_grid):
257    """
258    `nd_grid` instance which returns an open multi-dimensional "meshgrid".
259
260    An instance of `numpy.lib.index_tricks.nd_grid` which returns an open
261    (i.e. not fleshed out) mesh-grid when indexed, so that only one dimension
262    of each returned array is greater than 1.  The dimension and number of the
263    output arrays are equal to the number of indexing dimensions.  If the step
264    length is not a complex number, then the stop is not inclusive.
265
266    However, if the step length is a **complex number** (e.g. 5j), then
267    the integer part of its magnitude is interpreted as specifying the
268    number of points to create between the start and stop values, where
269    the stop value **is inclusive**.
270
271    Returns
272    -------
273    mesh-grid
274        `ndarrays` with only one dimension not equal to 1
275
276    See Also
277    --------
278    np.lib.index_tricks.nd_grid : class of `ogrid` and `mgrid` objects
279    mgrid : like `ogrid` but returns dense (or fleshed out) mesh grids
280    r_ : array concatenator
281
282    Examples
283    --------
284    >>> from numpy import ogrid
285    >>> ogrid[-1:1:5j]
286    array([-1. , -0.5,  0. ,  0.5,  1. ])
287    >>> ogrid[0:5,0:5]
288    [array([[0],
289            [1],
290            [2],
291            [3],
292            [4]]), array([[0, 1, 2, 3, 4]])]
293
294    """
295    def __init__(self):
296        super(OGridClass, self).__init__(sparse=True)
297
298ogrid = OGridClass()
299
300
301class AxisConcatenator:
302    """
303    Translates slice objects to concatenation along an axis.
304
305    For detailed documentation on usage, see `r_`.
306    """
307    # allow ma.mr_ to override this
308    concatenate = staticmethod(_nx.concatenate)
309    makemat = staticmethod(matrixlib.matrix)
310
311    def __init__(self, axis=0, matrix=False, ndmin=1, trans1d=-1):
312        self.axis = axis
313        self.matrix = matrix
314        self.trans1d = trans1d
315        self.ndmin = ndmin
316
317    def __getitem__(self, key):
318        # handle matrix builder syntax
319        if isinstance(key, str):
320            frame = sys._getframe().f_back
321            mymat = matrixlib.bmat(key, frame.f_globals, frame.f_locals)
322            return mymat
323
324        if not isinstance(key, tuple):
325            key = (key,)
326
327        # copy attributes, since they can be overridden in the first argument
328        trans1d = self.trans1d
329        ndmin = self.ndmin
330        matrix = self.matrix
331        axis = self.axis
332
333        objs = []
334        scalars = []
335        arraytypes = []
336        scalartypes = []
337
338        for k, item in enumerate(key):
339            scalar = False
340            if isinstance(item, slice):
341                step = item.step
342                start = item.start
343                stop = item.stop
344                if start is None:
345                    start = 0
346                if step is None:
347                    step = 1
348                if isinstance(step, (_nx.complexfloating, complex)):
349                    size = int(abs(step))
350                    newobj = linspace(start, stop, num=size)
351                else:
352                    newobj = _nx.arange(start, stop, step)
353                if ndmin > 1:
354                    newobj = array(newobj, copy=False, ndmin=ndmin)
355                    if trans1d != -1:
356                        newobj = newobj.swapaxes(-1, trans1d)
357            elif isinstance(item, str):
358                if k != 0:
359                    raise ValueError("special directives must be the "
360                            "first entry.")
361                if item in ('r', 'c'):
362                    matrix = True
363                    col = (item == 'c')
364                    continue
365                if ',' in item:
366                    vec = item.split(',')
367                    try:
368                        axis, ndmin = [int(x) for x in vec[:2]]
369                        if len(vec) == 3:
370                            trans1d = int(vec[2])
371                        continue
372                    except Exception as e:
373                        raise ValueError(
374                            "unknown special directive {!r}".format(item)
375                        ) from e
376                try:
377                    axis = int(item)
378                    continue
379                except (ValueError, TypeError):
380                    raise ValueError("unknown special directive")
381            elif type(item) in ScalarType:
382                newobj = array(item, ndmin=ndmin)
383                scalars.append(len(objs))
384                scalar = True
385                scalartypes.append(newobj.dtype)
386            else:
387                item_ndim = ndim(item)
388                newobj = array(item, copy=False, subok=True, ndmin=ndmin)
389                if trans1d != -1 and item_ndim < ndmin:
390                    k2 = ndmin - item_ndim
391                    k1 = trans1d
392                    if k1 < 0:
393                        k1 += k2 + 1
394                    defaxes = list(range(ndmin))
395                    axes = defaxes[:k1] + defaxes[k2:] + defaxes[k1:k2]
396                    newobj = newobj.transpose(axes)
397            objs.append(newobj)
398            if not scalar and isinstance(newobj, _nx.ndarray):
399                arraytypes.append(newobj.dtype)
400
401        # Ensure that scalars won't up-cast unless warranted
402        final_dtype = find_common_type(arraytypes, scalartypes)
403        if final_dtype is not None:
404            for k in scalars:
405                objs[k] = objs[k].astype(final_dtype)
406
407        res = self.concatenate(tuple(objs), axis=axis)
408
409        if matrix:
410            oldndim = res.ndim
411            res = self.makemat(res)
412            if oldndim == 1 and col:
413                res = res.T
414        return res
415
416    def __len__(self):
417        return 0
418
419# separate classes are used here instead of just making r_ = concatentor(0),
420# etc. because otherwise we couldn't get the doc string to come out right
421# in help(r_)
422
423class RClass(AxisConcatenator):
424    """
425    Translates slice objects to concatenation along the first axis.
426
427    This is a simple way to build up arrays quickly. There are two use cases.
428
429    1. If the index expression contains comma separated arrays, then stack
430       them along their first axis.
431    2. If the index expression contains slice notation or scalars then create
432       a 1-D array with a range indicated by the slice notation.
433
434    If slice notation is used, the syntax ``start:stop:step`` is equivalent
435    to ``np.arange(start, stop, step)`` inside of the brackets. However, if
436    ``step`` is an imaginary number (i.e. 100j) then its integer portion is
437    interpreted as a number-of-points desired and the start and stop are
438    inclusive. In other words ``start:stop:stepj`` is interpreted as
439    ``np.linspace(start, stop, step, endpoint=1)`` inside of the brackets.
440    After expansion of slice notation, all comma separated sequences are
441    concatenated together.
442
443    Optional character strings placed as the first element of the index
444    expression can be used to change the output. The strings 'r' or 'c' result
445    in matrix output. If the result is 1-D and 'r' is specified a 1 x N (row)
446    matrix is produced. If the result is 1-D and 'c' is specified, then a N x 1
447    (column) matrix is produced. If the result is 2-D then both provide the
448    same matrix result.
449
450    A string integer specifies which axis to stack multiple comma separated
451    arrays along. A string of two comma-separated integers allows indication
452    of the minimum number of dimensions to force each entry into as the
453    second integer (the axis to concatenate along is still the first integer).
454
455    A string with three comma-separated integers allows specification of the
456    axis to concatenate along, the minimum number of dimensions to force the
457    entries to, and which axis should contain the start of the arrays which
458    are less than the specified number of dimensions. In other words the third
459    integer allows you to specify where the 1's should be placed in the shape
460    of the arrays that have their shapes upgraded. By default, they are placed
461    in the front of the shape tuple. The third argument allows you to specify
462    where the start of the array should be instead. Thus, a third argument of
463    '0' would place the 1's at the end of the array shape. Negative integers
464    specify where in the new shape tuple the last dimension of upgraded arrays
465    should be placed, so the default is '-1'.
466
467    Parameters
468    ----------
469    Not a function, so takes no parameters
470
471
472    Returns
473    -------
474    A concatenated ndarray or matrix.
475
476    See Also
477    --------
478    concatenate : Join a sequence of arrays along an existing axis.
479    c_ : Translates slice objects to concatenation along the second axis.
480
481    Examples
482    --------
483    >>> np.r_[np.array([1,2,3]), 0, 0, np.array([4,5,6])]
484    array([1, 2, 3, ..., 4, 5, 6])
485    >>> np.r_[-1:1:6j, [0]*3, 5, 6]
486    array([-1. , -0.6, -0.2,  0.2,  0.6,  1. ,  0. ,  0. ,  0. ,  5. ,  6. ])
487
488    String integers specify the axis to concatenate along or the minimum
489    number of dimensions to force entries into.
490
491    >>> a = np.array([[0, 1, 2], [3, 4, 5]])
492    >>> np.r_['-1', a, a] # concatenate along last axis
493    array([[0, 1, 2, 0, 1, 2],
494           [3, 4, 5, 3, 4, 5]])
495    >>> np.r_['0,2', [1,2,3], [4,5,6]] # concatenate along first axis, dim>=2
496    array([[1, 2, 3],
497           [4, 5, 6]])
498
499    >>> np.r_['0,2,0', [1,2,3], [4,5,6]]
500    array([[1],
501           [2],
502           [3],
503           [4],
504           [5],
505           [6]])
506    >>> np.r_['1,2,0', [1,2,3], [4,5,6]]
507    array([[1, 4],
508           [2, 5],
509           [3, 6]])
510
511    Using 'r' or 'c' as a first string argument creates a matrix.
512
513    >>> np.r_['r',[1,2,3], [4,5,6]]
514    matrix([[1, 2, 3, 4, 5, 6]])
515
516    """
517
518    def __init__(self):
519        AxisConcatenator.__init__(self, 0)
520
521r_ = RClass()
522
523class CClass(AxisConcatenator):
524    """
525    Translates slice objects to concatenation along the second axis.
526
527    This is short-hand for ``np.r_['-1,2,0', index expression]``, which is
528    useful because of its common occurrence. In particular, arrays will be
529    stacked along their last axis after being upgraded to at least 2-D with
530    1's post-pended to the shape (column vectors made out of 1-D arrays).
531
532    See Also
533    --------
534    column_stack : Stack 1-D arrays as columns into a 2-D array.
535    r_ : For more detailed documentation.
536
537    Examples
538    --------
539    >>> np.c_[np.array([1,2,3]), np.array([4,5,6])]
540    array([[1, 4],
541           [2, 5],
542           [3, 6]])
543    >>> np.c_[np.array([[1,2,3]]), 0, 0, np.array([[4,5,6]])]
544    array([[1, 2, 3, ..., 4, 5, 6]])
545
546    """
547
548    def __init__(self):
549        AxisConcatenator.__init__(self, -1, ndmin=2, trans1d=0)
550
551
552c_ = CClass()
553
554
555@set_module('numpy')
556class ndenumerate:
557    """
558    Multidimensional index iterator.
559
560    Return an iterator yielding pairs of array coordinates and values.
561
562    Parameters
563    ----------
564    arr : ndarray
565      Input array.
566
567    See Also
568    --------
569    ndindex, flatiter
570
571    Examples
572    --------
573    >>> a = np.array([[1, 2], [3, 4]])
574    >>> for index, x in np.ndenumerate(a):
575    ...     print(index, x)
576    (0, 0) 1
577    (0, 1) 2
578    (1, 0) 3
579    (1, 1) 4
580
581    """
582
583    def __init__(self, arr):
584        self.iter = asarray(arr).flat
585
586    def __next__(self):
587        """
588        Standard iterator method, returns the index tuple and array value.
589
590        Returns
591        -------
592        coords : tuple of ints
593            The indices of the current iteration.
594        val : scalar
595            The array element of the current iteration.
596
597        """
598        return self.iter.coords, next(self.iter)
599
600    def __iter__(self):
601        return self
602
603
604@set_module('numpy')
605class ndindex:
606    """
607    An N-dimensional iterator object to index arrays.
608
609    Given the shape of an array, an `ndindex` instance iterates over
610    the N-dimensional index of the array. At each iteration a tuple
611    of indices is returned, the last dimension is iterated over first.
612
613    Parameters
614    ----------
615    shape : ints, or a single tuple of ints
616        The size of each dimension of the array can be passed as
617        individual parameters or as the elements of a tuple.
618
619    See Also
620    --------
621    ndenumerate, flatiter
622
623    Examples
624    --------
625    # dimensions as individual arguments
626    >>> for index in np.ndindex(3, 2, 1):
627    ...     print(index)
628    (0, 0, 0)
629    (0, 1, 0)
630    (1, 0, 0)
631    (1, 1, 0)
632    (2, 0, 0)
633    (2, 1, 0)
634
635    # same dimensions - but in a tuple (3, 2, 1)
636    >>> for index in np.ndindex((3, 2, 1)):
637    ...     print(index)
638    (0, 0, 0)
639    (0, 1, 0)
640    (1, 0, 0)
641    (1, 1, 0)
642    (2, 0, 0)
643    (2, 1, 0)
644
645    """
646
647    def __init__(self, *shape):
648        if len(shape) == 1 and isinstance(shape[0], tuple):
649            shape = shape[0]
650        x = as_strided(_nx.zeros(1), shape=shape,
651                       strides=_nx.zeros_like(shape))
652        self._it = _nx.nditer(x, flags=['multi_index', 'zerosize_ok'],
653                              order='C')
654
655    def __iter__(self):
656        return self
657
658    def ndincr(self):
659        """
660        Increment the multi-dimensional index by one.
661
662        This method is for backward compatibility only: do not use.
663
664        .. deprecated:: 1.20.0
665            This method has been advised against since numpy 1.8.0, but only
666            started emitting DeprecationWarning as of this version.
667        """
668        # NumPy 1.20.0, 2020-09-08
669        warnings.warn(
670            "`ndindex.ndincr()` is deprecated, use `next(ndindex)` instead",
671            DeprecationWarning, stacklevel=2)
672        next(self)
673
674    def __next__(self):
675        """
676        Standard iterator method, updates the index and returns the index
677        tuple.
678
679        Returns
680        -------
681        val : tuple of ints
682            Returns a tuple containing the indices of the current
683            iteration.
684
685        """
686        next(self._it)
687        return self._it.multi_index
688
689
690# You can do all this with slice() plus a few special objects,
691# but there's a lot to remember. This version is simpler because
692# it uses the standard array indexing syntax.
693#
694# Written by Konrad Hinsen <hinsen@cnrs-orleans.fr>
695# last revision: 1999-7-23
696#
697# Cosmetic changes by T. Oliphant 2001
698#
699#
700
701class IndexExpression:
702    """
703    A nicer way to build up index tuples for arrays.
704
705    .. note::
706       Use one of the two predefined instances `index_exp` or `s_`
707       rather than directly using `IndexExpression`.
708
709    For any index combination, including slicing and axis insertion,
710    ``a[indices]`` is the same as ``a[np.index_exp[indices]]`` for any
711    array `a`. However, ``np.index_exp[indices]`` can be used anywhere
712    in Python code and returns a tuple of slice objects that can be
713    used in the construction of complex index expressions.
714
715    Parameters
716    ----------
717    maketuple : bool
718        If True, always returns a tuple.
719
720    See Also
721    --------
722    index_exp : Predefined instance that always returns a tuple:
723       `index_exp = IndexExpression(maketuple=True)`.
724    s_ : Predefined instance without tuple conversion:
725       `s_ = IndexExpression(maketuple=False)`.
726
727    Notes
728    -----
729    You can do all this with `slice()` plus a few special objects,
730    but there's a lot to remember and this version is simpler because
731    it uses the standard array indexing syntax.
732
733    Examples
734    --------
735    >>> np.s_[2::2]
736    slice(2, None, 2)
737    >>> np.index_exp[2::2]
738    (slice(2, None, 2),)
739
740    >>> np.array([0, 1, 2, 3, 4])[np.s_[2::2]]
741    array([2, 4])
742
743    """
744
745    def __init__(self, maketuple):
746        self.maketuple = maketuple
747
748    def __getitem__(self, item):
749        if self.maketuple and not isinstance(item, tuple):
750            return (item,)
751        else:
752            return item
753
754index_exp = IndexExpression(maketuple=True)
755s_ = IndexExpression(maketuple=False)
756
757# End contribution from Konrad.
758
759
760# The following functions complement those in twodim_base, but are
761# applicable to N-dimensions.
762
763
764def _fill_diagonal_dispatcher(a, val, wrap=None):
765    return (a,)
766
767
768@array_function_dispatch(_fill_diagonal_dispatcher)
769def fill_diagonal(a, val, wrap=False):
770    """Fill the main diagonal of the given array of any dimensionality.
771
772    For an array `a` with ``a.ndim >= 2``, the diagonal is the list of
773    locations with indices ``a[i, ..., i]`` all identical. This function
774    modifies the input array in-place, it does not return a value.
775
776    Parameters
777    ----------
778    a : array, at least 2-D.
779      Array whose diagonal is to be filled, it gets modified in-place.
780
781    val : scalar or array_like
782      Value(s) to write on the diagonal. If `val` is scalar, the value is
783      written along the diagonal. If array-like, the flattened `val` is
784      written along the diagonal, repeating if necessary to fill all
785      diagonal entries.
786
787    wrap : bool
788      For tall matrices in NumPy version up to 1.6.2, the
789      diagonal "wrapped" after N columns. You can have this behavior
790      with this option. This affects only tall matrices.
791
792    See also
793    --------
794    diag_indices, diag_indices_from
795
796    Notes
797    -----
798    .. versionadded:: 1.4.0
799
800    This functionality can be obtained via `diag_indices`, but internally
801    this version uses a much faster implementation that never constructs the
802    indices and uses simple slicing.
803
804    Examples
805    --------
806    >>> a = np.zeros((3, 3), int)
807    >>> np.fill_diagonal(a, 5)
808    >>> a
809    array([[5, 0, 0],
810           [0, 5, 0],
811           [0, 0, 5]])
812
813    The same function can operate on a 4-D array:
814
815    >>> a = np.zeros((3, 3, 3, 3), int)
816    >>> np.fill_diagonal(a, 4)
817
818    We only show a few blocks for clarity:
819
820    >>> a[0, 0]
821    array([[4, 0, 0],
822           [0, 0, 0],
823           [0, 0, 0]])
824    >>> a[1, 1]
825    array([[0, 0, 0],
826           [0, 4, 0],
827           [0, 0, 0]])
828    >>> a[2, 2]
829    array([[0, 0, 0],
830           [0, 0, 0],
831           [0, 0, 4]])
832
833    The wrap option affects only tall matrices:
834
835    >>> # tall matrices no wrap
836    >>> a = np.zeros((5, 3), int)
837    >>> np.fill_diagonal(a, 4)
838    >>> a
839    array([[4, 0, 0],
840           [0, 4, 0],
841           [0, 0, 4],
842           [0, 0, 0],
843           [0, 0, 0]])
844
845    >>> # tall matrices wrap
846    >>> a = np.zeros((5, 3), int)
847    >>> np.fill_diagonal(a, 4, wrap=True)
848    >>> a
849    array([[4, 0, 0],
850           [0, 4, 0],
851           [0, 0, 4],
852           [0, 0, 0],
853           [4, 0, 0]])
854
855    >>> # wide matrices
856    >>> a = np.zeros((3, 5), int)
857    >>> np.fill_diagonal(a, 4, wrap=True)
858    >>> a
859    array([[4, 0, 0, 0, 0],
860           [0, 4, 0, 0, 0],
861           [0, 0, 4, 0, 0]])
862
863    The anti-diagonal can be filled by reversing the order of elements
864    using either `numpy.flipud` or `numpy.fliplr`.
865
866    >>> a = np.zeros((3, 3), int);
867    >>> np.fill_diagonal(np.fliplr(a), [1,2,3])  # Horizontal flip
868    >>> a
869    array([[0, 0, 1],
870           [0, 2, 0],
871           [3, 0, 0]])
872    >>> np.fill_diagonal(np.flipud(a), [1,2,3])  # Vertical flip
873    >>> a
874    array([[0, 0, 3],
875           [0, 2, 0],
876           [1, 0, 0]])
877
878    Note that the order in which the diagonal is filled varies depending
879    on the flip function.
880    """
881    if a.ndim < 2:
882        raise ValueError("array must be at least 2-d")
883    end = None
884    if a.ndim == 2:
885        # Explicit, fast formula for the common case.  For 2-d arrays, we
886        # accept rectangular ones.
887        step = a.shape[1] + 1
888        #This is needed to don't have tall matrix have the diagonal wrap.
889        if not wrap:
890            end = a.shape[1] * a.shape[1]
891    else:
892        # For more than d=2, the strided formula is only valid for arrays with
893        # all dimensions equal, so we check first.
894        if not alltrue(diff(a.shape) == 0):
895            raise ValueError("All dimensions of input must be of equal length")
896        step = 1 + (cumprod(a.shape[:-1])).sum()
897
898    # Write the value out into the diagonal.
899    a.flat[:end:step] = val
900
901
902@set_module('numpy')
903def diag_indices(n, ndim=2):
904    """
905    Return the indices to access the main diagonal of an array.
906
907    This returns a tuple of indices that can be used to access the main
908    diagonal of an array `a` with ``a.ndim >= 2`` dimensions and shape
909    (n, n, ..., n). For ``a.ndim = 2`` this is the usual diagonal, for
910    ``a.ndim > 2`` this is the set of indices to access ``a[i, i, ..., i]``
911    for ``i = [0..n-1]``.
912
913    Parameters
914    ----------
915    n : int
916      The size, along each dimension, of the arrays for which the returned
917      indices can be used.
918
919    ndim : int, optional
920      The number of dimensions.
921
922    See Also
923    --------
924    diag_indices_from
925
926    Notes
927    -----
928    .. versionadded:: 1.4.0
929
930    Examples
931    --------
932    Create a set of indices to access the diagonal of a (4, 4) array:
933
934    >>> di = np.diag_indices(4)
935    >>> di
936    (array([0, 1, 2, 3]), array([0, 1, 2, 3]))
937    >>> a = np.arange(16).reshape(4, 4)
938    >>> a
939    array([[ 0,  1,  2,  3],
940           [ 4,  5,  6,  7],
941           [ 8,  9, 10, 11],
942           [12, 13, 14, 15]])
943    >>> a[di] = 100
944    >>> a
945    array([[100,   1,   2,   3],
946           [  4, 100,   6,   7],
947           [  8,   9, 100,  11],
948           [ 12,  13,  14, 100]])
949
950    Now, we create indices to manipulate a 3-D array:
951
952    >>> d3 = np.diag_indices(2, 3)
953    >>> d3
954    (array([0, 1]), array([0, 1]), array([0, 1]))
955
956    And use it to set the diagonal of an array of zeros to 1:
957
958    >>> a = np.zeros((2, 2, 2), dtype=int)
959    >>> a[d3] = 1
960    >>> a
961    array([[[1, 0],
962            [0, 0]],
963           [[0, 0],
964            [0, 1]]])
965
966    """
967    idx = arange(n)
968    return (idx,) * ndim
969
970
971def _diag_indices_from(arr):
972    return (arr,)
973
974
975@array_function_dispatch(_diag_indices_from)
976def diag_indices_from(arr):
977    """
978    Return the indices to access the main diagonal of an n-dimensional array.
979
980    See `diag_indices` for full details.
981
982    Parameters
983    ----------
984    arr : array, at least 2-D
985
986    See Also
987    --------
988    diag_indices
989
990    Notes
991    -----
992    .. versionadded:: 1.4.0
993
994    """
995
996    if not arr.ndim >= 2:
997        raise ValueError("input array must be at least 2-d")
998    # For more than d=2, the strided formula is only valid for arrays with
999    # all dimensions equal, so we check first.
1000    if not alltrue(diff(arr.shape) == 0):
1001        raise ValueError("All dimensions of input must be of equal length")
1002
1003    return diag_indices(arr.shape[0], arr.ndim)
1004