1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_FC_INL_H_
21 #define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_FC_INL_H_
22 #if MXNET_USE_MKLDNN == 1
23 
24 #include <string>
25 #include <utility>
26 #include <vector>
27 #include "mkldnn.hpp"
28 #include "../../nn/mkldnn/mkldnn_fully_connected-inl.h"
29 
30 namespace mxnet {
31 namespace op {
32 
SupportMKLDNNFCEltwiseFusion(const std::string op_name)33 static inline bool SupportMKLDNNFCEltwiseFusion(const std::string op_name) {
34   if (op_name == "Activation" ||
35       op_name == "square" ||
36       op_name == "sqrt" ||
37       op_name == "exp" ||
38       op_name == "abs" ||
39       op_name == "clip" ||
40       op_name == "LeakyReLU") {
41     return true;
42   } else {
43     return false;
44   }
45 }
46 
GetMKLDNNEltwiseAlgo(const std::string op_name)47 static inline mkldnn::algorithm GetMKLDNNEltwiseAlgo(const std::string op_name) {
48   if (op_name == "square")
49     return mkldnn::algorithm::eltwise_square;
50   else if (op_name == "sqrt")
51     return mkldnn::algorithm::eltwise_sqrt;
52   else if (op_name == "exp")
53     return mkldnn::algorithm::eltwise_exp;
54   else if (op_name == "abs")
55     return mkldnn::algorithm::eltwise_abs;
56   else
57     LOG(FATAL) << "Unsupported eltwise fusion op: " << op_name;
58 
59   return mkldnn::algorithm::undef;
60 }
61 
IsOutputUint8(const MKLDNNFCFullParam & full_param)62 static inline bool IsOutputUint8(const MKLDNNFCFullParam& full_param) {
63   auto alg = full_param.eltwise_param.alg;
64   // TODO(ciyong): some alg doesn't support int8 so far.
65   if (full_param.mkldnn_param.with_eltwise &&
66       (alg == mkldnn::algorithm::eltwise_relu ||
67        alg == mkldnn::algorithm::eltwise_logistic ||
68        alg == mkldnn::algorithm::eltwise_soft_relu ||
69        alg == mkldnn::algorithm::eltwise_bounded_relu ||
70        alg == mkldnn::algorithm::eltwise_square ||
71        alg == mkldnn::algorithm::eltwise_sqrt ||
72        alg == mkldnn::algorithm::eltwise_exp ||
73        alg == mkldnn::algorithm::eltwise_abs)) {
74     return true;
75   }
76 
77   return false;
78 }
79 
80 }  // namespace op
81 }  // namespace mxnet
82 
83 #endif  // MXNET_USE_MKLDNN == 1
84 #endif  // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_FC_INL_H_
85