1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 #ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_FC_INL_H_ 21 #define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_FC_INL_H_ 22 #if MXNET_USE_MKLDNN == 1 23 24 #include <string> 25 #include <utility> 26 #include <vector> 27 #include "mkldnn.hpp" 28 #include "../../nn/mkldnn/mkldnn_fully_connected-inl.h" 29 30 namespace mxnet { 31 namespace op { 32 SupportMKLDNNFCEltwiseFusion(const std::string op_name)33static inline bool SupportMKLDNNFCEltwiseFusion(const std::string op_name) { 34 if (op_name == "Activation" || 35 op_name == "square" || 36 op_name == "sqrt" || 37 op_name == "exp" || 38 op_name == "abs" || 39 op_name == "clip" || 40 op_name == "LeakyReLU") { 41 return true; 42 } else { 43 return false; 44 } 45 } 46 GetMKLDNNEltwiseAlgo(const std::string op_name)47static inline mkldnn::algorithm GetMKLDNNEltwiseAlgo(const std::string op_name) { 48 if (op_name == "square") 49 return mkldnn::algorithm::eltwise_square; 50 else if (op_name == "sqrt") 51 return mkldnn::algorithm::eltwise_sqrt; 52 else if (op_name == "exp") 53 return mkldnn::algorithm::eltwise_exp; 54 else if (op_name == "abs") 55 return mkldnn::algorithm::eltwise_abs; 56 else 57 LOG(FATAL) << "Unsupported eltwise fusion op: " << op_name; 58 59 return mkldnn::algorithm::undef; 60 } 61 IsOutputUint8(const MKLDNNFCFullParam & full_param)62static inline bool IsOutputUint8(const MKLDNNFCFullParam& full_param) { 63 auto alg = full_param.eltwise_param.alg; 64 // TODO(ciyong): some alg doesn't support int8 so far. 65 if (full_param.mkldnn_param.with_eltwise && 66 (alg == mkldnn::algorithm::eltwise_relu || 67 alg == mkldnn::algorithm::eltwise_logistic || 68 alg == mkldnn::algorithm::eltwise_soft_relu || 69 alg == mkldnn::algorithm::eltwise_bounded_relu || 70 alg == mkldnn::algorithm::eltwise_square || 71 alg == mkldnn::algorithm::eltwise_sqrt || 72 alg == mkldnn::algorithm::eltwise_exp || 73 alg == mkldnn::algorithm::eltwise_abs)) { 74 return true; 75 } 76 77 return false; 78 } 79 80 } // namespace op 81 } // namespace mxnet 82 83 #endif // MXNET_USE_MKLDNN == 1 84 #endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_FC_INL_H_ 85