1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals 18"""gather_nd in python""" 19import numpy as np 20 21def gather_nd_python(a_np, indices_np): 22 """ Python version of GatherND operator 23 24 Parameters 25 ---------- 26 a_np : numpy.ndarray 27 Numpy array 28 29 indices_np : numpy.ndarray 30 Numpy array 31 32 Returns 33 ------- 34 b_np : numpy.ndarray 35 Numpy array 36 """ 37 a_shape = a_np.shape 38 indices_np = indices_np.astype('int32') 39 indices_shape = indices_np.shape 40 assert len(indices_shape) > 1 41 assert indices_shape[0] <= len(a_shape) 42 b_shape = list(indices_shape[1:]) 43 for i in range(indices_shape[0], len(a_shape)): 44 b_shape.append(a_shape[i]) 45 b_np = np.zeros(b_shape) 46 for idx in np.ndindex(*indices_shape[1:]): 47 a_idx = [] 48 for i in range(indices_shape[0]): 49 indices_pos = tuple([i] + list(idx)) 50 a_idx.append(indices_np[indices_pos]) 51 b_np[idx] = a_np[tuple(a_idx)] 52 return b_np 53