1 /**
2  * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "glow/Base/Tensor.h"
18 
19 #include "glow/Base/Type.h"
20 
21 #include "llvm/Support/NativeFormatting.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include <glog/logging.h>
24 
25 using namespace glow;
26 
27 namespace {
28 
29 /// This is a helper method that's used in the visualization of tensors.
valueToChar(ElemTy input)30 template <class ElemTy> static char valueToChar(ElemTy input) {
31   char ch = ' ';
32   const double val = input;
33   if (val > 0.2) {
34     ch = '.';
35   }
36   if (val > 0.4) {
37     ch = ',';
38   }
39   if (val > 0.6) {
40     ch = ':';
41   }
42   if (val > 0.8) {
43     ch = 'o';
44   }
45   if (val > 1.0) {
46     ch = 'O';
47   }
48   if (val > 1.5) {
49     ch = '0';
50   }
51   if (val > 2.0) {
52     ch = '@';
53   }
54   if (val < -0.1) {
55     ch = '-';
56   }
57   if (val < -0.2) {
58     ch = '~';
59   }
60   if (val < -0.4) {
61     ch = '=';
62   }
63   if (val < -1.0) {
64     ch = '#';
65   }
66   return ch;
67 }
68 
dumpShape(llvm::ArrayRef<dim_t> shape,llvm::raw_ostream & os)69 static void dumpShape(llvm::ArrayRef<dim_t> shape, llvm::raw_ostream &os) {
70   os << "shape: ( ";
71   for (auto &d : shape) {
72     os << d << " ";
73   }
74   os << ")";
75 }
76 
77 template <class ElemTy>
dumpGenericImpl(Handle<ElemTy> handle,llvm::raw_ostream & os,unsigned maxNumElem)78 static void dumpGenericImpl(Handle<ElemTy> handle, llvm::raw_ostream &os,
79                             unsigned maxNumElem) {
80   auto shape = handle.dims();
81   size_t numDims = shape.size();
82   auto &Ty = handle.getType();
83 
84   // Check for 0-dimensional tensor.
85   if (!numDims) {
86     os << "[ Scalar containing: ";
87     llvm::write_double(os, handle.raw(0), llvm::FloatStyle::Fixed, 3);
88     os << " ]\n";
89     return;
90   }
91 
92   // Output shape.
93   dumpShape(shape, os);
94   os << "\n";
95 
96   // Check for tensor of size 0.
97   if (handle.getUnpaddedSizeInBytes() == 0) {
98     os << "[ tensor has no elements ]\n";
99     return;
100   }
101 
102   ElemTy mx = handle.raw(0);
103   ElemTy mn = handle.raw(0);
104 
105   for (auto elem : handle) {
106     mx = std::max(mx, elem);
107     mn = std::min(mn, elem);
108   }
109 
110   // Check for zero tensor.
111   if (mn == ElemTy(.0) && mx == ElemTy(.0)) {
112     os << "[ Zero tensor ]\n";
113     return;
114   }
115 
116   // Output max and min.
117   os << "max: ";
118   llvm::write_double(os, mx, llvm::FloatStyle::Fixed, 3);
119   os << "  min: ";
120   llvm::write_double(os, mn, llvm::FloatStyle::Fixed, 3);
121   os << "\n";
122 
123   os << "[";
124 
125   for (size_t i = 0, e = std::min<size_t>(maxNumElem, handle.size()); i < e;
126        i++) {
127 
128     // Print one open brace at the beginning of every row, slice, and tensor.
129     for (size_t j = 0, e = numDims - 1; numDims > 1 && j < e; j++) {
130       if (i % Ty.getSliceSize(j + 1) == 0) {
131         // This iteration of outer loop is a new row, slice or tensor.
132         os << "[";
133       }
134     }
135 
136     // Print the value at the current index.
137     llvm::write_double(os, handle.raw(i), llvm::FloatStyle::Fixed, 3);
138 
139     // Print one closed brace at the end of every row, slice, or tensor.
140     for (size_t j = 0, e = numDims - 1; numDims > 1 && j < e; j++) {
141       size_t next_index = i + 1;
142       if (next_index % Ty.getSliceSize(j + 1) == 0u) {
143         os << "]";
144       }
145     }
146 
147     os << ", ";
148 
149     // Print one newline at the end of every row, slice, or tensor.
150     for (size_t j = 0, e = numDims - 1; numDims > 1 && j < e; j++) {
151       size_t next_index = i + 1;
152       if (next_index % Ty.getSliceSize(j + 1) == 0u) {
153         // Next iteration of outer loop will be a new row, slice or tensor.
154         os << "\n";
155       }
156     }
157   }
158 
159   if (handle.size() > maxNumElem) {
160     os << "...";
161   }
162 
163   os << "]\n";
164 
165   os.flush();
166 }
167 
168 template <class ElemTy>
dumpAsciiGenericImpl(Handle<ElemTy> handle,llvm::raw_ostream & os)169 static void dumpAsciiGenericImpl(Handle<ElemTy> handle, llvm::raw_ostream &os) {
170   auto d = handle.dims();
171 
172   if (d.size() == 2) {
173     for (dim_t x = 0; x < d[0]; x++) {
174       for (dim_t y = 0; y < d[1]; y++) {
175         auto val = handle.at({x, y});
176         os << valueToChar(val);
177       }
178       os << "\n";
179     }
180   } else if (d.size() == 3) {
181     // Print monochrome (one-color channel) tensors:
182     if (d[2] == 1) {
183       for (dim_t x = 0; x < d[0]; x++) {
184         for (dim_t y = 0; y < d[1]; y++) {
185           auto val = handle.at({x, y, 0});
186           os << valueToChar(val);
187         }
188         os << "\n";
189       }
190     } else {
191       for (dim_t z = 0; z < d[2]; z++) {
192         os << "\n";
193         for (dim_t x = 0; x < d[0]; x++) {
194           for (dim_t y = 0; y < d[1]; y++) {
195             auto val = handle.at({x, y, z});
196             os << valueToChar(val);
197           }
198           os << "\n";
199         }
200       }
201     }
202 
203   } else {
204     llvm_unreachable("Invalid tensor size");
205   }
206 
207   os.flush();
208 }
209 
210 /// This is a slow generic transpose. This method performs a single for loop
211 /// over a single dimension, or if we've reached the last dimension perform a
212 /// single copy of a single element.
213 template <class ElemTy>
214 static void
transposeGenericImpl(const Handle<ElemTy> & src,Handle<ElemTy> & dest,dim_t * srcCoor,dim_t * destCoor,llvm::ArrayRef<unsigned_t> shuffle,unsigned depth=0)215 transposeGenericImpl(const Handle<ElemTy> &src, Handle<ElemTy> &dest,
216                      dim_t *srcCoor, dim_t *destCoor,
217                      llvm::ArrayRef<unsigned_t> shuffle, unsigned depth = 0) {
218   if (depth == shuffle.size()) {
219     auto srcIdx = llvm::ArrayRef<dim_t>(srcCoor, depth);
220     auto destIdx = llvm::ArrayRef<dim_t>(destCoor, depth);
221     dest.at(destIdx) = src.at(srcIdx);
222     return;
223   }
224 
225   // Iterate over one dimension and continue recursively to the next dim.
226   for (dim_t x = 0, e = dest.dims()[depth]; x < e; x++) {
227     unsigned_t swizzledDepth = shuffle[depth];
228     srcCoor[swizzledDepth] = x;
229     destCoor[depth] = x;
230     transposeGenericImpl(src, dest, srcCoor, destCoor, shuffle, depth + 1);
231   }
232 }
233 
234 /// Faster function for transposing a tensor for important/common tensor
235 /// shapes. If a transpose successfully occurs, the function \returns true;
236 /// otherwise it \returns false, representing no transpose occurred and some
237 /// other transpose function (e.g. transposeGenericImpl) must be called. \p
238 /// dest is the tensor to transpose, and \p shuffle defines how to transpose.
239 template <class ElemTy>
tryTransposeFastImpl(const Handle<ElemTy> & src,Handle<ElemTy> & dest,llvm::ArrayRef<unsigned_t> shuffle)240 static bool tryTransposeFastImpl(const Handle<ElemTy> &src,
241                                  Handle<ElemTy> &dest,
242                                  llvm::ArrayRef<unsigned_t> shuffle) {
243   const dim_t numDims = dest.dims().size();
244   dim_t srcCoorArr[max_tensor_dimensions];
245   dim_t destCoorArr[max_tensor_dimensions] = {0};
246   auto srcCoor = llvm::ArrayRef<dim_t>(srcCoorArr, numDims);
247   auto destCoor = llvm::ArrayRef<dim_t>(destCoorArr, numDims);
248 
249   /// This defines a single depth of the for loop used to iterate over the
250   /// source and destination tensors for transposing.
251 #define TRANSPOSE_LOOP_LEVEL(DEPTH_)                                           \
252   for (srcCoorArr[shuffle[DEPTH_]] = 0, destCoorArr[DEPTH_] = 0;               \
253        destCoorArr[DEPTH_] < dest.dims()[DEPTH_];                              \
254        srcCoorArr[shuffle[DEPTH_]]++, destCoorArr[DEPTH_]++)
255 
256   switch (numDims) {
257   case 2:
258     TRANSPOSE_LOOP_LEVEL(1) {
259       TRANSPOSE_LOOP_LEVEL(0) { dest.at(destCoor) = src.at(srcCoor); }
260     }
261     return true;
262   case 4:
263     TRANSPOSE_LOOP_LEVEL(1) {
264       TRANSPOSE_LOOP_LEVEL(2) {
265         TRANSPOSE_LOOP_LEVEL(0) {
266           TRANSPOSE_LOOP_LEVEL(3) { dest.at(destCoor) = src.at(srcCoor); }
267         }
268       }
269     }
270     return true;
271   }
272   return false;
273 }
274 
275 template <class ElemTy>
transposeSelectImpl(const Handle<ElemTy> & src,Handle<ElemTy> & dest,llvm::ArrayRef<unsigned_t> shuffle)276 static void transposeSelectImpl(const Handle<ElemTy> &src, Handle<ElemTy> &dest,
277                                 llvm::ArrayRef<unsigned_t> shuffle) {
278   bool transposeOccurred = tryTransposeFastImpl(src, dest, shuffle);
279   if (!transposeOccurred) {
280     dim_t srcCoor[max_tensor_dimensions];
281     dim_t destCoor[max_tensor_dimensions];
282     transposeGenericImpl(src, dest, srcCoor, destCoor, shuffle);
283   }
284 }
285 
286 template <class ElemTy>
isTiledImpl(const Tensor * tensor,unsigned_t axis,dim_t size,bool fractional)287 static bool isTiledImpl(const Tensor *tensor, unsigned_t axis, dim_t size,
288                         bool fractional) {
289   assert(axis < tensor->dims().size() && "Axis parameter invalid!");
290   assert(size <= tensor->dims()[axis] && "Size parameter invalid!");
291   assert(size >= 1 && "Size parameter invalid!");
292 
293   // When the tile size matches the dimension size then we return true.
294   // This is because a tensor can be considered a tiled version of itself.
295   if (size == tensor->dims()[axis]) {
296     return true;
297   }
298 
299   // If fractional tiling verification is disabled and the dimension size
300   // is NOT divisible by the tile size then we return false.
301   if (!fractional && ((tensor->dims()[axis] % size) != 0)) {
302     return false;
303   }
304 
305   static_assert(max_tensor_dimensions == 6,
306                 "Implementation assumes max_tensor_dimensions = 6.");
307 
308   // Get tensor view with maximum number of dimensions.
309   auto dimsMax = expandDimsToMax(tensor->dims());
310   Tensor tensorMax = tensor->getUnowned(dimsMax);
311   auto tensorH = tensorMax.getHandle<ElemTy>();
312   for (dim_t idx0 = 0; idx0 < dimsMax[0]; ++idx0) {
313     for (dim_t idx1 = 0; idx1 < dimsMax[1]; ++idx1) {
314       for (dim_t idx2 = 0; idx2 < dimsMax[2]; ++idx2) {
315         for (dim_t idx3 = 0; idx3 < dimsMax[3]; ++idx3) {
316           for (dim_t idx4 = 0; idx4 < dimsMax[4]; ++idx4) {
317             for (dim_t idx5 = 0; idx5 < dimsMax[5]; ++idx5) {
318               std::vector<dim_t> idx = {idx0, idx1, idx2, idx3, idx4, idx5};
319               std::vector<dim_t> idxWrapped = idx;
320               idxWrapped[axis] = (idx[axis] % size);
321               double delta = tensorH.at(idx) - tensorH.at(idxWrapped);
322               // Since any comparison with NAN returns false, we use a negated
323               // condition so that this function correctly returns false when
324               // delta is NAN.
325               if (!(delta == 0.0)) {
326                 return false;
327               }
328             }
329           }
330         }
331       }
332     }
333   }
334   return true;
335 }
336 } // namespace
337 
dumpAsciiImpl(const Tensor * T,llvm::raw_ostream & os)338 void glow::dumpAsciiImpl(const Tensor *T, llvm::raw_ostream &os) {
339   switch (T->getElementType()) {
340   case ElemKind::FloatTy:
341     return dumpAsciiGenericImpl(T->getHandle<float>(), os);
342   case ElemKind::Float16Ty:
343     return dumpAsciiGenericImpl(T->getHandle<float16_t>(), os);
344   case ElemKind::BFloat16Ty:
345     return dumpAsciiGenericImpl(T->getHandle<bfloat16_t>(), os);
346   case ElemKind::Int8QTy:
347     return dumpAsciiGenericImpl(T->getHandle<int8_t>(), os);
348   case ElemKind::UInt8QTy:
349     return dumpAsciiGenericImpl(T->getHandle<uint8_t>(), os);
350   case ElemKind::Int16QTy:
351     return dumpAsciiGenericImpl(T->getHandle<int16_t>(), os);
352   case ElemKind::Int32QTy:
353     return dumpAsciiGenericImpl(T->getHandle<int32_t>(), os);
354   case ElemKind::Int32ITy:
355     return dumpAsciiGenericImpl(T->getHandle<int32_t>(), os);
356   case ElemKind::Int64ITy:
357     return dumpAsciiGenericImpl(T->getHandle<int64_t>(), os);
358   case ElemKind::UInt8FusedQTy:
359     return dumpAsciiGenericImpl(T->getHandle<uint8_t>(), os);
360   case ElemKind::UInt8FusedFP16QTy:
361     return dumpAsciiGenericImpl(T->getHandle<uint8_t>(), os);
362   case ElemKind::UInt4FusedFP16QTy:
363     return dumpAsciiGenericImpl(T->getHandle<uint8_t>(), os);
364   case ElemKind::BoolTy:
365     return dumpAsciiGenericImpl(T->getHandle<bool>(), os);
366   }
367 }
368 
dumpAsciiImpl(const Tensor * T)369 void glow::dumpAsciiImpl(const Tensor *T) { dumpAsciiImpl(T, llvm::outs()); }
370 
dumpImpl(const Tensor * T,llvm::raw_ostream & os,unsigned maxNumElem)371 void glow::dumpImpl(const Tensor *T, llvm::raw_ostream &os,
372                     unsigned maxNumElem) {
373   switch (T->getElementType()) {
374   case ElemKind::FloatTy:
375     return dumpGenericImpl(T->getHandle<float>(), os, maxNumElem);
376   case ElemKind::Float16Ty:
377     return dumpGenericImpl(T->getHandle<float16_t>(), os, maxNumElem);
378   case ElemKind::BFloat16Ty:
379     return dumpGenericImpl(T->getHandle<bfloat16_t>(), os, maxNumElem);
380   case ElemKind::Int8QTy:
381     return dumpGenericImpl(T->getHandle<int8_t>(), os, maxNumElem);
382   case ElemKind::UInt8QTy:
383     return dumpGenericImpl(T->getHandle<uint8_t>(), os, maxNumElem);
384   case ElemKind::Int16QTy:
385     return dumpGenericImpl(T->getHandle<int16_t>(), os, maxNumElem);
386   case ElemKind::Int32QTy:
387     return dumpGenericImpl(T->getHandle<int32_t>(), os, maxNumElem);
388   case ElemKind::Int32ITy:
389     return dumpGenericImpl(T->getHandle<int32_t>(), os, maxNumElem);
390   case ElemKind::Int64ITy:
391     return dumpGenericImpl(T->getHandle<int64_t>(), os, maxNumElem);
392   case ElemKind::UInt8FusedQTy:
393     return dumpGenericImpl(T->getHandle<uint8_t>(), os, maxNumElem);
394   case ElemKind::UInt8FusedFP16QTy:
395     return dumpGenericImpl(T->getHandle<uint8_t>(), os, maxNumElem);
396   case ElemKind::UInt4FusedFP16QTy:
397     return dumpGenericImpl(T->getHandle<uint8_t>(), os, maxNumElem);
398   case ElemKind::BoolTy:
399     return dumpGenericImpl(T->getHandle<bool>(), os, maxNumElem);
400   }
401 }
402 
dumpImpl(const Tensor * T,unsigned maxNumElem)403 void glow::dumpImpl(const Tensor *T, unsigned maxNumElem) {
404   dumpImpl(T, llvm::outs(), maxNumElem);
405 }
406 
dumpImpl(const Tensor * T)407 void glow::dumpImpl(const Tensor *T) { dumpImpl(T, llvm::outs()); }
408 
409 // Dump functions.
dump(llvm::raw_ostream & os) const410 void Tensor::dump(llvm::raw_ostream &os) const { dumpImpl(this, os); }
411 
dump() const412 void Tensor::dump() const { dumpImpl(this, llvm::outs()); }
413 
toString() const414 std::string Tensor::toString() const {
415   std::string storage;
416   llvm::raw_string_ostream os(storage);
417   dumpImpl(this, os);
418   return os.str();
419 }
420 
getShapeToString() const421 std::string Tensor::getShapeToString() const {
422   std::string storage;
423   llvm::raw_string_ostream os(storage);
424   dumpShape(dims(), os);
425   return os.str();
426 }
427 
dump(llvm::raw_ostream & os,unsigned maxNumElem) const428 void Tensor::dump(llvm::raw_ostream &os, unsigned maxNumElem) const {
429   dumpImpl(this, os, maxNumElem);
430 }
431 
dump(unsigned maxNumElem) const432 void Tensor::dump(unsigned maxNumElem) const {
433   dumpImpl(this, llvm::outs(), maxNumElem);
434 }
435 
toString(unsigned maxNumElem) const436 std::string Tensor::toString(unsigned maxNumElem) const {
437   std::string storage;
438   llvm::raw_string_ostream os(storage);
439   dumpImpl(this, os, maxNumElem);
440   return os.str();
441 }
442 
443 /// Dump a textual representation of a specific number of elements in the Tensor
444 /// to std::string.
445 
genericTranspose(const Tensor * src,Tensor * dest,llvm::ArrayRef<unsigned_t> shuffle)446 void glow::genericTranspose(const Tensor *src, Tensor *dest,
447                             llvm::ArrayRef<unsigned_t> shuffle) {
448   DCHECK(src->dims().size() == shuffle.size())
449       << "Invalid dimensions " << src->dims().size()
450       << " != " << src->dims().size();
451 
452   dim_t newSizes[max_tensor_dimensions];
453 
454   // Generate the swizzled dimensions.
455   auto origDims = src->dims();
456   for (unsigned i = 0; i < origDims.size(); i++) {
457     newSizes[i] = origDims[shuffle[i]];
458   }
459 
460   // Resize the tensor to the transposed shape.
461   auto destType = Type::newShape(src->getType(), {newSizes, origDims.size()});
462   // genericTranspose function doesn't know how to set non-trivial strides and
463   // alignments and it cannot figure out the correct ones as it can be
464   // backend-specific. Therefore set the type to destType only if it is not set
465   // properly by the caller yet.
466   // Reset should be called anyways to allocate memory for the tensor.
467   if (dest->dims() != destType.dims()) {
468     dest->reset(destType);
469   } else {
470     dest->reset(dest->getType());
471   }
472 
473   // fill with 0 for padding bytes.
474   if (src->actualSize() != dest->actualSize()) {
475     dest->zero();
476   }
477 
478   switch (src->getElementType()) {
479   case ElemKind::FloatTy: {
480     auto srcH = src->getHandle<float>();
481     auto destH = dest->getHandle<float>();
482     transposeSelectImpl(srcH, destH, shuffle);
483     return;
484   }
485   case ElemKind::Float16Ty: {
486     auto srcH = src->getHandle<float16_t>();
487     auto destH = dest->getHandle<float16_t>();
488     transposeSelectImpl(srcH, destH, shuffle);
489     return;
490   }
491   case ElemKind::BFloat16Ty: {
492     auto srcH = src->getHandle<bfloat16_t>();
493     auto destH = dest->getHandle<bfloat16_t>();
494     transposeSelectImpl(srcH, destH, shuffle);
495     return;
496   }
497   case ElemKind::Int8QTy: {
498     auto srcH = src->getHandle<int8_t>();
499     auto destH = dest->getHandle<int8_t>();
500     transposeSelectImpl(srcH, destH, shuffle);
501     return;
502   }
503   case ElemKind::UInt8QTy: {
504     auto srcH = src->getHandle<uint8_t>();
505     auto destH = dest->getHandle<uint8_t>();
506     transposeSelectImpl(srcH, destH, shuffle);
507     return;
508   }
509   case ElemKind::Int16QTy: {
510     auto srcH = src->getHandle<int16_t>();
511     auto destH = dest->getHandle<int16_t>();
512     transposeSelectImpl(srcH, destH, shuffle);
513     return;
514   }
515   case ElemKind::Int32QTy: {
516     auto srcH = src->getHandle<int32_t>();
517     auto destH = dest->getHandle<int32_t>();
518     transposeSelectImpl(srcH, destH, shuffle);
519     return;
520   }
521   case ElemKind::Int32ITy: {
522     auto srcH = src->getHandle<int32_t>();
523     auto destH = dest->getHandle<int32_t>();
524     transposeSelectImpl(srcH, destH, shuffle);
525     return;
526   }
527   case ElemKind::Int64ITy: {
528     auto srcH = src->getHandle<int64_t>();
529     auto destH = dest->getHandle<int64_t>();
530     transposeSelectImpl(srcH, destH, shuffle);
531     return;
532   }
533   case ElemKind::UInt8FusedQTy: {
534     llvm_unreachable("Transposing UInt8FusedQTy is unsupported.");
535   }
536   case ElemKind::UInt8FusedFP16QTy: {
537     llvm_unreachable("Transposing UInt8FusedFP16QTy is unsupported.");
538   }
539   case ElemKind::UInt4FusedFP16QTy: {
540     llvm_unreachable("Transposing UInt4FusedFP16QTy is unsupported.");
541   }
542   case ElemKind::BoolTy: {
543     auto srcH = src->getHandle<bool>();
544     auto destH = dest->getHandle<bool>();
545     transposeSelectImpl(srcH, destH, shuffle);
546     return;
547   }
548   }
549 }
550 
expandDimsToMax(llvm::ArrayRef<dim_t> currDims)551 ShapeVector glow::expandDimsToMax(llvm::ArrayRef<dim_t> currDims) {
552   ShapeVector newDims(currDims.begin(), currDims.end());
553   for (size_t i = newDims.size(); i < max_tensor_dimensions; i++) {
554     newDims.push_back(1);
555   }
556   return newDims;
557 }
558 
reduceDims(llvm::ArrayRef<dim_t> dims,llvm::ArrayRef<unsigned_t> axes,bool keepDims)559 ShapeVector glow::reduceDims(llvm::ArrayRef<dim_t> dims,
560                              llvm::ArrayRef<unsigned_t> axes, bool keepDims) {
561   ShapeVector newDims;
562   for (unsigned_t dim = 0, end = dims.size(); dim < end; ++dim) {
563     auto it = std::find(axes.begin(), axes.end(), dim);
564     bool dimReduced = (it != axes.end());
565     if (dimReduced) {
566       if (keepDims) {
567         newDims.push_back(1);
568       } else {
569         continue;
570       }
571     } else {
572       newDims.push_back(dims[dim]);
573     }
574   }
575   return newDims;
576 }
577 
init(InitKind init,float val,PseudoRNG & PRNG)578 void Tensor::init(InitKind init, float val, PseudoRNG &PRNG) {
579   assert(!isDeviceResident() && "Tensor must reside on host to access data.");
580   switch (init) {
581   case InitKind::Zero:
582     zero();
583     break;
584 
585   case InitKind::Broadcast: {
586     switch (getElementType()) {
587     case ElemKind::FloatTy: {
588       getHandle<float>().clear(val);
589       break;
590     }
591     case ElemKind::Float16Ty: {
592       getHandle<float16_t>().clear(float16_t(val));
593       break;
594     }
595     case ElemKind::BFloat16Ty: {
596       getHandle<bfloat16_t>().clear(bfloat16_t(val));
597       break;
598     }
599     case ElemKind::Int8QTy: {
600       getHandle<int8_t>().clear(val);
601       break;
602     }
603     case ElemKind::UInt8QTy: {
604       getHandle<uint8_t>().clear(val);
605       break;
606     }
607     case ElemKind::Int16QTy: {
608       getHandle<int16_t>().clear(val);
609       break;
610     }
611     case ElemKind::Int32QTy: {
612       getHandle<int32_t>().clear(val);
613       break;
614     }
615     case ElemKind::Int32ITy: {
616       getHandle<int32_t>().clear(val);
617       break;
618     }
619     case ElemKind::Int64ITy: {
620       getHandle<int64_t>().clear(val);
621       break;
622     }
623 
624 #define FUSED_CASE(ELEM_KIND, DATA_TYPE)                                       \
625   case ElemKind::ELEM_KIND: {                                                  \
626     DCHECK(dims().size() == 2)                                                 \
627         << "Fused tensor must be 2-dimensional but instead has "               \
628         << dims().size() << " dimensions.";                                    \
629     DCHECK(dims()[1] > 2 * sizeof(DATA_TYPE))                                  \
630         << "Fused tensor must have space for scale/offset, but only has  "     \
631         << dims()[1] << " columns.";                                           \
632     auto H = getHandle<uint8_t>();                                             \
633     for (dim_t i = 0; i < dims()[0]; i++) {                                    \
634       for (dim_t j = 0, f = dims()[1] - 2 * sizeof(DATA_TYPE); j < f; j++) {   \
635         H.at({i, j}) = val;                                                    \
636       }                                                                        \
637     }                                                                          \
638     break;                                                                     \
639   }
640       FUSED_CASE(UInt8FusedQTy, float);
641       FUSED_CASE(UInt8FusedFP16QTy, float16_t);
642       FUSED_CASE(UInt4FusedFP16QTy, float16_t);
643 #undef FUSED_CASE
644 
645     case ElemKind::BoolTy: {
646       getHandle<bool>().clear(val);
647       break;
648     }
649     }
650     break;
651   }
652 
653   case InitKind::Xavier: {
654     switch (getElementType()) {
655     case ElemKind::FloatTy: {
656       getHandle<float>().initXavier(val, PRNG);
657       break;
658     }
659     case ElemKind::Float16Ty: {
660       getHandle<float16_t>().initXavier(val, PRNG);
661       break;
662     }
663     case ElemKind::BFloat16Ty: {
664       getHandle<bfloat16_t>().initXavier(val, PRNG);
665       break;
666     }
667     default: {
668       llvm_unreachable("Undefined to Xavier-initialize non-Float Tensors.");
669     }
670     }
671     break;
672   }
673   }
674 }
675 
convertToType(ElemKind newTy)676 void Tensor::convertToType(ElemKind newTy) {
677   assert(!isDeviceResident() && "Tensor must reside on host to access data.");
678   *this = this->getCopyConvertedToType(newTy);
679 }
680 
getCopyConvertedToType(ElemKind newKind) const681 Tensor Tensor::getCopyConvertedToType(ElemKind newKind) const {
682   assert(!isDeviceResident() && "Tensor must reside on host to access data.");
683   const ElemKind origKind = getElementType();
684   DCHECK((origKind == ElemKind::FloatTy && newKind == ElemKind::Float16Ty) ||
685          (origKind == ElemKind::FloatTy && newKind == ElemKind::BFloat16Ty) ||
686          (origKind == ElemKind::FloatTy && newKind == ElemKind::Int32ITy) ||
687          (origKind == ElemKind::FloatTy && newKind == ElemKind::Int64ITy) ||
688          (origKind == ElemKind::Float16Ty && newKind == ElemKind::FloatTy) ||
689          (origKind == ElemKind::BFloat16Ty && newKind == ElemKind::FloatTy) ||
690          (origKind == ElemKind::Int64ITy && newKind == ElemKind::Int32ITy) ||
691          (origKind == ElemKind::Int64ITy && newKind == ElemKind::FloatTy) ||
692          (origKind == ElemKind::Int32ITy && newKind == ElemKind::Int64ITy) ||
693          (origKind == ElemKind::Int32ITy && newKind == ElemKind::FloatTy) ||
694          (origKind == ElemKind::UInt8FusedQTy &&
695           newKind == ElemKind::UInt8FusedFP16QTy))
696       << "Conversion from " << Type::getElementName(origKind).str() << " to "
697       << Type::getElementName(newKind).str() << " is not yet implemented";
698 
699   if (!isQuantizedElemKind(newKind)) {
700     Tensor tmp(newKind, dims());
701     switch (newKind) {
702     case ElemKind::Float16Ty:
703       tmp.copyWithCast<float16_t, float>(this);
704       break;
705     case ElemKind::BFloat16Ty:
706       tmp.copyWithCast<bfloat16_t, float>(this);
707       break;
708 
709     case ElemKind::FloatTy:
710       if (getElementType() == ElemKind::Int32ITy) {
711         tmp.copyWithCast<float, int32_t>(this);
712       } else if (getElementType() == ElemKind::Int64ITy) {
713         tmp.copyWithCast<float, int64_t>(this);
714       } else if (getElementType() == ElemKind::Float16Ty) {
715         tmp.copyWithCast<float, float16_t>(this);
716       } else if (getElementType() == ElemKind::BFloat16Ty) {
717         tmp.copyWithCast<float, bfloat16_t>(this);
718       } else if (getElementType() == ElemKind::FloatTy) {
719         tmp.copyRawFrom(this);
720       } else {
721         llvm_unreachable("Invalid conversion to FLOAT.");
722       }
723       break;
724 
725     case ElemKind::Int32ITy:
726       if (getElementType() == ElemKind::Int64ITy) {
727         tmp.copyWithCast<int32_t, int64_t>(this);
728       } else if (getElementType() == ElemKind::FloatTy) {
729         tmp.copyWithCast<int32_t, float>(this);
730       } else {
731         llvm_unreachable("Invalid conversion from FLOAT.");
732       }
733       break;
734     case ElemKind::Int64ITy:
735       if (getElementType() == ElemKind::Int32ITy) {
736         tmp.copyWithCast<int64_t, int32_t>(this);
737       } else {
738         llvm_unreachable("Invalid conversion from FLOAT.");
739       }
740       break;
741 
742     default:
743       llvm_unreachable("Type not supported");
744     }
745     return tmp;
746   }
747 
748   // Handle Fused conversion. Currently only supports UInt8FusedQTy ->
749   // UInt8FusedFP16QTy.
750   DCHECK(origKind == ElemKind::UInt8FusedQTy && dims().size() == 2)
751       << "UInt8FusedQTy must be 2 dimensional.";
752   Tensor tmp(newKind,
753              {dims()[0], dims()[1] - 2 * ((dim_t)sizeof(float) -
754                                           (dim_t)sizeof(float16_t))},
755              1.0, 0);
756 
757   const size_t dstWidth = tmp.dims()[1];
758   auto srcH = getHandle<uint8_t>();
759   auto dstH = tmp.getHandle<uint8_t>();
760   for (dim_t i = 0, e = dims()[0]; i < e; i++) {
761     // Copy the scale/offset from src to dst.
762     float scale, offset;
763     std::tie(scale, offset) = srcH.getFusedScaleOffsetFromRow<float>(i);
764     dstH.setFusedScaleOffsetInRow<float16_t>(i, static_cast<float16_t>(scale),
765                                              static_cast<float16_t>(offset));
766 
767     // Copy over the row's uint8 data from src to dst; scales and offsets were
768     // already copied over above.
769     for (dim_t j = 0, f = dstWidth - 2 * sizeof(float16_t); j < f; j++) {
770       dstH.at({i, j}) = srcH.at({i, j});
771     }
772   }
773   return tmp;
774 }
775 
776 namespace glow {
operator <<(llvm::raw_ostream & os,const Tensor & t)777 llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Tensor &t) {
778   t.dump(os);
779   return os;
780 }
781 
operator <<(llvm::raw_ostream & os,const Tensor * t)782 llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Tensor *t) {
783   assert(t != nullptr && "Null Pointer.");
784   t->dump(os);
785   return os;
786 }
787 
moveToDevice(DeviceTensorTransferManager * deviceManager,void * locationContext)788 void Tensor::moveToDevice(DeviceTensorTransferManager *deviceManager,
789                           void *locationContext) {
790   if (deviceResidency_ == nullptr) {
791     deviceResidency_ = new DeviceResidencyInfo();
792   }
793   deviceResidency_->deviceManager_ = deviceManager;
794   deviceResidency_->locationContext_ = locationContext;
795   deviceResidency_->tensorResidency_ =
796       DeviceResidencyInfo::TensorResidency::Device;
797 }
798 
ensureOnHost()799 void Tensor::ensureOnHost() {
800   if (deviceResidency_ == nullptr) {
801     // already on host.
802     return;
803   }
804   if (deviceResidency_->isDeviceResident()) {
805     deviceResidency_->deviceManager_->transferFromDevice(*this);
806   }
807   assert(!isDeviceResident());
808 }
809 
copyRawToDevice(const Tensor * t)810 void Tensor::copyRawToDevice(const Tensor *t) {
811   assert(isDeviceResident());
812   void *locationContext = deviceResidency_->locationContext_;
813   DeviceTensorTransferManager *DM = deviceResidency_->deviceManager_;
814   clearDeviceResidency();
815   copyRawFrom(t);
816   DM->transferToDevice(*this, locationContext);
817 }
818 
isTiled(unsigned_t axis,dim_t size,bool fractional) const819 bool Tensor::isTiled(unsigned_t axis, dim_t size, bool fractional) const {
820   switch (getElementType()) {
821   case ElemKind::FloatTy: {
822     return isTiledImpl<float>(this, axis, size, fractional);
823   }
824   case ElemKind::Float16Ty: {
825     return isTiledImpl<float16_t>(this, axis, size, fractional);
826   }
827   case ElemKind::Int8QTy: {
828     return isTiledImpl<int8_t>(this, axis, size, fractional);
829   }
830   case ElemKind::UInt8QTy: {
831     return isTiledImpl<uint8_t>(this, axis, size, fractional);
832   }
833   case ElemKind::Int16QTy: {
834     return isTiledImpl<int16_t>(this, axis, size, fractional);
835   }
836   case ElemKind::Int32QTy: {
837     return isTiledImpl<int32_t>(this, axis, size, fractional);
838   }
839   case ElemKind::Int32ITy: {
840     return isTiledImpl<int32_t>(this, axis, size, fractional);
841   }
842   case ElemKind::Int64ITy: {
843     return isTiledImpl<int64_t>(this, axis, size, fractional);
844   }
845   case ElemKind::BoolTy: {
846     return isTiledImpl<bool>(this, axis, size, fractional);
847   }
848   default: { llvm_unreachable("isTiled: Precision not supported!"); }
849   }
850 }
851 
isTiled(llvm::ArrayRef<unsigned_t> axes,llvm::ArrayRef<dim_t> sizes,bool fractional) const852 bool Tensor::isTiled(llvm::ArrayRef<unsigned_t> axes,
853                      llvm::ArrayRef<dim_t> sizes, bool fractional) const {
854   assert(axes.size() == sizes.size() &&
855          "Mismatch between axes and sizes length!");
856   for (size_t idx = 0, end = axes.size(); idx < end; ++idx) {
857     if (!isTiled(axes[idx], sizes[idx], fractional)) {
858       return false;
859     }
860   }
861   return true;
862 }
863 
isSliceContiguous(llvm::ArrayRef<dim_t> sliceShape,llvm::ArrayRef<dim_t> tensorShape)864 bool isSliceContiguous(llvm::ArrayRef<dim_t> sliceShape,
865                        llvm::ArrayRef<dim_t> tensorShape) {
866   assert(sliceShape.size() == tensorShape.size() &&
867          "Array length mismatch for slice/tensor sizes!");
868   // Search first non-singleton slice dimension. If all the dimensions are
869   // singleton then by convention the first non-singleton dimension is the
870   // slice size.
871   size_t firstNonSingleDim = sliceShape.size();
872   for (size_t dim = 0, dimEnd = sliceShape.size(); dim < dimEnd; ++dim) {
873     if (sliceShape[dim] != 1) {
874       firstNonSingleDim = dim;
875       break;
876     }
877   }
878   // First non-singleton slice dimension can be partially or fully extracted.
879   // The following dimensions must be fully extracted.
880   for (size_t dim = firstNonSingleDim + 1, dimEnd = sliceShape.size();
881        dim < dimEnd; ++dim) {
882     if (sliceShape[dim] != tensorShape[dim]) {
883       return false;
884     }
885   }
886   return true;
887 }
888 
889 } // namespace glow
890