1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file caffe_op.cc
22 * \brief caffe operator
23 * \author Haoran Wang
24 */
25 #include "./caffe_op-inl.h"
26 namespace mxnet {
27 namespace op {
28
29 template<>
CreateOp(CaffeOpParam param,int dtype)30 Operator* CreateOp<cpu>(CaffeOpParam param, int dtype) {
31 Operator *op = NULL;
32 switch (dtype) {
33 case mshadow::kFloat32:
34 op = new CaffeOp<cpu, float>(param);
35 break;
36 case mshadow::kFloat64:
37 op = new CaffeOp<cpu, double>(param);
38 break;
39 case mshadow::kFloat16:
40 LOG(FATAL) << "float16 layer is not supported by caffe";
41 break;
42 case mshadow::kBfloat16:
43 LOG(FATAL) << "bfloat16 layer is not supported by caffe";
44 break;
45 default:
46 LOG(FATAL) << "Unsupported type " << dtype;
47 }
48 return op;
49 }
50
51 // DO_BIND_DISPATCH comes from static_operator_common.h
CreateOperatorEx(Context ctx,mxnet::ShapeVector * in_shape,std::vector<int> * in_type) const52 Operator *CaffeOpProp::CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape,
53 std::vector<int> *in_type) const {
54 std::vector<int> out_type, aux_type;
55 mxnet::ShapeVector out_shape, aux_shape;
56 out_type.resize(this->ListOutputs().size());
57 out_shape.resize(this->ListOutputs().size());
58 aux_type.resize(this->ListAuxiliaryStates().size());
59 aux_shape.resize(this->ListAuxiliaryStates().size());
60 CHECK(InferType(in_type, &out_type, &aux_type));
61 CHECK(InferShape(in_shape, &out_shape, &aux_shape));
62 DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]);
63 }
64
65 DMLC_REGISTER_PARAMETER(CaffeOpParam);
66
67 MXNET_REGISTER_OP_PROPERTY(CaffeOp, CaffeOpProp)
68 .describe("Apply caffe operator")
69 .add_argument("data", "Symbol[]", "List of tensors")
70 .add_arguments(CaffeOpParam::__FIELDS__());
71
72 } // namespace op
73 } // namespace mxnet
74