1 //
2 // PipelineBuilder.cpp
3 // MNN
4 //
5 // Created by MNN on 2020/07/28.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include <fstream>
10
11 #include "PipelineBuilder.hpp"
12 #include "MNN/MNNDefine.h"
13
14 namespace compression {
15
PipelineBuilder(const std::string & filename)16 PipelineBuilder::PipelineBuilder(const std::string& filename)
17 : filename_(filename) {}
18
Build() const19 Pipeline PipelineBuilder::Build() const {
20 Pipeline pipeline;
21 if (!filename_.empty()) {
22 MNN::Compression::Pipeline proto;
23 std::fstream input(filename_.c_str(), std::ios::in | std::ios::binary);
24 if (!proto.ParseFromIstream(&input)) {
25 MNN_ERROR("Failed to parse compression pipeline proto.\n");
26 } else {
27 ParsePipeline(proto, &pipeline);
28 }
29 }
30 return std::move(pipeline);
31 }
32
ParsePipeline(const MNN::Compression::Pipeline & proto,Pipeline * pipeline) const33 bool PipelineBuilder::ParsePipeline(const MNN::Compression::Pipeline& proto,
34 Pipeline* pipeline) const {
35 for (const auto& algo : proto.algo()) {
36 Progress progress;
37 progress.type = algo.type();
38 switch (progress.type) {
39 case CompressionAlgo::QUANTIZE: {
40 ParseQuantization(algo.quant_params(),
41 &(progress.quant_params));
42 break;
43 }
44 case CompressionAlgo::PRUNE:
45 default: {
46 MNN_ERROR("Unsupported compression type: %d.\n", progress.type);
47 }
48 }
49 pipeline->progress_.push_back(std::move(progress));
50 }
51 return true;
52 }
53
ParseActivationQuantization(const LayerQuantizeParams::ActivationParams & proto) const54 Quantization::TensorParams PipelineBuilder::ParseActivationQuantization(
55 const LayerQuantizeParams::ActivationParams& proto) const {
56 Quantization::TensorParams tensor_params;
57 tensor_params.nbit = proto.bits();
58 int size = proto.scales().size();
59 tensor_params.scale.resize(size);
60 for (int i = 0; i < size; ++i) {
61 tensor_params.scale[i] = proto.scales(i);
62 }
63 tensor_params.zero_point = proto.zero_point();
64 tensor_params.clamp_min = proto.clamp_min();
65 tensor_params.clamp_max = proto.clamp_max();
66 return std::move(tensor_params);
67 }
68
ParseWeightQuantization(const LayerQuantizeParams::WeightParams & proto) const69 Quantization::TensorParams PipelineBuilder::ParseWeightQuantization(
70 const LayerQuantizeParams::WeightParams& proto) const {
71 Quantization::TensorParams tensor_params;
72 tensor_params.nbit = proto.bits();
73 int size = proto.scales().size();
74 tensor_params.scale.resize(size);
75 for (int i = 0; i < size; ++i) {
76 tensor_params.scale[i] = proto.scales(i);
77 }
78 tensor_params.zero_point = 0.f;
79 return std::move(tensor_params);
80 }
81
ParseQuantization(const MNN::Compression::QuantizeParams & proto,Quantization * quant_params) const82 bool PipelineBuilder::ParseQuantization(
83 const MNN::Compression::QuantizeParams& proto,
84 Quantization* quant_params) const {
85 quant_params->round_mode = proto.round_mode();
86 for (const auto& layer_proto : proto.layer()) {
87 auto method = layer_proto.method();
88 for (const auto& input : layer_proto.input()) {
89 const std::string& tensor_name = input.name();
90 Quantization::TensorParams tensor_params =
91 ParseActivationQuantization(input);
92 tensor_params.method = method;
93 quant_params->tensors[tensor_name].push_back(tensor_params);
94 }
95 int length = 0;
96 for (const auto& weight : layer_proto.weight()) {
97 const std::string& tensor_name = weight.name();
98 Quantization::TensorParams tensor_params =
99 ParseWeightQuantization(weight);
100 // TODO(): FIXME
101 // quant_params->tensors[tensor_name].push_back(tensor_params);
102 length = tensor_params.scale.size();
103 }
104 for (const auto& output : layer_proto.output()) {
105 const std::string& tensor_name = output.name();
106 Quantization::TensorParams tensor_params =
107 ParseActivationQuantization(output);
108 if (tensor_params.scale.size() != length) {
109 MNN_ERROR("Output(%s) scale and weight scale length are "
110 "mismatch, (%d vs %d).\n", tensor_name.c_str(),
111 int(tensor_params.scale.size()), length);
112 MNN_ASSERT(false);
113 }
114 tensor_params.method = method;
115 quant_params->tensors[tensor_name].push_back(tensor_params);
116 }
117 }
118 return true;
119 }
120
121 };
122