1 // 2 // PostTreatUtils.hpp 3 // MNNConverter 4 // 5 // Created by MNN on 2019/01/31. 6 // Copyright © 2018, Alibaba Group Holding Limited 7 // 8 9 #ifndef POSTTREATUTILS_HPP 10 #define POSTTREATUTILS_HPP 11 12 #include <stdio.h> 13 #include <stdlib.h> 14 #include <algorithm> 15 #include <cmath> 16 #include <fstream> 17 #include <map> 18 #include <sstream> 19 #include "MNN_generated.h" 20 #include "logkit.h" 21 class PostConverter { 22 public: 23 PostConverter() = default; 24 virtual ~PostConverter() = default; 25 virtual bool onExecute(std::unique_ptr<MNN::NetT>& net) const = 0; 26 static PostConverter* get(std::string key); 27 static void add(std::shared_ptr<PostConverter> converter, std::string key); 28 29 private: 30 static std::map<std::string, std::shared_ptr<PostConverter>>* getConvertMap(); 31 }; 32 33 template <class T> 34 class PostConverterRegister { 35 public: PostConverterRegister(const char * claim)36 PostConverterRegister(const char* claim) { 37 T* instance = new T; 38 PostConverter::add(std::shared_ptr<PostConverter>(instance), claim); 39 } 40 }; 41 42 class PostTreatUtils { 43 public: 44 static MNN::OpT* _findOpByOutputIndex(int outputIndex, const MNN::NetT* net); 45 static std::vector<MNN::OpT*> _findOpByInputIndex(int inputIndex, const MNN::NetT* net); 46 static void _removeOpInNet(MNN::OpT* op, MNN::NetT* net); 47 static bool _isSingleInputOutput(const MNN::OpT* op); 48 49 static int _getOpDecestorCount(MNN::OpT* op, const MNN::NetT* net); 50 static bool _replace(std::vector<int>& indexes, int freshIndex, int oldIndex); 51 52 private: 53 PostTreatUtils(); 54 }; 55 56 #endif // POSTTREATUTILS_HPP 57