1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file reset_arrays.cc
22  * \brief setting all array element values to zeros
23  * \author Moises Hernandez-Fernandez, Andrei Ivanov
24  */
25 
26 #include "./reset_arrays-inl.h"
27 
28 namespace mxnet {
29 namespace op {
30 
31 DMLC_REGISTER_PARAMETER(ResetArraysParam);
32 
33 NNVM_REGISTER_OP(reset_arrays)
34 .describe(R"code(Set to zero multiple arrays
35 )code" ADD_FILELINE)
__anon2c8944c20102(const nnvm::NodeAttrs& attrs) 36 .set_num_inputs([](const nnvm::NodeAttrs& attrs) {
37     return static_cast<uint32_t>(dmlc::get<ResetArraysParam>(attrs.parsed).num_arrays);
38   })
39 .set_attr<nnvm::FMutateInputs>("FMutateInputs",
__anon2c8944c20202(const nnvm::NodeAttrs& attrs) 40   [](const nnvm::NodeAttrs& attrs) {
41     const uint32_t num_args = dmlc::get<ResetArraysParam>(attrs.parsed).num_arrays;
42     std::vector<uint32_t> ret;
43     for (uint32_t i = 0; i < num_args; ++i) {
44       ret.push_back(i);
45     }
46     return ret;
47   })
48 .set_num_outputs(0)
49 .set_attr_parser(ParamParser<ResetArraysParam>)
50 .set_attr<mxnet::FInferShape>("FInferShape", ResetArraysShape)
51 .set_attr<nnvm::FInferType>("FInferType", ResetArraysType)
52 .set_attr<nnvm::FListInputNames>("FListInputNames",
__anon2c8944c20302(const NodeAttrs& attrs) 53   [](const NodeAttrs& attrs) {
54     const uint32_t num_args = dmlc::get<ResetArraysParam>(attrs.parsed).num_arrays;
55     std::vector<std::string> ret;
56     for (uint32_t i = 0; i < num_args; ++i) {
57       ret.push_back(std::string("array_") + std::to_string(i));
58     }
59     return ret;
60   })
61 .add_argument("data", "NDArray-or-Symbol[]", "Arrays")
62 .add_arguments(ResetArraysParam::__FIELDS__());
63 
64 NNVM_REGISTER_OP(reset_arrays)
65 .set_attr<FCompute>("FCompute<cpu>", ResetArrays<cpu>);
66 
67 template<>
ResetMemory(void * pntr,size_t len,mshadow::Stream<cpu> * s)68 void ResetMemory<cpu>(void *pntr, size_t len, mshadow::Stream<cpu> *s) {
69   memset(pntr, 0, len);
70 }
71 
72 }  // namespace op
73 }  // namespace mxnet
74