1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file caffe_op.cc
22  * \brief caffe operator
23  * \author Haoran Wang
24 */
25 #include "./caffe_op-inl.h"
26 namespace mxnet {
27 namespace op {
28 
29 template<>
CreateOp(CaffeOpParam param,int dtype)30 Operator* CreateOp<cpu>(CaffeOpParam param, int dtype) {
31   Operator *op = NULL;
32   switch (dtype) {
33   case mshadow::kFloat32:
34     op = new CaffeOp<cpu, float>(param);
35     break;
36   case mshadow::kFloat64:
37     op = new CaffeOp<cpu, double>(param);
38     break;
39   case mshadow::kFloat16:
40     LOG(FATAL) << "float16 layer is not supported by caffe";
41     break;
42   case mshadow::kBfloat16:
43     LOG(FATAL) << "bfloat16 layer is not supported by caffe";
44     break;
45   default:
46     LOG(FATAL) << "Unsupported type " << dtype;
47   }
48   return op;
49 }
50 
51 // DO_BIND_DISPATCH comes from static_operator_common.h
CreateOperatorEx(Context ctx,mxnet::ShapeVector * in_shape,std::vector<int> * in_type) const52 Operator *CaffeOpProp::CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape,
53                                      std::vector<int> *in_type) const {
54   std::vector<int> out_type, aux_type;
55   mxnet::ShapeVector out_shape, aux_shape;
56   out_type.resize(this->ListOutputs().size());
57   out_shape.resize(this->ListOutputs().size());
58   aux_type.resize(this->ListAuxiliaryStates().size());
59   aux_shape.resize(this->ListAuxiliaryStates().size());
60   CHECK(InferType(in_type, &out_type, &aux_type));
61   CHECK(InferShape(in_shape, &out_shape, &aux_shape));
62   DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]);
63 }
64 
65 DMLC_REGISTER_PARAMETER(CaffeOpParam);
66 
67 MXNET_REGISTER_OP_PROPERTY(CaffeOp, CaffeOpProp)
68 .describe("Apply caffe operator")
69 .add_argument("data", "Symbol[]", "List of tensors")
70 .add_arguments(CaffeOpParam::__FIELDS__());
71 
72 }  // namespace op
73 }  // namespace mxnet
74