1 /**
2  * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef GLOW_QUANTIZATION_BASE_BASE_H
18 #define GLOW_QUANTIZATION_BASE_BASE_H
19 
20 #include "glow/Base/Tensor.h"
21 #include "glow/Base/Traits.h"
22 #include "glow/Base/Type.h"
23 
24 #include <algorithm>
25 #include <cassert>
26 #include <cstdlib>
27 #include <limits>
28 
29 namespace glow {
30 
31 /// Profiling parameters of a tensor consisting in the global minimum and global
32 /// maximum values and also the histogram obtained during profiling. To be noted
33 /// that the histogram is not normalized.
34 struct TensorProfilingParams {
35   float min;
36   float max;
37   std::vector<float> histogram;
38 
39   TensorProfilingParams() = default;
TensorProfilingParamsTensorProfilingParams40   TensorProfilingParams(float min, float max) : min(min), max(max) {}
TensorProfilingParamsTensorProfilingParams41   TensorProfilingParams(float min, float max, const std::vector<float> &hist)
42       : min(min), max(max), histogram(hist) {}
TensorProfilingParamsTensorProfilingParams43   TensorProfilingParams(float min, float max, const Tensor &hist)
44       : min(min), max(max) {
45     auto histH = hist.getHandle<float>();
46     histogram = std::vector<float>(histH.size());
47     for (dim_t idx = 0, e = histH.size(); idx < e; idx++) {
48       histogram[idx] = histH.raw(idx);
49     }
50   }
51 };
52 
53 /// Main attributes of a quantized tensor.
54 /// Scale and Offset allow quantization of a float tensor and dequantization of
55 /// integer tensor back to float one.
56 struct TensorQuantizationParams {
57   float scale;
58   int32_t offset;
59 };
60 
61 /// A data structure that represents the 32-bit to 8-bit quantization
62 /// scaling operation. This data structure represents the transformation:
63 /// (((input >> pre) * scale) + rtn) >> post + offset.
64 struct QuantizationTransform32To8 {
65   int pre;
66   int post;
67   int scale;
68   int offset;
69 
70   /// Initializes the transformation based on the conversion formula (above).
QuantizationTransform32To8QuantizationTransform32To871   QuantizationTransform32To8(int pre, int post, int scale, int offset)
72       : pre(pre), post(post), scale(scale), offset(offset) {}
73 
74   /// \returns the scaled integer.
transformQuantizationTransform32To875   int32_t transform(int32_t input) {
76     // The operation x >> post is rounded down to negative infinity. To get to
77     // round-nearest we add (1 << (post - 1)) to the value prior to shifting.
78     // Rounding is performed only when shifting right (pos > 0).
79     int rtn = (post > 0) ? (1 << (post - 1)) : 0;
80     return ((((input >> pre) * scale) + rtn) >> post) + offset;
81   }
82 };
83 
84 /// Tensor profiling parameters for a given node.
85 struct NodeProfilingInfo {
86   std::string nodeOutputName_;
87   TensorProfilingParams tensorProfilingParams_;
88 
89   NodeProfilingInfo() = default;
NodeProfilingInfoNodeProfilingInfo90   NodeProfilingInfo(const std::string &nodeOutputName,
91                     const TensorProfilingParams &tensorProfilingParams)
92       : nodeOutputName_(nodeOutputName),
93         tensorProfilingParams_(tensorProfilingParams) {}
94 
minNodeProfilingInfo95   float min() const { return tensorProfilingParams_.min; }
maxNodeProfilingInfo96   float max() const { return tensorProfilingParams_.max; }
histogramNodeProfilingInfo97   const std::vector<float> &histogram() const {
98     return tensorProfilingParams_.histogram;
99   }
100 };
101 
102 /// Tensor quantization parameters for a given node.
103 struct NodeQuantizationInfo {
104   std::string nodeOutputName_;
105   TensorQuantizationParams tensorQuantizationParams_;
106 
107   NodeQuantizationInfo() = default;
NodeQuantizationInfoNodeQuantizationInfo108   NodeQuantizationInfo(const std::string &nodeOutputName,
109                        const TensorQuantizationParams &tensorQuantizationParams)
110       : nodeOutputName_(nodeOutputName),
111         tensorQuantizationParams_(tensorQuantizationParams) {}
112 
scaleNodeQuantizationInfo113   float scale() const { return tensorQuantizationParams_.scale; }
offsetNodeQuantizationInfo114   int32_t offset() const { return tensorQuantizationParams_.offset; }
115 };
116 
117 namespace quantization {
118 
119 /// Type definition for a float min/max range.
120 using FloatRange = std::pair<float, float>;
121 
122 /// Type definition for a quantized min/max range.
123 using QuantizedRange = std::pair<int64_t, int64_t>;
124 
125 /// Quantization schema which influences the way the quantization parameters
126 /// scale and offset are computed based on the target min/max dynamic range.
127 enum Schema {
128   /// Asymmetric quantization produces ranges not necessarily centered on 0.
129   Asymmetric,
130   /// Symmetric quantization produces ranges centered on 0.
131   Symmetric,
132   /// Symmetric quantization produces ranges centered on 0 or -qmin, qmin being
133   /// the minimum value of the quantized type.
134   /// An offset of qmin (i.e., offset == -128 for int8) represents an unsigned
135   /// version of the quantized type with an offset of zero:
136   /// For example, int8 is [-128; 127] - (-128) == uint8 [0; 255] - 0
137   SymmetricWithUnsigned,
138   /// Quantization schema with:
139   /// - range centered on 0 (symmetric): offset == 0.
140   /// - scale parameter is a power of 2: scale = 2^E where E is a signed
141   ///   exponent. Since the scale parameter is mostly subunitary, the
142   ///   exponent is mostly negative.
143   /// Since the scale parameter is stored as floating point, the values
144   /// of E which are exactly representable range from -126 to 127.
145   SymmetricWithPower2Scale,
146 };
147 
148 /// Calibration mode which influences the way the dynamic range min/max obtained
149 /// during profiling is narrowed in order to have a more precise representation
150 /// for the majority of the values with the price of saturating the outliers.
151 enum Calibration {
152   /// No calibration. The quantization parameters will be computed using the
153   /// unaltered dynamic range min/max obtained during profiling such that all
154   /// the profiled dynamic range will be representable without saturation.
155   None,
156   /// Calibration mode based on minimizing the Kullback-Leibler divergence.
157   KLMinimization
158 };
159 
160 /// Configuration for Profiling, passed into \ref profileQuantization().
161 struct ProfilingConfiguration {
162   /// Number of bins used to compute the histogram during profiling.
163   unsigned numHistogramBins{10};
164 };
165 
166 /// Configuration for Quantization, passed into \ref quantizeFunction().
167 struct QuantizationConfiguration {
168   /// Profiling infos to use when computing the scale and offset for all the
169   /// Nodes inside the function being quantized, including the referenced
170   /// Placeholders and Constants.
171   std::vector<NodeProfilingInfo> infos{};
172 
173   /// The hash of the graph obtained during profiling in the pre lowering stage.
174   /// This hash is used to verify during quantization that the graph being
175   /// compiled matches the graph used for obtaining the profiling information.
176   llvm::hash_code graphPreLowerHash{0};
177 
178   /// Whether to check the graph hash during quantization.
179   bool checkGraphPreLowerHash{false};
180 
181   /// Precision to use when quantizing a Function.
182   ElemKind precision{ElemKind::Int8QTy};
183 
184   /// Schema to use when quantizing a Function.
185   Schema schema{Schema::Asymmetric};
186 
187   /// Calibration mode used when computing the quantization parameters.
188   Calibration calibration{Calibration::None};
189 
190   /// Whether to enable the calibration for constant weights.
191   bool calibrateConstants{false};
192 
193   /// Whether to use rowwise quantization when quantizing a Function.
194   bool enableRowwise{false};
195 
196   /// Whether to use channelwise quantization when quantizing a Function.
197   bool enableChannelwise{false};
198 
199   /// New name for the quantized function. If no name is given then
200   /// \ref quantizeFunction() will generate a name.
201   std::string newFuncName{""};
202 
203   /// If true, the quantizer will abort when encountering a node that it would
204   /// like to quantize but the backend cannot support. Note that node kinds in
205   /// doNotQuantizeKinds will skip this check and not cause an abort.
206   bool assertAllNodesQuantized{false};
207 
208   /// Precision used for bias quantization for Convolution and FullyConnected.
209   /// This allows specializing the bias quantization. Default is int32.
210   ElemKind precisionBias{ElemKind::Int32QTy};
211 
212   /// If true, don't apply quantization to FC bias inputs.
213   bool skipQuantizeFCBias{false};
214 
215   QuantizationConfiguration() = default;
QuantizationConfigurationQuantizationConfiguration216   QuantizationConfiguration(llvm::ArrayRef<NodeProfilingInfo> i) : infos(i) {}
217 };
218 
219 /// \returns the tensor average value based on the profiling info \p profParams.
220 float getTensorAverageValue(const TensorProfilingParams &profParams);
221 
222 /// \returns the value \p in as clipped to the range of \p DestTy.
clip(SrcTy in)223 template <class SrcTy, class DestTy> DestTy clip(SrcTy in) {
224   static_assert(sizeof(SrcTy) >= sizeof(DestTy), "Invalid types");
225 
226   auto mx = std::numeric_limits<DestTy>::max();
227   auto mn = std::numeric_limits<DestTy>::min();
228   return std::max<SrcTy>(mn, std::min<SrcTy>(mx, in));
229 }
230 
231 /// Converts floating point value to DestTy (quantized type) based on the
232 /// quantization parameters \p TQP.
233 template <class DestTy = int8_t>
quantize(float input,const TensorQuantizationParams & TQP)234 inline DestTy quantize(float input, const TensorQuantizationParams &TQP) {
235   float result = input / TQP.scale + TQP.offset;
236   // Note: use int64_t since casts of large values might be wrapped around
237   // before clipping, for example for result = 2147483648.00 (float).
238   return quantization::clip<int64_t, DestTy>((int64_t)nearbyintf(result));
239 }
240 
241 /// Converts a quantized value (type eTy) to floating point based on the
242 /// quantization parameters \p TQP.
243 /// Note: use int64_t to cover the 'symmetric int32 with unsigned' case.
244 template <class eTy = int8_t>
dequantize(eTy input,const TensorQuantizationParams & TQP)245 inline float dequantize(eTy input, const TensorQuantizationParams &TQP) {
246   return TQP.scale * ((int64_t)input - TQP.offset);
247 }
248 
249 /// Converts floating point value to DestTy (quantized type) based on the
250 /// quantization parameters \p scale and \p offset. If the dest type is int8_t,
251 /// then an offset of 128 is substracted to convert to int8_t.
252 template <class DestTy>
quantizeWithFloatOffset(float input,float scale,float offset)253 inline DestTy quantizeWithFloatOffset(float input, float scale, float offset) {
254   uint8_t d = static_cast<uint8_t>(std::round((input - offset) / scale));
255   if (std::is_same<int8_t, DestTy>::value) {
256     d -= 128;
257   }
258   return static_cast<DestTy>(d);
259 }
260 
261 /// Converts floating point value \p input to 4-bit quantization based on the
262 /// quantization parameters \p scale and \p offset.
quantize4BitsWithFloatOffset(float input,float scale,float offset)263 inline uint8_t quantize4BitsWithFloatOffset(float input, float scale,
264                                             float offset) {
265   uint8_t d = std::max(
266       0, std::min(static_cast<int>(std::round((input - offset) / scale)), 15));
267   return d;
268 }
269 
270 /// Converts a quantized value (type eTy) to floating point based on the
271 /// quantization parameters \p scale and \p offset. If the input type is int8_t,
272 /// then an offset of 128 is added to convert to uint8_t.
273 template <class eTy>
dequantizeWithFloatOffset(eTy input,float scale,float offset)274 inline float dequantizeWithFloatOffset(eTy input, float scale, float offset) {
275   uint8_t d = static_cast<uint8_t>(input);
276   if (std::is_same<int8_t, eTy>::value) {
277     d += 128;
278   }
279   return (d * scale) + offset;
280 }
281 
282 /// Converts a 4-bit quantized value, which is stored in \p input (MSB if \p
283 /// isMSB is true, otherwise LSB), to floating point based on the quantization
284 /// parameters \p scale and \p offset.
dequantize4BitWithFloatOffset(uint8_t input,float scale,float offset,bool isMSB)285 inline float dequantize4BitWithFloatOffset(uint8_t input, float scale,
286                                            float offset, bool isMSB) {
287   if (isMSB) {
288     input >>= 4;
289   }
290   input &= 0x0f;
291   return (input * scale) + offset;
292 }
293 
294 /// Converts a floating point \p tensor to quantized tensor based on the
295 /// quantization parameters \p TQP and \p Ty.
296 Tensor quantizeTensor(const Tensor &tensor, const TensorQuantizationParams &TQP,
297                       ElemKind Ty = ElemKind::Int8QTy);
298 
299 /// Converts quantized tensor \p tensor to floating point tensor of type \p Ty
300 /// floatKind.
301 Tensor dequantizeTensor(const Tensor &tensor, ElemKind floatKind);
302 
303 /// Dequantize 4-bit fused quantized tensor \p input. \returns the float type
304 /// output.
305 Tensor tensor4BitsFusedRowwiseDequantization(const Tensor &input);
306 
307 /// Convert the floating point quantization parameters \p scale and \p offset
308 /// into the integer sequence of:
309 /// result = ((input >> pre) * scale) >> post + offset.
310 /// This scales a 32-bit signed integer word into an 8-bit signed integer.
311 /// \returns transformation parameters.
312 QuantizationTransform32To8 quantizeScaleOffset32To8(float scale,
313                                                     int32_t offset);
314 
315 /// Function to get the quantized range for a given precision type \p qTy.
316 /// \returns the range as a (min, max) pair.
317 QuantizedRange getQuantizedRange(ElemKind qTy);
318 
319 /// Function to validate that the given quantization parameters \p qParams
320 /// comply with the given quantization \p schema and precision \p qTy.
321 void validateQuantizationParams(TensorQuantizationParams qParams, Schema schema,
322                                 ElemKind qTy);
323 
324 /// Calculate the TensorQuantizationParams from the TensorProfilingParams
325 /// \p profParams using the quantization type \p qTy and the quantization
326 /// method described by \p schema. The calibration of the quantization
327 /// parameters will be done using the method given by \p calibration.
328 TensorQuantizationParams
329 chooseQuantizationParams(TensorProfilingParams profParams,
330                          Schema schema = Asymmetric,
331                          ElemKind qTy = ElemKind::Int8QTy,
332                          Calibration calibration = Calibration::None);
333 
334 /// Function to specialize the TensorQuantizationParams of the bias operand
335 /// for nodes like Convolution and FullyConnected given the initially computed
336 /// parameters \p biasTQP and the parameters of the input \p inputTQP and the
337 /// weights \p weightsTQP, for given quantization schema \p schema and bias type
338 /// \p biasQTy. The bias operand requires a more thoughtful quantization since
339 /// every bias value has a higher impact on the precision of the output value
340 /// than any particular weight value. The specialization logic is:
341 /// - for INT32 bias quantization: since the dynamic range of INT32 is large we
342 ///   can always force symmetric quantization (offset = 0). This allows a faster
343 ///   implementation since no offset subtraction is required at run-time.
344 /// - for INT8/INT16 bias quantization: since the dynamic range is small we
345 ///   will keep the original offset.
346 /// - regardless of precision, we try to force the bias scale parameter to
347 ///   bias_scale = input_scale * weights_scale since this has a performance
348 ///   benefit by specializing the parameters to biasPre = 0, biasPost = 0,
349 ///   biasScale = 1. We must verify that by changing the bias scale we don`t
350 ///   saturate the bias data. This is also equivalent to forcing the effective
351 ///   scale applied at run-time (bias_scale / (input_scale * weights_scale))
352 ///   to be always greater than or equal to 1.0 which is a common constraint
353 ///   for the bias for most libraries with quantized implementations.
354 TensorQuantizationParams
355 specializeBiasQuantizationParams(const TensorQuantizationParams &biasTQP,
356                                  const TensorQuantizationParams &inputTQP,
357                                  const TensorQuantizationParams &weightsTQP,
358                                  Schema schema, ElemKind biasQTy);
359 
360 /// \returns an int8 vector mapping from the \p inTy to the \p outTy given the
361 /// function \p f.
362 /// \pre inTy and outTy should be Int8QTy.
363 std::vector<int8_t> createMapping(TypeRef inTy, TypeRef outTy,
364                                   std::function<float(float)> f);
365 
366 /// Row-wise quantize the tensor \p input. \p scales and \p offsets are
367 /// generated by each row of \p input, \p output is tensor of the same shape as
368 /// input, quantized from \p input using \p scales and \p offsets for each
369 /// row. Note that the shape of input/output can be any non-zero number of
370 /// dimensions; row refers to all data in the first dimension of the shape.
371 /// Template parameter \p ScaleT and OffsetT represent the type to use for the
372 /// scales and offsets for quantization respectively. Template parameter \p QP
373 /// represents quantization precision, typically int8_t or uint8_t.
374 template <typename ScaleT, typename OffsetT, typename QP>
tensorRowwiseQuantization(const Tensor & input,Tensor & output,Tensor & scales,Tensor & offsets,quantization::Schema schema)375 void tensorRowwiseQuantization(const Tensor &input, Tensor &output,
376                                Tensor &scales, Tensor &offsets,
377                                quantization::Schema schema) {
378   constexpr bool offsetIsFP = std::is_same<float, OffsetT>::value ||
379                               std::is_same<float16_t, OffsetT>::value;
380   constexpr bool offsetIsInt32 = std::is_same<int32_t, OffsetT>::value;
381   static_assert((offsetIsInt32 && std::is_same<float, ScaleT>::value) ||
382                     (offsetIsFP && std::is_same<ScaleT, OffsetT>::value),
383                 "Invalid combination of Scale/Offset types.");
384 
385   const auto fDims = flattenCdr(input.dims());
386   Tensor finalIn = input.getUnowned({fDims.first, fDims.second});
387   Tensor finalOut = output.getUnowned({fDims.first, fDims.second});
388   ShapeHW idim(finalIn.dims());
389 
390   auto srcH = finalIn.getHandle<float>();
391   auto destH = finalOut.getHandle<QP>();
392   auto scalesH = scales.getHandle<ScaleT>();
393   auto offsetsH = offsets.getHandle<OffsetT>();
394   for (dim_t i = 0; i < idim.height; i++) {
395     auto slice = srcH.extractSlice(i);
396     auto rSrc = slice.getHandle<float>();
397     auto res = rSrc.minMaxArg();
398     float min = rSrc.raw(res.first);
399     float max = rSrc.raw(res.second);
400 
401     // Handle rowwise quantization for FCs.
402     if (offsetIsInt32) {
403       TensorQuantizationParams qParams =
404           chooseQuantizationParams({min, max}, schema);
405       for (dim_t j = 0; j < idim.width; j++) {
406         destH.at({i, j}) = quantization::quantize(srcH.at({i, j}), qParams);
407       }
408       scalesH.raw(i) = qParams.scale;
409       offsetsH.raw(i) = qParams.offset;
410     } else if (offsetIsFP) {
411       // Handle rowwise quantization for Rowwise quantized SLS.
412       constexpr float kEqualityThreshold = 1e-10f;
413       const float scale = ((max - min) < kEqualityThreshold)
414                               ? 1.0
415                               : ((double)max - (double)min) / 255.0;
416       float offset = min;
417 
418       for (dim_t j = 0; j < idim.width; j++) {
419         destH.at({i, j}) = quantization::quantizeWithFloatOffset<QP>(
420             srcH.at({i, j}), scale, offset);
421       }
422       scalesH.raw(i) = static_cast<ScaleT>(scale);
423       offsetsH.raw(i) = static_cast<OffsetT>(offset);
424     } else {
425       llvm_unreachable("Unsupported offset type.");
426     }
427   }
428 }
429 
430 /// Fused-rowwise quantize the tensor \p input. Scales and offsets are generated
431 /// from each row of \p input. This function supports 8-bits quantization (i.e.
432 /// each quantized data uses 8 bits) and 4-bits quantization(i.e. each quantized
433 /// data uses 4 bits).
434 /// For 8-bits quantization, \p output is tensor of the same shape as input but
435 /// with extra columns for storing fused scales. Template parameter \p T
436 /// represents the datatype used for storing the scale and offset in the row
437 /// |   .... int8 data ...    |   scale   |  offset   |
438 /// |num_of_input_columns * 1B| sizeof(T) | sizeof(T) |
439 /// For 4-bits quantization, in \p output, 1 byte will contain 2 quantized data.
440 /// Template parameter \p T here must be float16_t.
441 /// |   .... int4 data ...       | scale | offset |
442 /// |num_of_input_columns * 0.5B |  2B   |   2B   |
443 /// \pre input.dims().size() == 2
444 /// \pre output.dims().size() == 2
445 /// For 8-bits quantization:
446 /// \pre input.dims()[1] + 2 * sizeof(T) == output.dims()[1]
447 /// For 4-bits quantization:
448 /// \pre input.dims()[1] % 2 == 0
449 /// \pre input.dims()[1] / 2 + 2 * sizeof(T) == output.dims()[1]
450 template <typename T>
tensorFusedRowwiseQuantization(const Tensor & input,Tensor & output)451 void tensorFusedRowwiseQuantization(const Tensor &input, Tensor &output) {
452   // We are fusing the scale and offset onto the end of each row. Thus input and
453   // output must both be 2 dimensional, with output having 2*sizeof(T) extra
454   // columns for the scale and offset.
455   auto outputType = output.getElementType();
456   assert(input.dims().size() == 2 && output.dims().size() == 2 &&
457          "Input and output must be 2 dimensional.");
458   if (outputType == ElemKind::UInt8FusedFP16QTy ||
459       outputType == ElemKind::UInt8FusedQTy) {
460     assert(input.dims()[1] + 2 * sizeof(T) == output.dims()[1] &&
461            "Output must have 2*sizeof(T) more columns than input for 8-bits "
462            "quantization.");
463   } else if (outputType == ElemKind::UInt4FusedFP16QTy) {
464     constexpr bool scaleIsFP16 = std::is_same<float16_t, T>::value;
465     (void)scaleIsFP16;
466     assert(scaleIsFP16 && "Only float16_t scale and offset are supported "
467                           "in 4-bit fused quantization");
468     assert(
469         input.dims()[1] % 2 == 0 &&
470         "4-bits fused quantization only works for the number of input column "
471         "a multiple of 2");
472     assert(
473         input.dims()[1] / 2 + 2 * sizeof(T) == output.dims()[1] &&
474         "Output must have 2*sizeof(T) more columns than half of input columns "
475         "for 4-bits quantization.");
476   }
477 
478   auto srcH = input.getHandle<float>();
479   auto destH = output.getHandle<uint8_t>();
480   for (dim_t i = 0, e = input.dims()[0]; i < e; i++) {
481     auto slice = srcH.extractSlice(i);
482     auto rSrc = slice.getHandle<float>();
483     auto res = rSrc.minMaxArg();
484     float min = rSrc.raw(res.first);
485     float max = rSrc.raw(res.second);
486 
487     float range;
488     switch (outputType) {
489     case ElemKind::UInt8FusedQTy:
490     case ElemKind::UInt8FusedFP16QTy:
491       range = 255.0;
492       break;
493     case ElemKind::UInt4FusedFP16QTy:
494       range = 15.0;
495       break;
496     default:
497       llvm_unreachable("Not yet supported");
498     }
499 
500     // This matches the Caffe2 implementation for FloatToRowwiseQuantized8BitsOp
501     // found in operators/lengths_reducer_rowwise_8bit_ops.h.
502     constexpr float kEqualityThreshold = 1e-10f;
503     const float scale = ((max - min) < kEqualityThreshold)
504                             ? 1.0
505                             : ((double)max - (double)min) / range;
506     const float offset = min;
507 
508     for (dim_t j = 0, f = input.dims()[1]; j < f; j++) {
509       if (outputType == ElemKind::UInt8FusedFP16QTy ||
510           outputType == ElemKind::UInt8FusedQTy) {
511         destH.at({i, j}) = quantization::quantizeWithFloatOffset<uint8_t>(
512             srcH.at({i, j}), scale, offset);
513       } else if (outputType == ElemKind::UInt4FusedFP16QTy) {
514         uint8_t quantized = quantization::quantize4BitsWithFloatOffset(
515             srcH.at({i, j}), scale, offset);
516         if (j % 2 == 0) {
517           // Even columns use LSB 4-bit.
518           destH.at({i, j / 2}) = quantized;
519         } else {
520           // Odd columns use MSB 4-bit.
521           destH.at({i, j / 2}) |= quantized << 4;
522         }
523       } else {
524         llvm_unreachable("Not yet supported");
525       }
526     }
527 
528     // Now set the scale/offset at the end of each row.
529     destH.setFusedScaleOffsetInRow<T>(i, scale, offset);
530   }
531 }
532 
533 /// Generic function to compute the quantization parameters for an input
534 /// floating-point tensor \p tensor with given schema \p qSchema and type
535 /// \p qTy. A separate set of quantization parameters (scale, offset) will
536 /// be computed for each group of \p qStep indices along the \p qDim dimension.
537 /// This allows quantizing a given tensor with finer granularity (e.g. rowwise
538 /// or channelwise).
539 /// For example, for a tensor of size [4, 6, 8, 10], qDim = 1 and qStep = 3:
540 /// -> one set of quantization parameters will be computed for [:,0:2,:,:].
541 /// -> one set of quantization parameters will be computed for [:,3:5,:,:].
542 /// The number of sets of computed quantization parameters (scale, offset) is
543 /// tensor.dims()[qDim] / qStep. \returns the set of quantization parameters.
544 std::vector<TensorQuantizationParams>
545 getTensorQuantizationParams(const Tensor &tensor, Schema qSchema = Asymmetric,
546                             ElemKind qTy = ElemKind::Int8QTy, dim_t qDim = 0,
547                             dim_t qStep = 1);
548 
549 /// Similar function to the one above with the difference that the quantization
550 /// parameters scales and offsets are written into separate tensors \p scales
551 /// and \p offsets which are assummed allocated with the correct type and size.
552 void getTensorQuantizationParams(const Tensor &tensor, Tensor &scales,
553                                  Tensor &offsets, Schema qSchema = Asymmetric,
554                                  ElemKind qTy = ElemKind::Int8QTy,
555                                  dim_t qDim = 0, dim_t qStep = 1);
556 
557 /// Generic function to quantize a given input floating-point tensor \p tensor
558 /// with given tensor quantization parameters \p TQP and type \p qTy. A separate
559 /// set of quantization parameters (scale, offset) is provided for each group
560 /// of \p qStep indices along the \p qDim dimension and can be obtained using
561 /// the function \ref getTensorQuantizationParams. This allows quantizing a
562 /// given tensor with finer granularity (e.g. rowwise or channelwise).
563 /// For example, for a tensor of size [4, 6, 8, 10], qDim = 1 and qStep = 3:
564 /// -> one set of quantization parameters will be provided for [:,0:2,:,:].
565 /// -> one set of quantization parameters will be provided for [:,3:5,:,:].
566 /// The number of sets of provided quantization parameters (scale, offset) is
567 /// tensor.dims()[qDim] / qStep. \returns the quantized tensor.
568 Tensor quantizeTensor(const Tensor &tensor,
569                       llvm::ArrayRef<TensorQuantizationParams> TQP,
570                       ElemKind qTy = ElemKind::Int8QTy, dim_t qDim = 0,
571                       dim_t qStep = 1);
572 
573 /// Similar function to the one above with the difference that the quantization
574 /// parameters scales and offsets are loaded from separate tensors \p scales
575 /// and \p offsets.
576 Tensor quantizeTensor(const Tensor &tensor, const Tensor &scales,
577                       const Tensor &offsets, ElemKind qTy = ElemKind::Int8QTy,
578                       dim_t qDim = 0, dim_t qStep = 1);
579 
580 /// Verify if float is an exact power of 2 (mantissa is exactly 1.0).
581 bool isFloatPowerOf2(float val);
582 
583 /// Get float 2's exponent.
584 int getFloat2Exp(float val);
585 
586 } // namespace quantization
587 } // namespace glow
588 
589 #endif // GLOW_QUANTIZATION_BASE_BASE_H
590