1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 #ifndef MXNET_OPERATOR_CONTRIB_DGL_GRAPH_INL_H_
21 #define MXNET_OPERATOR_CONTRIB_DGL_GRAPH_INL_H_
22
23 #include <dmlc/logging.h>
24 #include <dmlc/parameter.h>
25 #include <mxnet/operator.h>
26 #include <mxnet/ndarray.h>
27 #include <map>
28 #include <algorithm>
29 #include <vector>
30 #include <string>
31 #include <utility>
32 #include "../operator_common.h"
33 #include "../mxnet_op.h"
34 #include "../mshadow_op.h"
35 #include "../tensor/init_op.h"
36
37 namespace mxnet {
38 namespace op {
39
40 template<typename xpu>
DGLAdjacencyForwardEx(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)41 void DGLAdjacencyForwardEx(const nnvm::NodeAttrs& attrs,
42 const OpContext& ctx,
43 const std::vector<NDArray>& inputs,
44 const std::vector<OpReqType>& req,
45 const std::vector<NDArray>& outputs) {
46 CHECK_EQ(inputs.size(), 1U);
47 CHECK_EQ(outputs.size(), 1U);
48 CHECK_EQ(req.size(), 1U);
49 CHECK_EQ(inputs[0].storage_type(), kCSRStorage);
50 CHECK_EQ(outputs[0].storage_type(), kCSRStorage);
51 CHECK_EQ(req[0], kWriteTo);
52 const TBlob &in_idx = inputs[0].aux_data(csr::kIdx);
53 const TBlob &in_indptr = inputs[0].aux_data(csr::kIndPtr);
54
55 outputs[0].CheckAndAllocData(in_idx.shape_);
56 outputs[0].CheckAndAllocAuxData(csr::kIdx, in_idx.shape_);
57 outputs[0].CheckAndAllocAuxData(csr::kIndPtr, in_indptr.shape_);
58
59 mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
60 Fill<false>(s, outputs[0].data(), req[0], 1.0);
61 mxnet_op::copy(s, outputs[0].aux_data(csr::kIdx), in_idx);
62 mxnet_op::copy(s, outputs[0].aux_data(csr::kIndPtr), in_indptr);
63 }
64
65 } // namespace op
66 } // namespace mxnet
67
68 #endif // MXNET_OPERATOR_CONTRIB_DGL_GRAPH_INL_H_
69