1 //
2 //  PostTreatUtils.hpp
3 //  MNNConverter
4 //
5 //  Created by MNN on 2019/01/31.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #ifndef POSTTREATUTILS_HPP
10 #define POSTTREATUTILS_HPP
11 
12 #include <stdio.h>
13 #include <stdlib.h>
14 #include <algorithm>
15 #include <cmath>
16 #include <fstream>
17 #include <map>
18 #include <sstream>
19 #include "MNN_generated.h"
20 #include "logkit.h"
21 class PostConverter {
22 public:
23     PostConverter()                                               = default;
24     virtual ~PostConverter()                                      = default;
25     virtual bool onExecute(std::unique_ptr<MNN::NetT>& net) const = 0;
26     static PostConverter* get(std::string key);
27     static void add(std::shared_ptr<PostConverter> converter, std::string key);
28 
29 private:
30     static std::map<std::string, std::shared_ptr<PostConverter>>* getConvertMap();
31 };
32 
33 template <class T>
34 class PostConverterRegister {
35 public:
PostConverterRegister(const char * claim)36     PostConverterRegister(const char* claim) {
37         T* instance = new T;
38         PostConverter::add(std::shared_ptr<PostConverter>(instance), claim);
39     }
40 };
41 
42 class PostTreatUtils {
43 public:
44     static MNN::OpT* _findOpByOutputIndex(int outputIndex, const MNN::NetT* net);
45     static std::vector<MNN::OpT*> _findOpByInputIndex(int inputIndex, const MNN::NetT* net);
46     static void _removeOpInNet(MNN::OpT* op, MNN::NetT* net);
47     static bool _isSingleInputOutput(const MNN::OpT* op);
48 
49     static int _getOpDecestorCount(MNN::OpT* op, const MNN::NetT* net);
50     static bool _replace(std::vector<int>& indexes, int freshIndex, int oldIndex);
51 
52 private:
53     PostTreatUtils();
54 };
55 
56 #endif // POSTTREATUTILS_HPP
57