1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #ifndef MXNET_OPERATOR_CONTRIB_DGL_GRAPH_INL_H_
21 #define MXNET_OPERATOR_CONTRIB_DGL_GRAPH_INL_H_
22 
23 #include <dmlc/logging.h>
24 #include <dmlc/parameter.h>
25 #include <mxnet/operator.h>
26 #include <mxnet/ndarray.h>
27 #include <map>
28 #include <algorithm>
29 #include <vector>
30 #include <string>
31 #include <utility>
32 #include "../operator_common.h"
33 #include "../mxnet_op.h"
34 #include "../mshadow_op.h"
35 #include "../tensor/init_op.h"
36 
37 namespace mxnet {
38 namespace op {
39 
40 template<typename xpu>
DGLAdjacencyForwardEx(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)41 void DGLAdjacencyForwardEx(const nnvm::NodeAttrs& attrs,
42                            const OpContext& ctx,
43                            const std::vector<NDArray>& inputs,
44                            const std::vector<OpReqType>& req,
45                            const std::vector<NDArray>& outputs) {
46   CHECK_EQ(inputs.size(), 1U);
47   CHECK_EQ(outputs.size(), 1U);
48   CHECK_EQ(req.size(), 1U);
49   CHECK_EQ(inputs[0].storage_type(), kCSRStorage);
50   CHECK_EQ(outputs[0].storage_type(), kCSRStorage);
51   CHECK_EQ(req[0], kWriteTo);
52   const TBlob &in_idx = inputs[0].aux_data(csr::kIdx);
53   const TBlob &in_indptr = inputs[0].aux_data(csr::kIndPtr);
54 
55   outputs[0].CheckAndAllocData(in_idx.shape_);
56   outputs[0].CheckAndAllocAuxData(csr::kIdx, in_idx.shape_);
57   outputs[0].CheckAndAllocAuxData(csr::kIndPtr, in_indptr.shape_);
58 
59   mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
60   Fill<false>(s, outputs[0].data(), req[0], 1.0);
61   mxnet_op::copy(s, outputs[0].aux_data(csr::kIdx), in_idx);
62   mxnet_op::copy(s, outputs[0].aux_data(csr::kIndPtr), in_indptr);
63 }
64 
65 }  // namespace op
66 }  // namespace mxnet
67 
68 #endif  // MXNET_OPERATOR_CONTRIB_DGL_GRAPH_INL_H_
69