1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
18"""gather_nd in python"""
19import numpy as np
20
21def gather_nd_python(a_np, indices_np):
22    """ Python version of GatherND operator
23
24    Parameters
25    ----------
26    a_np : numpy.ndarray
27        Numpy array
28
29    indices_np : numpy.ndarray
30        Numpy array
31
32    Returns
33    -------
34    b_np : numpy.ndarray
35        Numpy array
36    """
37    a_shape = a_np.shape
38    indices_np = indices_np.astype('int32')
39    indices_shape = indices_np.shape
40    assert len(indices_shape) > 1
41    assert indices_shape[0] <= len(a_shape)
42    b_shape = list(indices_shape[1:])
43    for i in range(indices_shape[0], len(a_shape)):
44        b_shape.append(a_shape[i])
45    b_np = np.zeros(b_shape)
46    for idx in np.ndindex(*indices_shape[1:]):
47        a_idx = []
48        for i in range(indices_shape[0]):
49            indices_pos = tuple([i] + list(idx))
50            a_idx.append(indices_np[indices_pos])
51        b_np[idx] = a_np[tuple(a_idx)]
52    return b_np
53