1 /**
2  * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef GLOW_BASE_TENSOR_H
17 #define GLOW_BASE_TENSOR_H
18 
19 #include <algorithm>
20 #include <cassert>
21 #include <vector>
22 
23 #include "glow/Base/DeviceTensorTransferManager.h"
24 #include "glow/Base/Type.h"
25 #include "glow/Support/Compiler.h"
26 #include "glow/Support/Memory.h"
27 #include "glow/Support/Random.h"
28 
29 #include "llvm/ADT/ArrayRef.h"
30 #include "llvm/Support/raw_ostream.h"
31 
32 namespace glow {
33 
34 //===----------------------------------------------------------------------===//
35 //                               Tensor
36 //===----------------------------------------------------------------------===//
37 
38 template <class ElemTy> class Handle;
39 
40 class Tensor;
41 class TensorPool;
42 
43 void genericTranspose(const Tensor *src, Tensor *dest,
44                       llvm::ArrayRef<unsigned_t> shuffle);
45 
46 /// Helper function that \returns a ShapeVector of those dimensions in \p
47 /// currDims expanded with dimension = 1 until the maximum tensor dimension is
48 /// reached. The number of elements in the input dims is the same as in the
49 /// returned dims. For example, input {2,1,4} would result in {2,1,4,1,1,1}.
50 ShapeVector expandDimsToMax(llvm::ArrayRef<dim_t> currDims);
51 
52 /// Helper function that \returns a ShapeVector obtained from \p dims by
53 /// reducing (setting to 1) the dimensions given by \p axes. If the flag
54 /// \p keepDims is also used then the reduced dimensions are kept, otherwise
55 /// are pruned. For example, given the dimensions [2,3,4] and axes [0,2] the
56 /// returned shape will be [1,3,1] for keepDims true and [3] for keepDims false.
57 ShapeVector reduceDims(llvm::ArrayRef<dim_t> dims,
58                        llvm::ArrayRef<unsigned_t> axes, bool keepDims);
59 
60 namespace runtime {
61 class DeviceManager;
62 }
63 
64 /// Holds information regarding whether this Tensor exists in a device-specific
65 /// form, either resident or specific for a device, and what device holds it.
66 class DeviceResidencyInfo final {
67   enum class TensorResidency {
68     Host,
69     Device,
70   };
71 
72   // A pointer to the device manager of the device on which the tensor
73   // resides.
74   DeviceTensorTransferManager *deviceManager_{nullptr};
75   /// The residency status of the tensor.
76   TensorResidency tensorResidency_{TensorResidency::Host};
77   // A pointer to a context structure, containing the required info to access
78   // tensor data and perform transfers.
79   void *locationContext_{nullptr};
80 
81 public:
DeviceResidencyInfo()82   DeviceResidencyInfo()
83       : deviceManager_(nullptr), tensorResidency_(TensorResidency::Host),
84         locationContext_(nullptr) {}
85 
86   /// Move ctor.
87   DeviceResidencyInfo(DeviceResidencyInfo &&other) = delete;
88 
89   /// Move assignment operator.
90   DeviceResidencyInfo &operator=(DeviceResidencyInfo &&other) = delete;
91 
~DeviceResidencyInfo()92   ~DeviceResidencyInfo() {
93     // If a tensor is device resident, let its device manager free the device
94     // buffer.
95     if (isDeviceResident()) {
96       deviceManager_->releaseDeviceTensor(locationContext_);
97     }
98   }
99 
100   /// Removes all device specific state.
clear()101   void clear() {
102     deviceManager_ = nullptr;
103     locationContext_ = nullptr;
104     tensorResidency_ = TensorResidency::Host;
105   }
106 
107   /// \returns true if this Tensor is resident or specific for a device.
isDeviceResident()108   bool isDeviceResident() const {
109     assert((tensorResidency_ == TensorResidency::Host || deviceManager_) &&
110            "Device resident tensor must have an assigned device manager.");
111     return tensorResidency_ == TensorResidency::Device;
112   }
113 
114   /// \returns the DeviceManager this tensor is resident on, if any.
getDeviceManager()115   DeviceTensorTransferManager *getDeviceManager() const {
116     return deviceManager_;
117   }
118 
119   /// \returns the device specific location context for a resident Tensor.
getLocationContext()120   void *getLocationContext() const { return locationContext_; }
121 
122   friend class Tensor;
123 };
124 
125 /// A class that represents a contiguous n-dimensional array (a tensor).
126 class Tensor final {
127 public:
128   /// Specifies the kind initialization for the tensor.
129   enum class InitKind {
130     Zero,      // The tensor is initialized to zero.
131     Broadcast, // Broadcast a single value to all elements.
132     Xavier,    // Init the tensor with random values using the Xavier method.
133   };
134 
135 private:
136   /// A pointer to the tensor data.
137   char *data_{nullptr};
138 
139   /// The type of the tensor.
140   Type type_;
141 
142   /// If the tensor is unowned.
143   bool isUnowned_{false};
144 
145   /// The TensorPool that is managing this Tensor (if any).
146   TensorPool *tensorPool_{nullptr};
147 
148   /// The device residency info accosiated with the tensor.
149   DeviceResidencyInfo *deviceResidency_{nullptr};
150 
151   /// If this tensor owns the DeviceResidencyInfo.
152   bool ownsDeviceResidency_{false};
153 
154   /// Size in bytes of the unpadded region memory. This is useful  communicating
155   /// the actual size of the data, this allows for copying only inputs and not
156   /// padding to the device.
157   size_t unpaddedSize_{0};
158 
159   template <class ElemTy> friend class Handle;
160 
161   /// \returns a pointer to the tensor data buffer.
getData()162   char *getData() const { return data_; }
163 
164 public:
165   /// \returns true if it is an unowned tensor.
isUnowned()166   bool isUnowned() const { return isUnowned_; }
167 
168   /// \returns the number of allocated bytes pointed to by \ref data_.
getUnpaddedSizeInBytes()169   size_t getUnpaddedSizeInBytes() const { return unpaddedSize_; }
170 
171   /// \returns the type of the tensor.
getType()172   const Type &getType() const { return type_; }
173 
174   /// Set the type of the Tensor to \p t.
setType(const TypeRef t)175   void setType(const TypeRef t) {
176     assert(type_.dims() == t->dims() && "New type must retain the same shape.");
177     assert(((type_.getElementType() == t->getElementType() &&
178              type_.size() == t->size()) ||
179             type_.getSizeInBytes() == t->getSizeInBytes()) &&
180            "New type must retain the same size in bytes.");
181     type_ = *t;
182   }
183 
184   /// \return the element type of the tensor.
getElementType()185   ElemKind getElementType() const { return type_.getElementType(); }
186 
187   /// \returns True if the coordinate is within the array.
isInBounds(llvm::ArrayRef<dim_t> indices)188   bool isInBounds(llvm::ArrayRef<dim_t> indices) const {
189     assert(type_.numSizes_ == indices.size() && "Invalid number of indices");
190     for (size_t i = 0u, e = indices.size(); i < e; i++) {
191       if (indices[i] >= type_.sizes_[i]) {
192         return false;
193       }
194     }
195     return true;
196   }
197 
198   /// Set the content of the tensor to zero. If \p resetFusedScalesOffsets, then
199   /// fused scales/offsets will be set to 1.0/0.0 as well.
200   void zero(bool resetFusedScalesOffsets = false) {
201     assert(!isDeviceResident() && "Tensor must reside on host to access data.");
202     size_t size = actualSize();
203     // Quantized tensors should go to their offset.
204     switch (type_.getElementType()) {
205     case ElemKind::Int8QTy: {
206       auto *data = reinterpret_cast<int8_t *>(getData());
207       std::fill(&data[0], &data[0] + size, (int8_t)type_.getOffset());
208       break;
209     }
210     case ElemKind::UInt8QTy: {
211       auto *data = reinterpret_cast<uint8_t *>(getData());
212       std::fill(&data[0], &data[0] + size, (uint8_t)type_.getOffset());
213       break;
214     }
215     case ElemKind::Int16QTy: {
216       auto *data = reinterpret_cast<int16_t *>(getData());
217       std::fill(&data[0], &data[0] + size, (int16_t)type_.getOffset());
218       break;
219     }
220     case ElemKind::Int32QTy: {
221       auto *data = reinterpret_cast<int32_t *>(getData());
222       std::fill(&data[0], &data[0] + size, (int32_t)type_.getOffset());
223       break;
224     }
225 #define FUSED_CASE(ELEM_KIND, DATA_TYPE)                                       \
226   case ElemKind::ELEM_KIND: {                                                  \
227     assert(dims().size() == 2 && "Fused tensor must be 2-dimensional.");       \
228     assert(dims()[1] > sizeof(DATA_TYPE) &&                                    \
229            "Fused tensor must have space for scale and offset.");              \
230     const size_t dataWidth = dims()[1];                                        \
231     const size_t alignedLength = type_.strides()[0];                           \
232     auto *data = reinterpret_cast<uint8_t *>(getData());                       \
233     for (size_t i = 0, e = dims()[0]; i < e; i++) {                            \
234       uint8_t *scaleOffsetPtr =                                                \
235           data + i * alignedLength + dataWidth - 2 * sizeof(DATA_TYPE);        \
236       DATA_TYPE scale, offset;                                                 \
237       if (resetFusedScalesOffsets) {                                           \
238         /* Use these as defaults, and copy them into each row. */              \
239         scale = 1.0;                                                           \
240         offset = 0.0;                                                          \
241         memcpy(scaleOffsetPtr, &scale, sizeof(DATA_TYPE));                     \
242         memcpy(scaleOffsetPtr + sizeof(DATA_TYPE), &offset,                    \
243                sizeof(DATA_TYPE));                                             \
244       } else {                                                                 \
245         memcpy(&scale, scaleOffsetPtr, sizeof(DATA_TYPE));                     \
246         memcpy(&offset, scaleOffsetPtr + sizeof(DATA_TYPE),                    \
247                sizeof(DATA_TYPE));                                             \
248       }                                                                        \
249       DCHECK_NE(static_cast<float>(scale), 0.0)                                \
250           << "Disallow scale = 0.0 for Fused ElemKinds; causes div by zero.";  \
251       float zero = nearbyintf(-1 * static_cast<float>(offset / scale));        \
252       std::fill(data + i * alignedLength, scaleOffsetPtr,                      \
253                 static_cast<uint8_t>(zero));                                   \
254     }                                                                          \
255     break;                                                                     \
256   }
257       FUSED_CASE(UInt8FusedQTy, float);
258       FUSED_CASE(UInt8FusedFP16QTy, float16_t);
259 #undef FUSED_CASE
260 
261     default:
262       // Non-quantized tensors are set to 0.
263       std::fill(&getData()[0], &getData()[0] + size * type_.getElementSize(),
264                 0);
265       break;
266     }
267   }
268 
269   /// \returns the shape of the tensor.
dims()270   llvm::ArrayRef<dim_t> dims() const { return type_.dims(); }
271 
272   /// \returns the number of real meaningful elements in the tensor. Does not
273   /// take strides into account.
size()274   dim_t size() const { return type_.size(); }
275 
276   /// \returns the actual number of elements in the tensor taking striding into
277   /// account. Since size() does not take striding into account, size() is
278   /// always <= actualSize().
actualSize()279   dim_t actualSize() const { return type_.actualSize(); }
280 
281   /// \returns the number of bytes required to store the tensor based on its
282   /// Type. Note that this includes the size required for padding.
getSizeInBytes()283   uint64_t getSizeInBytes() const { return type_.getSizeInBytes(); }
284 
285   /// \returns the TensorPool managing this object, or nullptr if it is
286   /// unmanaged.
getOwningPool()287   TensorPool *getOwningPool() { return tensorPool_; }
288 
289   /// Initialize an empty tensor.
290   Tensor() = default;
291 
292   /// Initialize from a list of float literals.
Tensor(const std::initializer_list<float> & vec)293   Tensor(const std::initializer_list<float> &vec) {
294     reset(ElemKind::FloatTy, {(dim_t)vec.size()});
295     auto *data = getRawDataPointer<float>();
296     int i = 0;
297     for (auto &f : vec) {
298       data[i++] = f;
299     }
300   }
301 
302   /// Allocate and initialize a new tensor.
Tensor(TypeRef ty)303   explicit Tensor(TypeRef ty) : data_(nullptr), type_(*ty), isUnowned_{false} {
304     reset(*ty);
305   }
306 
307   /// Allocate and initialize a new tensor.
Tensor(const Type & ty)308   explicit Tensor(const Type &ty)
309       : data_(nullptr), type_(ty), isUnowned_{false} {
310     reset(ty);
311   }
312 
313   /// Allocate and initialize a float new tensor.
Tensor(ElemKind elemTy,llvm::ArrayRef<dim_t> dims)314   Tensor(ElemKind elemTy, llvm::ArrayRef<dim_t> dims)
315       : data_(nullptr), type_(elemTy, dims), isUnowned_{false} {
316     reset(elemTy, dims);
317   }
318 
319   /// Construct an unowned tensor provided an existing payload buffer.
320   /// This constructor can be used when there is a need to work with
321   /// "externally" managed payload buffers using Tensor APIs. Additionally
322   /// \p unpaddedSize can be set to indicate actual size of the inputs. If
323   /// negative then it defaults back to the size of the input type.
324   Tensor(void *data, TypeRef ty, ssize_t unpaddedSize = -1)
data_(reinterpret_cast<char * > (data))325       : data_(reinterpret_cast<char *>(data)), type_(*ty) {
326     // Mark as unowned.
327     isUnowned_ = true;
328     // We do want DeviceResidency however, since there is no owning Glow Tensor.
329     resetDeviceInfo();
330     if (unpaddedSize < 0) {
331       unpaddedSize_ = type_.getSizeInBytes();
332     } else {
333       unpaddedSize_ = static_cast<size_t>(unpaddedSize);
334     }
335   }
336 
337   /// Allocate and initialize a new integer tensor with \p scale and \p offset.
Tensor(ElemKind elemTy,llvm::ArrayRef<dim_t> dims,float scale,int32_t offset)338   Tensor(ElemKind elemTy, llvm::ArrayRef<dim_t> dims, float scale,
339          int32_t offset)
340       : data_(nullptr), type_(elemTy, dims, scale, offset), isUnowned_{false} {
341     reset(type_);
342   }
343 
344   /// Allocate a new Tensor managed by the \p tensorPool.
Tensor(TypeRef ty,TensorPool * tensorPool)345   explicit Tensor(TypeRef ty, TensorPool *tensorPool)
346       : data_(nullptr), type_(*ty), tensorPool_(tensorPool) {
347     reset(*ty);
348   }
349 
350   Tensor(const Tensor &other) = delete;
351   Tensor &operator=(const Tensor &other) = delete;
352 
353   /// Initialize the content of the tensor using the \p init method. The value
354   /// \p val is the initialization parameter. \p PRNG is used to generate random
355   /// numbers. Note that if the tensor's kind is Fused, then the fused
356   /// scaled/offsets will not be modified.
357   void init(InitKind init, float val, PseudoRNG &PRNG);
358 
359   /// \returns an unowned tensor with the exact same dimensions as this.
getUnowned()360   Tensor getUnowned() const { return getUnowned(dims()); }
361 
362   /// \returns unowned tensor using the same data buffer as the current tensor
363   /// but having different dimensions \p dims. \p offsets represents an optional
364   /// offset into the tensor representing the location of the first element to
365   /// start a subview from. The returned unonwed tensor is essentially a
366   /// different view or subview on the same data.
367   ///
368   /// The lifetime of the returned unowned tensor should be always within
369   /// the lifetime of its parent tensor, i.e. the unowned tensor should not
370   /// outlive its parent tensor.
371   Tensor getUnowned(llvm::ArrayRef<dim_t> dims,
372                     llvm::ArrayRef<dim_t> offsets = {}) const {
373     Tensor unownedTensor;
374 
375     auto *firstElemPtr = getData();
376     if (offsets.size()) {
377       assert(offsets.size() == this->dims().size() &&
378              "Number of dims of tensor must equal number of dims in offsets");
379       // Find the index of the first element and use it to find the pointer to
380       // the first element.
381       size_t index = 0;
382       for (size_t i = 0; i < this->dims().size(); i++) {
383         index += type_.strides()[i] * offsets[i];
384       }
385       firstElemPtr = &firstElemPtr[index * type_.getElementSize()];
386     }
387 
388     unownedTensor.data_ = firstElemPtr;
389     unownedTensor.isUnowned_ = true;
390     unownedTensor.type_ = Type::newShape(getType(), dims);
391     unownedTensor.deviceResidency_ = deviceResidency_;
392 
393     // If the original base Tensor is padded, then we only allow the unowned
394     // Tensor to be padded if there are no offsets. Otherwise assert that the
395     // base Tensor is not padded, and set unpaddedSize to that of the new
396     // unowned type.
397     if (offsets.size() == 0) {
398       unownedTensor.unpaddedSize_ = unpaddedSize_;
399       assert(actualSize() == unownedTensor.actualSize() &&
400              "The size of the unowned tensor "
401              "should be the same as the size of "
402              "the original tensor");
403 
404     } else {
405       unownedTensor.unpaddedSize_ = unownedTensor.type_.getSizeInBytes();
406       assert(getSizeInBytes() == getUnpaddedSizeInBytes() &&
407              "Problematic to get unowned offsetted view of a padded tensor");
408       assert(actualSize() >= unownedTensor.actualSize() &&
409              "The size of the unowned tensor "
410              "should be no greater than the "
411              "size of the original tensor");
412     }
413     return unownedTensor;
414   }
415 
416   /// This is the same as \ref getUnowned() but it produces an owned tensor
417   /// instead. \returns owned tensor copied from the data buffer of the current
418   /// tensor but having different dimensions \p dims. \p offsets represents an
419   /// optional offset into the tensor representing the location of the first
420   /// element to start a subview from.
421   Tensor getOwnedSlice(llvm::ArrayRef<dim_t> dims,
422                        llvm::ArrayRef<dim_t> offsets = {}) const {
423     assert(!isDeviceResident() && "Tensor must reside on host to access data.");
424     return getUnowned(dims, offsets).clone();
425   }
426 
427   /// Reset the shape and type of this tensor to match the shape and type of
428   /// \p other. The size of the buffer is set to \p unpaddedSize unless it is
429   /// negative, which will instead default back to the number of bytes needed
430   /// for the type of \p other.
431   void reset(const Tensor *other, ssize_t unpaddedSize = -1) {
432     reset(other->getType(), unpaddedSize);
433   }
434 
reset(ElemKind elemTy,llvm::ArrayRef<dim_t> shape)435   void reset(ElemKind elemTy, llvm::ArrayRef<dim_t> shape) {
436     Type t(elemTy, shape);
437     reset(t);
438   }
439 
reset(ElemKind elemTy,llvm::ArrayRef<dim_t> shape,float scale,int32_t offset)440   void reset(ElemKind elemTy, llvm::ArrayRef<dim_t> shape, float scale,
441              int32_t offset) {
442     Type t(elemTy, shape, scale, offset);
443     reset(t);
444   }
445 
446   /// Assigns a new shape to the tensor and allocates a new buffer. The size of
447   /// the buffer is set to \p unpaddedSize unless it is negative, which will
448   /// instead default back to the number of bytes needed for \p T.
449   void reset(const Type &T, ssize_t unpaddedSize = -1) {
450     assert(!isDeviceResident() && "Tensor must reside on host to access data.");
451 
452     // If negative then fall back to the passed in Type's padded size.
453     if (unpaddedSize < 0) {
454       unpaddedSize = T.getSizeInBytes();
455     }
456 
457     // If the new size is identical to the allocated size then there is no need
458     // to re-allocate the buffer.
459     const bool isOrigPadded = getSizeInBytes() != getUnpaddedSizeInBytes();
460     const bool isNewPadded = T.getSizeInBytes() != unpaddedSize;
461     const bool isBufReuseAllowed = (isOrigPadded == isNewPadded) &&
462                                    (getUnpaddedSizeInBytes() == unpaddedSize);
463     if (type_ == T && getData() && isBufReuseAllowed) {
464 #ifdef GLOW_DEBUG_TENSOR_INIT
465       PseudoRNG rng;
466       init(InitKind::Broadcast, GLOW_DEBUG_TENSOR_INIT, rng);
467 #endif
468       resetDeviceInfo();
469       return;
470     }
471 
472     // Delete the old buffer, update the shape, and allocate a new one.
473     if (!isUnowned())
474       alignedFree(getData());
475     type_ = T;
476 
477     // We are allocating memory specifically for this tensor, thus, it owns it.
478     isUnowned_ = false;
479 
480     // We are allocating memory on the host so it is not device resident.
481     resetDeviceInfo();
482 
483     // Note: zero-dimensional tensors (i.e. {}) have size 1. However, Tensors
484     // may have 0 for some dimension, meaning they have size of 0, and so we do
485     // not allocate anything for them.
486     data_ = unpaddedSize == 0 ? nullptr
487                               : reinterpret_cast<char *>(alignedAlloc(
488                                     unpaddedSize, TensorAlignment));
489 
490     // Set unpaddedSize_ to the actual number of bytes.
491     unpaddedSize_ = unpaddedSize;
492 
493 #ifdef GLOW_DEBUG_TENSOR_INIT
494     PseudoRNG rng;
495     init(InitKind::Broadcast, GLOW_DEBUG_TENSOR_INIT, rng);
496 #endif
497   }
498   /// Releases the data buffer and sets the unOwned flag to true. This is useful
499   /// for keeping metadata around but not the actual contents.
release()500   void release() {
501     if (!isUnowned()) {
502       alignedFree(getData());
503     }
504     if (ownsDeviceResidency_) {
505       delete deviceResidency_;
506       ownsDeviceResidency_ = false;
507     }
508 
509     isUnowned_ = true;
510   }
~Tensor()511   ~Tensor() {
512     if (!isUnowned()) {
513       alignedFree(getData());
514     }
515 
516     if (ownsDeviceResidency_) {
517       delete deviceResidency_;
518       ownsDeviceResidency_ = false;
519     }
520   }
521 
522   // Move ctor.
Tensor(Tensor && other)523   Tensor(Tensor &&other) noexcept {
524     std::swap(data_, other.data_);
525     std::swap(type_, other.type_);
526     std::swap(isUnowned_, other.isUnowned_);
527     std::swap(tensorPool_, other.tensorPool_);
528     std::swap(unpaddedSize_, other.unpaddedSize_);
529     std::swap(deviceResidency_, other.deviceResidency_);
530     std::swap(ownsDeviceResidency_, other.ownsDeviceResidency_);
531   }
532 
533   /// Move assignment operator.
534   Tensor &operator=(Tensor &&other) noexcept {
535     std::swap(data_, other.data_);
536     std::swap(type_, other.type_);
537     std::swap(isUnowned_, other.isUnowned_);
538     std::swap(tensorPool_, other.tensorPool_);
539     std::swap(unpaddedSize_, other.unpaddedSize_);
540     std::swap(deviceResidency_, other.deviceResidency_);
541     std::swap(ownsDeviceResidency_, other.ownsDeviceResidency_);
542     return *this;
543   }
544 
545   /// Dump a textual representation of the Tensor into provided output stream.
546   void dump(llvm::raw_ostream &os) const;
547 
548   /// Dump a textual representation of the Tensor into default output stream.
549   void dump() const;
550 
551   /// Dump a textual representation of a specific number of elements in the
552   /// Tensor into provided output stream.
553   void dump(llvm::raw_ostream &os, unsigned maxNumElem) const;
554 
555   /// Dump a textual representation of a specific number of elements in the
556   /// Tensor into default output stream.
557   void dump(unsigned maxNumElem) const;
558 
559   /// Dump a textual representation of the Tensor to std::string.
560   std::string toString() const;
561 
562   /// Dump a textual representation of a specific number of elements in the
563   /// Tensor to std::string.
564   std::string toString(unsigned maxNumElem) const;
565 
566   /// Dump a textual representation of the shape of this Tensor to std::string.
567   std::string getShapeToString() const;
568 
569   /// \returns true if the content of the other tensor \p other is identical to
570   /// this one, given some \p allowedError. If \p verbose and the tensors are
571   /// not equal, then we will log information about the mismatch (number of
572   /// elements exceeding allowed error; maximum error and location found; etc.).
573   bool isEqual(const Tensor &other, float allowedError = 0.0001,
574                bool verbose = true) const {
575     if (isDeviceResident()) {
576       if (!other.isDeviceResident()) {
577         if (verbose) {
578           LOG(INFO) << "Tensors cannot be compared as they are not resident in "
579                        "the same location.";
580         }
581         return false;
582       }
583 
584       return getDeviceManager() == other.getDeviceManager() &&
585              getLocationContext() == other.getLocationContext();
586     }
587     return isEqualImpl(other, /*isBitwise=*/false, allowedError, verbose);
588   }
589 
590   /// \returns true if the content of the other tensor \p other is bitwise
591   /// identical to this one.
592   bool isBitwiseEqual(const Tensor &other, bool verbose = false) const {
593     return isEqualImpl(other, /*isBitwise=*/true, /*allowedError=*/0.0,
594                        verbose);
595   }
596 
isEqualImpl(const Tensor & other,bool isBitwise,float allowedError,bool verbose)597   bool isEqualImpl(const Tensor &other, bool isBitwise, float allowedError,
598                    bool verbose) const {
599     if (other.dims() != dims()) {
600       if (verbose) {
601         LOG(INFO) << "Tensors are not equal as they have different shapes: "
602                   << this->getShapeToString() << " vs. "
603                   << other.getShapeToString();
604       }
605       return false;
606     }
607 
608     // For now, make sure that either both or neither of the tensors have
609     // UInt8FusedQTy or UInt8Fused16QTy. While it is possible for an Int8QTy
610     // tensor to equal a fused tensor if the fused tensor has the same
611     // scale/offset on all of its rows, and that scale/offset match that of the
612     // Int8QTy, we do not support checking this for now.
613     assert(((getElementType() == ElemKind::UInt8FusedQTy &&
614              other.getElementType() == ElemKind::UInt8FusedQTy) ||
615             (getElementType() == ElemKind::UInt8FusedFP16QTy &&
616              other.getElementType() == ElemKind::UInt8FusedFP16QTy) ||
617             (getElementType() != ElemKind::UInt8FusedFP16QTy &&
618              other.getElementType() != ElemKind::UInt8FusedQTy)) &&
619            "Fused ElemKinds only supports comparing against same ElemKind.");
620 
621     // Assert that the scale and offset match for the quantized types.
622     switch (getElementType()) {
623     default:
624       break;
625     case ElemKind::Int8QTy:
626     case ElemKind::UInt8QTy:
627     case ElemKind::Int16QTy:
628     case ElemKind::Int32QTy:
629       assert(getType().getScale() == other.getType().getScale() &&
630              "Scales must match.");
631       assert(getType().getOffset() == other.getType().getOffset() &&
632              "Offsets must match.");
633     }
634 
635     // Bitwise compare.
636     if (isBitwise) {
637       return isBitwiseEqualImpl(other, verbose);
638     }
639 
640     switch (getElementType()) {
641     case ElemKind::FloatTy:
642       return isEqualImpl<float>(other, allowedError, verbose);
643     case ElemKind::Float16Ty:
644       return isEqualImpl<float16_t>(other, allowedError, verbose);
645     case ElemKind::BFloat16Ty:
646       return isEqualImpl<bfloat16_t>(other, allowedError, verbose);
647     case ElemKind::Int8QTy:
648       return isEqualImpl<int8_t>(other, allowedError, verbose);
649     case ElemKind::UInt8QTy:
650       return isEqualImpl<uint8_t>(other, allowedError, verbose);
651     case ElemKind::Int16QTy:
652       return isEqualImpl<int16_t>(other, allowedError, verbose);
653     case ElemKind::Int32QTy:
654       return isEqualImpl<int32_t>(other, allowedError, verbose);
655     case ElemKind::Int32ITy:
656       return isEqualImpl<int32_t>(other, allowedError, verbose);
657     case ElemKind::Int64ITy:
658       return isEqualImpl<int64_t>(other, allowedError, verbose);
659       // Note: We can use isEqualImpl() here because the scales/offsets will be
660       // compared as if they were data, so we will return false if any rowwise
661       // scale/offset do not match.
662     case ElemKind::UInt8FusedQTy:
663       return isEqualImpl<uint8_t>(other, allowedError, verbose);
664     case ElemKind::UInt8FusedFP16QTy:
665       return isEqualImpl<uint8_t>(other, allowedError, verbose);
666     case ElemKind::UInt4FusedFP16QTy:
667       return isEqualImpl<uint8_t>(other, allowedError, verbose);
668     case ElemKind::BoolTy:
669       return isEqualImpl<bool>(other, allowedError, verbose);
670     }
671 
672     // This is to make compiler happy. It can never reach this point as switch
673     // always covers all possible values.
674     llvm_unreachable("unreachable");
675   }
676 
677   /// \returns whether this Tensor is tiled (repeated) along \p axis for the
678   /// given tile size \p size. Some examples:
679   /// - A Tensor with size [2, 3] equal to [[1,2,3],[1,2,3]] is tiled along
680   ///   axis 0 for a tile size equal to 1.
681   /// - A Tensor with size [2, 4] equal to [[1, 2, 1, 2],[3, 4, 3, 4]] is tiled
682   ///   along axis 1 for a tile size equal to 2.
683   /// When the tile size matches the dimensions size this function returns TRUE.
684   /// If the \p fractional flag is optionally given that this function will also
685   /// perform fractional tiling verification (default is FALSE). Some examples:
686   /// - For a Tensor with size [5] equal to [1,2,3,1,2], axis 0 and tile size 3,
687   ///   this function returns TRUE if \p fractional is TRUE and returns FALSE if
688   ///   \p fractional is FALSE.
689   bool isTiled(unsigned_t axis, dim_t size = 1, bool fractional = false) const;
690 
691   /// \returns whether this Tensor is tiled (repeated) along \p axes for the
692   /// given tile sizes \p sizes. Some examples:
693   /// - A Tensor with size [2, 4] equal to [[1,2,1,2],[1,2,1,2]] is tiled along
694   ///   axes {0,1} for the tile sizes {1,2}.
695   /// When the tile sizes match the dimension sizes this function returns TRUE.
696   /// If the \p fractional flag is optionally given that this function will also
697   /// perform fractional tiling verification (default is FALSE). Some examples:
698   /// - For a Tensor with size [5] equal to [1,2,3,1,2], axes {0} and sizes {3},
699   ///   this function returns TRUE if \p fractional is TRUE and returns FALSE if
700   ///   \p fractional is FALSE.
701   bool isTiled(llvm::ArrayRef<unsigned_t> axes, llvm::ArrayRef<dim_t> sizes,
702                bool fractional = false) const;
703 
704   /// Update the content and type of the tensor from the tensor \p t.
assign(const Tensor * t)705   void assign(const Tensor *t) {
706     assert(!isDeviceResident() && "Tensor must reside on host to access data.");
707     assert(this != t && "Copying to self");
708     const size_t bufferSize = t->getUnpaddedSizeInBytes();
709     reset(t, bufferSize);
710     std::copy(&t->getData()[0], &t->getData()[bufferSize], getData());
711   }
712 
713   /// Update the raw data of the tensor from the tensor \p t.
copyRawFrom(const Tensor * t)714   void copyRawFrom(const Tensor *t) {
715     assert(!isDeviceResident() && "Tensor must reside on host to access data.");
716     assert(this != t && "Copying to self");
717     assert(actualSize() == t->actualSize());
718     assert(getElementType() == t->getElementType() && "Invalid element type");
719     assert(t->getUnpaddedSizeInBytes() == getUnpaddedSizeInBytes() &&
720            "Do not support copying between different unpadded sized tensors");
721     size_t bufferSize = type_.getSizeInBytes();
722     std::copy(&t->getData()[0], &t->getData()[bufferSize], getData());
723   }
724 
725   /// Update the raw data of the tensor from a raw buffer \p data.
copyRawFrom(const char * data)726   void copyRawFrom(const char *data) {
727     assert(!isDeviceResident() && "Tensor must reside on host to access data.");
728     assert(data && "Null data pointer!");
729     assert(getData() != data && "Copying to self");
730     size_t bufferSize = type_.getSizeInBytes();
731     std::memcpy(getData(), data, bufferSize);
732   }
733 
734   /// Update the content of the tensor with a slice from tensor \p t. A slice
735   /// is one index from the first dimension of the tensor.
copySlice(const Tensor * t,size_t slice)736   void copySlice(const Tensor *t, size_t slice) {
737     assert(!isDeviceResident() && "Tensor must reside on host to access data.");
738     auto dim = t->dims().slice(1);
739     (void)dim;
740     assert(dim == dims() && "Invalid slice size");
741     assert(getElementType() == t->getElementType() && "Invalid element type");
742 
743     size_t bufferSize = type_.getSizeInBytes();
744     std::copy(&t->getData()[bufferSize * slice],
745               &t->getData()[bufferSize * (slice + 1)], getData());
746   }
747 
748   /// Update the content of the tensor with a sequence of slices from the
749   /// tensor \p t. A slice is one index from the first dimension of the tensor.
750   /// The copying operation may overlap the end of the tensor \p t one or more
751   /// times. This means that the data in the input tensor may be duplicated.
copyConsecutiveSlices(const Tensor * t,size_t startSliceIdx)752   void copyConsecutiveSlices(const Tensor *t, size_t startSliceIdx) {
753     assert(!isDeviceResident() && "Tensor must reside on host to access data.");
754     auto onceSliceDim = t->dims().slice(1);
755     (void)onceSliceDim;
756     assert(onceSliceDim == dims().slice(1) && "Invalid slice size");
757     assert(getElementType() == t->getElementType() && "Invalid element type");
758     assert(dims().size() > 1 && "Tensor must contain at least two dimensions");
759 
760     size_t numSlicesInInput = t->dims()[0];
761     size_t numElementsInSlice = actualSize() / dims()[0];
762     size_t bufferSize = numElementsInSlice * type_.getElementSize();
763 
764     // For each outer slice in the current tensor:
765     for (size_t n = 0, e = dims()[0]; n < e; n++) {
766       size_t startIdx = (startSliceIdx + n) % numSlicesInInput;
767       std::copy(&t->getData()[bufferSize * startIdx],
768                 &t->getData()[bufferSize * (startIdx + 1)],
769                 &getData()[bufferSize * n]);
770     }
771   }
772 
773   /// Convenience method to copy the content of \p t
774   /// to this while both have different underlying types.
775   /// This copy will read each element of \p t as SrcElemType
776   /// and cast them to DestElemType in this.
777   template <typename DestElemType, typename SrcElemType>
copyWithCast(const Tensor * t)778   void copyWithCast(const Tensor *t) {
779     assert(!isDeviceResident() && "Tensor must reside on host to access data.");
780     static_assert(!std::is_same<DestElemType, SrcElemType>::value,
781                   "Use copyRawFrom instead");
782     assert(this != t && "Copying to self");
783     assert(getElementType() != t->getElementType() &&
784            "Use copyRawFrom instead");
785     assert(actualSize() == t->actualSize() && "Different sizes");
786     const auto *src = t->getRawDataPointer<SrcElemType>();
787     auto *dst = getRawDataPointer<DestElemType>();
788     for (size_t idx = 0, end = actualSize(); idx != end; ++idx) {
789       dst[idx] = DestElemType(src[idx]);
790     }
791   }
792 
793   /// Convert each element of this tensor to \p newTy. Calls into
794   /// \ref getCopyConvertedToType() to do the conversion, and hence supports
795   /// converting between whatever ElemKinds it supports.
796   void convertToType(ElemKind newTy);
797 
798   /// \returns a copy of the Tensor but converted to \p newKind. Currently
799   /// supports conversion for:
800   /// - FloatTy to Float16Ty
801   /// - FloatTy to BFloat16Ty
802   /// - Float16Ty to FloatTy
803   /// - BFloat16Ty to FloatTy
804   /// - UInt8FusedQTy to UInt8FusedFP16QTy
805   Tensor getCopyConvertedToType(ElemKind newKind) const;
806 
807   /// Transpose the tensor \p src into the empty tensor \p dest. Shuffle the
808   /// axis based on the list \p shuffle, where each element is the src index.
transpose(Tensor * dest,llvm::ArrayRef<unsigned_t> shuffle)809   void transpose(Tensor *dest, llvm::ArrayRef<unsigned_t> shuffle) const {
810     assert(!isDeviceResident() && "Tensor must reside on host to access data.");
811     genericTranspose(this, dest, shuffle);
812   }
813 
814   /// Create a new copy of the current tensor.
clone()815   Tensor clone() const {
816     assert(!isDeviceResident() && "Tensor must reside on host to access data.");
817     Tensor slice;
818     slice.assign(this);
819     return slice;
820   }
821 
822   /// Return the raw unsafe pointer to the tensor payload.
getUnsafePtr()823   char *getUnsafePtr() const { return getData(); }
824 
825   /// \returns true if tensor data is stored on a device
isDeviceResident()826   bool isDeviceResident() const {
827     return deviceResidency_ && deviceResidency_->isDeviceResident();
828   }
829 
830   /// Update device residency info with new device manager and context
831   void moveToDevice(DeviceTensorTransferManager *deviceManager,
832                     void *locationContext);
833 
834   /// If device resident, copy Tensor contents back to host memory and release
835   /// associated device memory.
836   void ensureOnHost();
837 
838   /// Updates contents of a device resident Tensor with the data from \p t
839   /// without copying its contents to host.
840   void copyRawToDevice(const Tensor *t);
841 
842   /// \returns the pointer to the device manager where the tensor resides.
getDeviceManager()843   DeviceTensorTransferManager *getDeviceManager() const {
844     assert(deviceResidency_ != nullptr && "DeviceResidencyInfo must exist");
845     assert(deviceResidency_->isDeviceResident() &&
846            "Tensor must be device resident");
847     return deviceResidency_->getDeviceManager();
848   }
849 
850   /// \returns the pointer to the location context of where the tensor resides.
getLocationContext()851   void *getLocationContext() const {
852     assert(deviceResidency_ != nullptr && "DeviceResidencyInfo must exist");
853     assert(deviceResidency_->isDeviceResident() &&
854            "Tensor must be device resident");
855     return deviceResidency_->getLocationContext();
856   }
857 
resetDeviceInfo()858   void resetDeviceInfo() {
859     if (deviceResidency_ && ownsDeviceResidency_) {
860       deviceResidency_->clear();
861       return;
862     }
863 
864     deviceResidency_ = new DeviceResidencyInfo();
865     ownsDeviceResidency_ = true;
866   }
867 
868   /// Clears DeviceResidencyInfo.
869   /// Note that this does not affect the associated DeviceManager or device
870   /// memory.
clearDeviceResidency()871   void clearDeviceResidency() {
872     assert(deviceResidency_ != nullptr && "DeviceResidencyInfo must exist");
873     assert(deviceResidency_->isDeviceResident() &&
874            "Tensor must be device resident");
875     deviceResidency_->clear();
876   }
877 
878   /// \return a new handle that points and manages this tensor.
879   template <class ElemTy = float> Handle<ElemTy> getHandle() &;
880 
881   template <class ElemTy = float> const Handle<ElemTy> getHandle() const &;
882 
883   /// If Tensor is rvalue, it is an error to get its Handle.
884   template <class ElemTy = float> Handle<ElemTy> getHandle() && = delete;
885 
886 private:
887   /// \returns a pointer to the raw data, of type \p ElemTy.
getRawDataPointer()888   template <class ElemTy> ElemTy *getRawDataPointer() {
889     assert(!isDeviceResident() && "Tensor must reside on host to access data.");
890     assert(type_.isType<ElemTy>() && "Asking for the wrong ptr type.");
891     return reinterpret_cast<ElemTy *>(data_);
892   }
893 
894   /// \returns a const pointer to the raw data, of type \p ElemTy.
getRawDataPointer()895   template <class ElemTy> const ElemTy *getRawDataPointer() const {
896     assert(!isDeviceResident() && "Tensor must reside on host to access data.");
897     assert(type_.isType<ElemTy>() && "Asking for the wrong ptr type.");
898     return reinterpret_cast<const ElemTy *>(data_);
899   }
900 
901   template <class ElemTy>
isEqualImpl(const Tensor & other,float allowedError,bool verbose)902   bool isEqualImpl(const Tensor &other, float allowedError,
903                    bool verbose) const {
904     assert(!isDeviceResident() && "Tensor must reside on host to access data.");
905     auto thisHandle = getHandle<ElemTy>();
906     auto otherHandle = other.getHandle<ElemTy>();
907     double maxFoundError = 0.0;
908     size_t numExceedingError = 0;
909     size_t currIndex = 0;
910     size_t maxFoundErrorIdx = 0;
911     double maxRE = 0.0; // relative error.
912     size_t maxREIdx = 0;
913     for (auto thisHandleIt = thisHandle.begin(),
914               otherHandleIt = otherHandle.begin();
915          thisHandleIt != thisHandle.end() && otherHandleIt != otherHandle.end();
916          ++thisHandleIt, ++otherHandleIt, ++currIndex) {
917       double delta = *thisHandleIt - *otherHandleIt;
918       delta = std::abs(delta);
919       // Since any comparison with NAN returns false, we use a negated condition
920       // so that this function correctly returns false when delta is NAN.
921       if (!(delta <= allowedError)) {
922         if (!verbose) {
923           return false;
924         }
925         numExceedingError += 1;
926         if (!(delta <= maxFoundError)) {
927           maxFoundError = delta;
928           maxFoundErrorIdx = currIndex;
929         }
930         double sum = *thisHandleIt + *otherHandleIt;
931         double re = delta / std::abs(sum);
932         if (!(re <= maxRE)) {
933           maxRE = re;
934           maxREIdx = currIndex;
935         }
936       }
937     }
938     auto thisHandleIt = thisHandle.begin();
939     auto otherHandleIt = otherHandle.begin();
940     if (numExceedingError != 0) {
941       LOG(INFO) << "Tensors not equal: " << numExceedingError << " out of "
942                 << actualSize() << " elements exceeded allowed error threshold "
943                 << allowedError << ". Maximum error found was " << maxFoundError
944                 << " at index " << maxFoundErrorIdx << ": "
945                 << *(thisHandleIt.operator+(maxFoundErrorIdx)) << " vs. "
946                 << *(otherHandleIt.operator+(maxFoundErrorIdx));
947       LOG(INFO) << "Maximum relative error found was: " << maxRE
948                 << " at index: " << maxREIdx << ": "
949                 << *(thisHandleIt.operator+(maxREIdx)) << " v.s. "
950                 << *(otherHandleIt.operator+(maxREIdx));
951     }
952     return numExceedingError == 0;
953   }
954 
isBitwiseEqualImpl(const Tensor & other,bool verbose)955   bool isBitwiseEqualImpl(const Tensor &other, bool verbose) const {
956     assert(!isDeviceResident() && "Tensor must reside on host to access data.");
957     auto const *myData = getUnsafePtr();
958     auto const *otherData = other.getUnsafePtr();
959     dim_t mismatchCount = 0;
960     for (size_t i = 0, e = getSizeInBytes(); i < e; i++) {
961       if (myData[i] != otherData[i]) {
962         if (!verbose) {
963           return false;
964         }
965         ++mismatchCount;
966       }
967     }
968     if (mismatchCount != 0) {
969       LOG(INFO) << "Tensors not bitwise equal: " << mismatchCount
970                 << " bytes out of " << getSizeInBytes() << " mismatched.";
971     }
972     return mismatchCount == 0;
973   }
974 };
975 
976 //===----------------------------------------------------------------------===//
977 //                    Tensor Handle
978 //===----------------------------------------------------------------------===//
979 
980 constexpr unsigned MAX_DUMP_ELEMS = 100;
981 
982 void dumpAsciiImpl(const Tensor *T, llvm::raw_ostream &os);
983 void dumpAsciiImpl(const Tensor *T);
984 
985 void dumpImpl(const Tensor *T, llvm::raw_ostream &os,
986               unsigned maxNumElem = MAX_DUMP_ELEMS);
987 void dumpImpl(const Tensor *T, unsigned maxNumElem);
988 void dumpImpl(const Tensor *T);
989 
990 template <class ElemTy> class Handle;
991 
992 /// A class that provides ability to iterate over a Handle<ElemTy>. Since it's
993 /// common to have both mutating and const iterators, this class has template
994 /// parameter IsConst, which is true to create const_iterator and false
995 /// otherwise.
996 template <class ElemTy, bool IsConst>
997 class HandleIterator
998     : public std::iterator<std::random_access_iterator_tag, ElemTy> {
999   using HandleTy = typename std::conditional_t<IsConst, const Handle<ElemTy> *,
1000                                                Handle<ElemTy> *>;
1001   using ElemTyRef =
1002       typename std::conditional_t<IsConst, const ElemTy &, ElemTy &>;
1003 
1004   /// At every given moment, the iterator maintains an index, which is used to
1005   /// access the Handle. When moving the iterator forward, the index is
1006   /// incremented. Only valid elements can be accessed.
1007   /// 0 <= idx_ <= handle_->size()
1008   HandleTy handle_;
1009   llvm::ArrayRef<dim_t> sizes_;
1010   dim_t idx_;
1011   /// Holds true if the underlying tensor has non-trivial alignment (i.e. not 1)
1012   bool isAligned_;
1013 
1014   HandleIterator() = default;
1015 
HandleIterator(HandleTy handle)1016   HandleIterator(HandleTy handle) : handle_(handle) {
1017     sizes_ = handle->dims();
1018     isAligned_ = handle->size() < handle->actualSize();
1019   }
1020 
begin(HandleTy handle)1021   static HandleIterator begin(HandleTy handle) {
1022     auto res = HandleIterator(handle);
1023     res.idx_ = 0;
1024     return res;
1025   }
1026 
end(HandleTy handle)1027   static HandleIterator end(HandleTy handle) {
1028     auto res = HandleIterator(handle);
1029     res.idx_ = res.handle_->size();
1030     return res;
1031   }
1032 
1033   friend class Handle<ElemTy>;
1034 
1035 public:
1036   HandleIterator &operator++() {
1037     if (*this != handle_->end()) {
1038       idx_++;
1039     }
1040     return *this;
1041   }
1042   HandleIterator &operator--() {
1043     if (idx_) {
1044       idx_--;
1045     }
1046     return *this;
1047   }
1048   HandleIterator operator+(int n) const {
1049     auto res = HandleIterator(handle_);
1050     res.idx_ = std::max(static_cast<int>(idx_) + n, 0);
1051     res.idx_ = std::min(res.idx_, res.handle_->size());
1052     return res;
1053   }
1054   HandleIterator operator-(int n) const { return *this + (-n); }
1055   operator int() const { return idx_; }
1056 
1057   ElemTyRef operator*() {
1058     if (!isAligned_) {
1059       return handle_->raw(idx_);
1060     }
1061     std::vector<dim_t> indices(sizes_.size(), 0);
1062     size_t rem = idx_;
1063     for (int i = static_cast<int>(sizes_.size()) - 1; i >= 0; i--) {
1064       indices[i] = rem % sizes_[i];
1065       rem /= sizes_[i];
1066     }
1067     return handle_->at(indices);
1068   }
1069 
1070   bool operator==(const HandleIterator<ElemTy, IsConst> &other) const {
1071     return idx_ == other.idx_;
1072   }
1073 
1074   bool operator!=(const HandleIterator<ElemTy, IsConst> &other) const {
1075     return !(*this == other);
1076   }
1077 };
1078 
1079 /// Helper which \returns the flattened 1D offset given \p indices into a tensor
1080 /// with \p strides.
getFlattenedOffset(llvm::ArrayRef<dim_t> strides,llvm::ArrayRef<dim_t> indices)1081 inline size_t getFlattenedOffset(llvm::ArrayRef<dim_t> strides,
1082                                  llvm::ArrayRef<dim_t> indices) {
1083   assert(indices.size() <= strides.size() && "Invalid number of indices");
1084   // The loop below can be rewritten using std::inner_product. Unfortunately
1085   // std::inner_product does not optimize very well and loops that use this
1086   // method don't get vectorized. Don't change this loop without benchmarking
1087   // the program on a few compilers.
1088   size_t index = 0;
1089   for (size_t i = 0, e = indices.size(); i < e; i++) {
1090     index += size_t(strides[i]) * size_t(indices[i]);
1091   }
1092 
1093   return index;
1094 }
1095 
1096 /// Helper function which \returns true if a slice with the shape \p sliceShape
1097 /// referenced from a larger tensor with the shape \p tensorShape is contiguous
1098 /// in memory (assuming the tensor it is referenced from is contiguous). This
1099 /// happens when the slice dimensions:
1100 /// - Start with singleton dimensions (dimensions equal to 1).
1101 /// - Continue with a partially extracted dimension (one maximum).
1102 /// - End with fully extracted dimensions.
1103 bool isSliceContiguous(llvm::ArrayRef<dim_t> sliceShape,
1104                        llvm::ArrayRef<dim_t> tensorShape);
1105 
1106 /// A class that provides indexed access to a tensor. This class has value
1107 /// semantics and it's copied around. One of the reasons for making this class
1108 /// value semantics is to allow efficient index calculation that the compiler
1109 /// can optimize (because stack allocated structures don't alias).
1110 template <class ElemTy> class Handle final {
1111   /// A pointer to the tensor that this handle wraps.
1112   Tensor *tensor_{nullptr};
1113 
1114   /// Contains the multiplication of the sizes from current position to end.
1115   /// For example, for index (w,z,y,z):  [x * y * z, y * z, z, 1]
1116   dim_t sizeIntegral_[max_tensor_dimensions] = {
1117       0,
1118   };
1119 
1120   dim_t sizes_[max_tensor_dimensions] = {
1121       0,
1122   };
1123 
1124   /// Saves the number of dimensions used in the tensor.
1125   uint8_t numDims_{0};
1126 
1127   /// Remember end iterators. This is needed to speed up iterator increment,
1128   /// which has to check that iterator hasn't reached the end yet.
1129   HandleIterator<ElemTy, false> mutating_end_;
1130   HandleIterator<ElemTy, true> const_end_;
1131 
1132   /// Create a new invalid handle. Notice that this method is private and may
1133   /// only be used by the static factory method below.
1134   Handle() = default;
1135 
1136 public:
1137   /// \returns an iterator to the first element of the tensor.
begin()1138   HandleIterator<ElemTy, false> begin() {
1139     return HandleIterator<ElemTy, false>::begin(this);
1140   }
begin()1141   HandleIterator<ElemTy, true> begin() const {
1142     return HandleIterator<ElemTy, true>::begin(this);
1143   }
1144 
1145   /// \returns an iterator referring to the past-the-end element.
end()1146   HandleIterator<ElemTy, false> end() { return mutating_end_; }
end()1147   HandleIterator<ElemTy, true> end() const { return const_end_; }
1148 
1149   /// Allocate a new invalid handle.
createInvalidHandle()1150   static Handle createInvalidHandle() { return Handle(); }
1151 
1152   /// \returns true if this Handle points to a valid tensor.
isValid()1153   bool isValid() const { return tensor_; }
1154 
1155   /// Calculate the index for a specific element in the tensor. Notice that
1156   /// the list of indices may be incomplete. This method provides access to
1157   /// padding elements, meaning that it's possible to get an index pointing at
1158   /// data, added to meet alignment requirements.
getElementPtr(llvm::ArrayRef<dim_t> indices)1159   size_t getElementPtr(llvm::ArrayRef<dim_t> indices) const {
1160     return getFlattenedOffset(llvm::makeArrayRef(sizeIntegral_, numDims_),
1161                               indices);
1162   }
1163 
1164   /// \returns the value of the n'th dimension \p dim, for the index \p idx.
1165   /// 0 <= idx < size(), meaning that \p idx addresses a real data elements,
1166   /// not paddings.
getDimForPtr(size_t dim,size_t idx)1167   size_t getDimForPtr(size_t dim, size_t idx) const {
1168     assert(dim < numDims_ && "Invalid dimension");
1169     assert(idx < size() && "Invalid index");
1170     auto R = idx;
1171     for (size_t i = dim + 1; i < numDims_; i++) {
1172       R /= sizes_[i];
1173     }
1174     return R % sizes_[dim];
1175   }
1176 
1177   /// \returns the type of the tensor.
getType()1178   const Type &getType() const { return tensor_->getType(); }
1179 
1180   /// \returns the element type of the tensor.
getElementType()1181   ElemKind getElementType() const { return tensor_->getElementType(); }
1182 
1183   /// Construct a Tensor handle.
Handle(Tensor * tensor)1184   explicit Handle(Tensor *tensor) : tensor_(tensor) {
1185     auto sizes = tensor->dims();
1186     numDims_ = sizes.size();
1187 
1188     /// We allow handles that wrap uninitialized tensors.
1189     if (numDims_) {
1190       // Copy the sizes of the tensor.
1191       memcpy(sizes_, tensor_->type_.sizes_,
1192              max_tensor_dimensions * sizeof(sizes_[0]));
1193       // Copy the strides of the tensor.
1194       memcpy(sizeIntegral_, tensor_->type_.strides_,
1195              max_tensor_dimensions * sizeof(tensor_->type_.strides_[0]));
1196       assert(numDims_ <= max_tensor_dimensions && "Too many dimensions.");
1197     }
1198 
1199     mutating_end_ = HandleIterator<ElemTy, false>::end(this);
1200     const_end_ = HandleIterator<ElemTy, true>::end(this);
1201   }
1202 
dims()1203   llvm::ArrayRef<dim_t> dims() const {
1204     return llvm::ArrayRef<dim_t>(sizes_, numDims_);
1205   }
1206 
1207   /// \returns the number of elements in the whole tensor.
size()1208   dim_t size() const { return tensor_->size(); }
1209 
1210   /// \returns the actual number of elements in the tensor taking striding into
1211   /// account. Since size() does not take striding into account, size() is
1212   /// always <= actualSize().
actualSize()1213   dim_t actualSize() const { return tensor_->actualSize(); }
1214 
1215   /// \returns the unpadded size of the underlying \ref tensor_.
getUnpaddedSizeInBytes()1216   size_t getUnpaddedSizeInBytes() const {
1217     return tensor_->getUnpaddedSizeInBytes();
1218   }
1219 
isInBounds(llvm::ArrayRef<dim_t> indices)1220   bool isInBounds(llvm::ArrayRef<dim_t> indices) const {
1221     return tensor_->isInBounds(indices);
1222   }
1223 
1224   void clear(ElemTy value = 0) { std::fill(begin(), end(), value); }
1225 
1226   /// Returns reference to a meaningful data element. This method does not
1227   /// address padding elements.
at(llvm::ArrayRef<dim_t> indices)1228   ElemTy &at(llvm::ArrayRef<dim_t> indices) {
1229     size_t index = getElementPtr(indices);
1230     auto *data = tensor_->getRawDataPointer<ElemTy>();
1231     return data[index];
1232   }
1233 
at(llvm::ArrayRef<dim_t> indices)1234   const ElemTy &at(llvm::ArrayRef<dim_t> indices) const {
1235     size_t index = getElementPtr(indices);
1236     auto *data = tensor_->getRawDataPointer<ElemTy>();
1237     return data[index];
1238   }
1239 
1240   /// \returns the element at offset \p idx without any size calculations.
1241   /// The returned element can be a pad element.
raw(size_t index)1242   ElemTy &raw(size_t index) {
1243     auto *data = tensor_->getRawDataPointer<ElemTy>();
1244     return data[index];
1245   }
1246 
1247   /// \returns the element at offset \p idx without any size calculations.
1248   /// The returned element can be a pad element.
raw(size_t index)1249   const ElemTy &raw(size_t index) const {
1250     auto *data = tensor_->getRawDataPointer<ElemTy>();
1251     return data[index];
1252   }
1253 
1254   /// Extract a smaller dimension tensor from a specific slice (that has to be
1255   /// the first dimension).
extractSlice(size_t idx)1256   Tensor extractSlice(size_t idx) const {
1257     auto sizes = tensor_->dims();
1258     assert(sizes.size() > 1 && "Tensor must have at least two dimensions");
1259     assert(idx < sizes[0] && "Invalid first index");
1260 
1261     Tensor slice{Type::newShape(tensor_->getType(), sizes.slice(1),
1262                                 tensor_->type_.strides().slice(1))};
1263 
1264     // Extract the whole slice.
1265     size_t startIdx = sizeIntegral_[0] * idx;
1266     ElemTy *base = tensor_->getRawDataPointer<ElemTy>() + startIdx;
1267     auto *dest = slice.getRawDataPointer<ElemTy>();
1268     std::copy(base, base + sizeIntegral_[0], dest);
1269 
1270     return slice;
1271   }
1272 
1273   /// Insert a smaller dimension tensor into a larger tensor at a specific
1274   /// first-dimension index.
insertSlice(const Tensor & slice,size_t idx)1275   void insertSlice(const Tensor &slice, size_t idx) {
1276     auto dims = tensor_->dims();
1277     (void)dims;
1278     assert(getElementType() == slice.getElementType());
1279     assert(dims.size() > 1 && "Tensor must have at least two dimensions");
1280     assert(idx < dims[0] && "Invalid first index");
1281 
1282     auto sliceSize = sizeIntegral_[0];
1283     size_t startIdx = sliceSize * idx;
1284     ElemTy *base = &raw(startIdx);
1285     const ElemTy *slicePtr = slice.getRawDataPointer<float>();
1286     std::copy(slicePtr, slicePtr + sliceSize, base);
1287   }
1288 
1289   /// Create a new copy of the current tensor.
clone()1290   Tensor clone() const { return tensor_->clone(); }
1291 
1292   /// Update the content of the tensor from a literal list:
1293   void operator=(const std::initializer_list<ElemTy> &vec) {
1294     assert(actualSize() == vec.size() && "Invalid input size.");
1295     size_t i = 0;
1296     for (auto &e : vec) {
1297       raw(i++) = e;
1298     }
1299   }
1300 
1301   void operator=(llvm::ArrayRef<ElemTy> array) {
1302     assert(actualSize() == array.size() && "Invalid input size.");
1303     std::copy(array.begin(), array.end(), &raw(0));
1304   }
1305 
dumpAscii(llvm::raw_ostream & os)1306   void dumpAscii(llvm::raw_ostream &os) const { dumpAsciiImpl(tensor_, os); }
dumpAscii()1307   void dumpAscii() const { dumpAsciiImpl(tensor_); }
1308 
1309   /// \returns the raw indices of a min and max values from the tensor.
1310   /// In case of multiple min or max, the smallest index is returned.
minMaxArg()1311   std::pair<dim_t, dim_t> minMaxArg() const {
1312     ElemTy max = raw(0);
1313     ElemTy min = raw(0);
1314 
1315     size_t maxIdx = 0;
1316     size_t minIdx = 0;
1317 
1318     for (size_t i = 1, e = actualSize(); i < e; i++) {
1319       ElemTy val = raw(i);
1320       if (val > max) {
1321         max = val;
1322         maxIdx = i;
1323       } else if (val < min) {
1324         min = val;
1325         minIdx = i;
1326       }
1327     }
1328 
1329     return std::make_pair(minIdx, maxIdx);
1330   }
1331 
1332   /// \returns true if tensor contains only elements equal to zero.
1333   /// \p allowedError represents the delta from zero that is allowed before
1334   /// returning false.
1335   bool isZero(float allowedError = 0.0) const {
1336 #define RETURN_WHETHER_FUSED_IS_ZERO(DATA_TYPE)                                \
1337   assert(dims().size() == 2 && "Fused tensor must be 2-dimensional.");         \
1338   assert(dims()[1] > 2 * sizeof(DATA_TYPE) &&                                  \
1339          "Fused tensor must have space for scale/offset.");                    \
1340   const dim_t dataWidth = dims()[1];                                           \
1341   const dim_t alignedLength = tensor_->getType().strides()[0];                 \
1342   auto *data = reinterpret_cast<uint8_t *>(tensor_->getUnsafePtr());           \
1343   for (dim_t i = 0, e = dims()[0]; i < e; i++) {                               \
1344     uint8_t *scaleOffsetPtr =                                                  \
1345         data + i * alignedLength + dataWidth - 2 * sizeof(DATA_TYPE);          \
1346     DATA_TYPE scale, offset;                                                   \
1347     memcpy(&scale, scaleOffsetPtr, sizeof(DATA_TYPE));                         \
1348     memcpy(&offset, scaleOffsetPtr + sizeof(DATA_TYPE), sizeof(DATA_TYPE));    \
1349     for (dim_t j = 0, e = dataWidth - 2 * sizeof(DATA_TYPE); j < e; j++) {     \
1350       float currVal = (at({i, j}) * (float)scale) + (float)offset;             \
1351       if (std::abs(currVal) > allowedError) {                                  \
1352         return false;                                                          \
1353       }                                                                        \
1354     }                                                                          \
1355   }                                                                            \
1356   return true;
1357 
1358     if (getElementType() == ElemKind::UInt8FusedQTy) {
1359       RETURN_WHETHER_FUSED_IS_ZERO(float);
1360     }
1361     if (getElementType() == ElemKind::UInt8FusedFP16QTy) {
1362       RETURN_WHETHER_FUSED_IS_ZERO(float16_t);
1363     }
1364 #undef RETURN_WHETHER_FUSED_IS_ZERO
1365 
1366     int32_t trueZero = getType().isQuantizedType() ? getType().getOffset() : 0;
1367     return std::all_of(begin(), end(), [=](ElemTy e) { return e == trueZero; });
1368   }
1369 
1370   void dump(llvm::raw_ostream &os, unsigned maxNumElem = MAX_DUMP_ELEMS) const {
1371     dumpImpl(tensor_, os, maxNumElem);
1372   }
dump(unsigned maxNumElem)1373   void dump(unsigned maxNumElem) const { dumpImpl(tensor_, maxNumElem); }
dump()1374   void dump() const { dumpImpl(tensor_, MAX_DUMP_ELEMS); }
1375 
1376   /// Fill the array with random data that's close to zero using the
1377   /// Xavier method, based on the paper [Bengio and Glorot 2010].
1378   /// This type of initialization facilitates better training performance.
1379   /// The parameter \p filterSize is the number of "input" neurons in the
1380   /// tensor (or the relevant slice). For example, consider case of MatMul:
1381   /// NxM (\p input) * MxK (\p weights) == NxK (\p result)
1382   /// Correct \p filterSize for weights tensor is M, so that norm for each
1383   /// row of \p input equals to norm of corresponding row of \p result.
initXavier(size_t filterSize,PseudoRNG & PRNG)1384   void initXavier(size_t filterSize, PseudoRNG &PRNG) {
1385     assert(filterSize > 0 && "invalid filter size");
1386     assert(getType().isFPType() &&
1387            "Only support floating point Xavier initialization.");
1388     double scale = std::sqrt(3.0 / double(filterSize));
1389     std::uniform_real_distribution<> dist(-scale, scale);
1390     for (auto &e : *this) {
1391       e = dist(PRNG);
1392     }
1393   }
1394 
1395   /// Fill the tensor with uniformly distributed values in the range
1396   /// [low .. high).
1397   template <typename T = ElemTy>
1398   typename std::enable_if<std::is_floating_point<T>::value>::type
randomize(float low,float high,PseudoRNG & PRNG)1399   randomize(float low, float high, PseudoRNG &PRNG) {
1400     assert(low <= high && "invalid range");
1401     std::uniform_real_distribution<ElemTy> dist(low, high);
1402     for (auto &elem : *this) {
1403       elem = dist(PRNG);
1404     }
1405   }
1406 
1407   /// Fill the tensor with uniformly distributed values in the range
1408   /// [low .. high]. For quantized fused tensors leave scales/offsets unchanged.
1409   template <typename T = ElemTy>
1410   typename std::enable_if<std::is_integral<T>::value>::type
randomize(int low,int high,PseudoRNG & PRNG)1411   randomize(int low, int high, PseudoRNG &PRNG) {
1412     assert(low <= high && "invalid range");
1413     assert(low >= std::numeric_limits<ElemTy>::lowest() &&
1414            high <= std::numeric_limits<ElemTy>::max() &&
1415            "Cannot initialize outside range of representable values.");
1416     std::uniform_int_distribution<long long> dist(low, high);
1417     switch (getElementType()) {
1418     default: {
1419       for (auto &elem : *this) {
1420         elem = dist(PRNG);
1421       }
1422       return;
1423     }
1424 
1425 #define FUSED_CASE(ELEM_KIND, DATA_TYPE)                                       \
1426   case ElemKind::ELEM_KIND: {                                                  \
1427     assert(dims().size() == 2 && "Fused tensor must be 2-dimensional.");       \
1428     assert(dims()[1] > 2 * sizeof(DATA_TYPE) &&                                \
1429            "Fused tensor must have space for scale/offset.");                  \
1430     for (dim_t i = 0, e = dims()[0]; i < e; i++) {                             \
1431       for (dim_t j = 0, f = dims()[1] - 2 * sizeof(DATA_TYPE); j < f; j++) {   \
1432         at({i, j}) = dist(PRNG);                                               \
1433       }                                                                        \
1434     }                                                                          \
1435     return;                                                                    \
1436   }
1437       FUSED_CASE(UInt8FusedQTy, float);
1438       FUSED_CASE(UInt8FusedFP16QTy, float16_t);
1439 #undef FUSED_CASE
1440     }
1441   }
1442 
1443   /// Fill the tensor with uniformly distributed values in the range
1444   /// [low .. high).
1445   template <typename T = ElemTy>
1446   typename std::enable_if<!std::is_floating_point<T>::value &&
1447                           !std::is_integral<T>::value>::type
randomize(float low,float high,PseudoRNG & PRNG)1448   randomize(float low, float high, PseudoRNG &PRNG) {
1449     assert(low <= high && "invalid range");
1450     std::uniform_real_distribution<float> dist(low, high);
1451     for (auto &elem : *this) {
1452       elem = dist(PRNG);
1453     }
1454   }
1455 
1456   /// \returns the mean and variance of the tensor.
calculateMeanVariance()1457   std::pair<double, double> calculateMeanVariance() const {
1458     size_t n = actualSize();
1459     assert(n > 1 && "Input must have at least 2 elements.");
1460 
1461     // Calculate mean.
1462     double mean = 0;
1463     for (size_t i = 0; i < n; i++) {
1464       mean += raw({i});
1465     }
1466     mean /= n;
1467 
1468     // Calculate variance.
1469     double var = 0;
1470     for (size_t i = 0; i < n; i++) {
1471       double t = raw({i}) - mean;
1472       var += t * t;
1473     }
1474     var /= (n - 1);
1475 
1476     return {mean, var};
1477   }
1478 
1479   /// Insert the tensor \p slice at location \p offset \p count times along the
1480   /// \p axis. This operation is equivalent to the operation of scanning the
1481   /// source tensor, and saving the value that is stored at coordinate {d_0,
1482   /// d_1, ... d_n} in the new tensor at {d_0 + O_0, d_1 + O_1, ... d_n + O_n},
1483   /// where O is the offset vector, assuming \p count = 1. For \p count > 1, the
1484   /// same Tensor is copied \p count times along the provided \p axis. The
1485   /// tensors must be of the right dimensions.
1486   void insertTensors(Handle<ElemTy> &slice, llvm::ArrayRef<dim_t> offset,
1487                      size_t count = 1, size_t axis = 0) {
1488     auto sliceCoor = slice.dims().vec();
1489     auto fusedCoor = dims().vec();
1490     insertTensorsImpl(sliceCoor, fusedCoor, slice, true, offset, count, axis,
1491                       0);
1492   }
1493 
1494   /// Extract the tensor \p slice at location \p offset. This operation is
1495   /// equivalent to the operation of scanning the destination tensor, and
1496   /// copying into the cell at coordinate {d_0, d_1, ... d_n} a value from the
1497   /// tensor at {d_0 + O_0, d_1 + O_1, ... d_n + O_n}, where O is the offset
1498   /// vector. The tensors must be of the right dimensions.
extractTensors(Handle<ElemTy> & slice,llvm::ArrayRef<dim_t> offset)1499   void extractTensors(Handle<ElemTy> &slice, llvm::ArrayRef<dim_t> offset) {
1500     auto sliceCoor = slice.dims().vec();
1501     auto fusedCoor = dims().vec();
1502     insertTensorsImpl(sliceCoor, fusedCoor, slice, false, offset, /* count */ 1,
1503                       /* axis */ 0, 0);
1504   }
1505 
1506   /// \returns a pair of the scale and offset from a row \p rowIdx of a
1507   /// FusedRowwiseQuantized Tensor.
1508   template <typename T>
getFusedScaleOffsetFromRow(dim_t rowIdx)1509   std::pair<T, T> getFusedScaleOffsetFromRow(dim_t rowIdx) {
1510     ElemTy *rowScaleOffsetPtr = getFusedRowScaleOffsetPtr<T>(rowIdx);
1511     T scale;
1512     T offset;
1513     memcpy(&scale, rowScaleOffsetPtr, sizeof(T));
1514     memcpy(&offset, rowScaleOffsetPtr + sizeof(T), sizeof(T));
1515     return std::make_pair(scale, offset);
1516   }
1517 
1518   /// Sets the \p scale and \p offset to a row \p rowIdx of a
1519   /// FusedRowwiseQuantized Tensor.
1520   template <typename T>
setFusedScaleOffsetInRow(dim_t rowIdx,T scale,T offset)1521   void setFusedScaleOffsetInRow(dim_t rowIdx, T scale, T offset) {
1522     ElemTy *rowScaleOffsetPtr = getFusedRowScaleOffsetPtr<T>(rowIdx);
1523     T finalScale = static_cast<T>(scale);
1524     T finalOffset = static_cast<T>(offset);
1525     memcpy(rowScaleOffsetPtr, &finalScale, sizeof(T));
1526     memcpy(rowScaleOffsetPtr + sizeof(T), &finalOffset, sizeof(T));
1527   }
1528 
1529 private:
1530   /// Concats or splits tensors.
1531   /// This method concats or extracts a slice from a tensor.
1532   /// \p sliceCoor and \p fusedCoor are temporary storage that the function uses
1533   /// to construct the coordinates to access the tensor. They must be
1534   /// initialized to be the size of the shape of the tensor. \p slice and \p
1535   /// fused are the tensors to concat or extract. \p offset is the offset of the
1536   /// slice to add or extract along the dimension \p offsetDim. \p d is the
1537   /// recursion depth parameter that's following the number of the axis. if \p
1538   /// isInsert is set then data is copied from \p slice to \p fused. Otherwise
1539   /// data is copied from \p fused to \p slice. \p count and \p axis are used in
1540   /// conjunction for inserting the same tensor \p count times along the \p
1541   /// axis.
insertTensorsImpl(llvm::MutableArrayRef<dim_t> sliceCoor,llvm::MutableArrayRef<dim_t> fusedCoor,Handle<ElemTy> & slice,bool isInsert,llvm::ArrayRef<dim_t> offset,size_t count,size_t axis,unsigned d)1542   void insertTensorsImpl(llvm::MutableArrayRef<dim_t> sliceCoor,
1543                          llvm::MutableArrayRef<dim_t> fusedCoor,
1544                          Handle<ElemTy> &slice, bool isInsert,
1545                          llvm::ArrayRef<dim_t> offset, size_t count,
1546                          size_t axis, unsigned d) {
1547     bool isDone = (d == slice.dims().size());
1548 
1549     if (isDone) {
1550       if (isInsert) {
1551         at(fusedCoor) = slice.at(sliceCoor);
1552       } else {
1553         slice.at(sliceCoor) = at(fusedCoor);
1554       }
1555       return;
1556     }
1557 
1558     // Only need to iterate over count if the current dimension d is equal to
1559     // the axis we're inserting over.
1560     const size_t countIters = (axis == d) ? count : 1;
1561     for (size_t c = 0; c < countIters; c++) {
1562       for (size_t i = 0, e = slice.dims()[d]; i < e; i++) {
1563         // Construct the coordinates for the slice and for the joint shape.
1564         // Add the 'offset' to the dimension that we concat the shapes on.
1565         sliceCoor[d] = i;
1566         // If this is the correct axis to insert multiple times then calculate
1567         // the additional offset to use.
1568         const size_t countAxisOffset = (axis == d) ? c * slice.dims()[d] : 0;
1569         fusedCoor[d] = i + offset[d] + countAxisOffset;
1570         insertTensorsImpl(sliceCoor, fusedCoor, slice, isInsert, offset, count,
1571                           axis, d + 1);
1572       }
1573     }
1574   }
1575 
1576   /// Given a Fused tensor, \returns a pointer to the scale and offset with type
1577   /// \p T of a row \p rowIdx.
getFusedRowScaleOffsetPtr(dim_t rowIdx)1578   template <typename T> ElemTy *getFusedRowScaleOffsetPtr(dim_t rowIdx) {
1579     switch (getElementType()) {
1580     case ElemKind::UInt8FusedQTy: {
1581       constexpr auto isFloat = std::is_same<float, T>::value;
1582       DCHECK(isFloat) << "Expected float scale/offset";
1583       break;
1584     }
1585     case ElemKind::UInt4FusedFP16QTy:
1586     case ElemKind::UInt8FusedFP16QTy: {
1587       constexpr auto isFloat16 = std::is_same<float16_t, T>::value;
1588       DCHECK(isFloat16) << "Expected float16_t scale/offset";
1589       break;
1590     }
1591     default:
1592       llvm_unreachable("Must be used with Tensor of supported Fused ElemKind");
1593     }
1594 
1595     static_assert(std::is_same<uint8_t, ElemTy>::value,
1596                   "Handle of current Fused tensors expected to be uint8_t.");
1597     const dim_t colIdx = dims()[1] - 2 * sizeof(T);
1598     return &at({rowIdx, colIdx});
1599   }
1600 };
1601 
getHandle()1602 template <class ElemTy> Handle<ElemTy> Tensor::getHandle() & {
1603   assert(!isDeviceResident() && "Tensor must reside on host to access data.");
1604   assert(type_.isType<ElemTy>() && "Getting a handle to the wrong type.");
1605   return Handle<ElemTy>(this);
1606 }
1607 
getHandle()1608 template <class ElemTy> const Handle<ElemTy> Tensor::getHandle() const & {
1609   assert(!isDeviceResident() && "Tensor must reside on host to access data.");
1610   assert(type_.isType<ElemTy>() && "Getting a handle to the wrong type.");
1611   return Handle<ElemTy>(const_cast<Tensor *>(this));
1612 }
1613 
1614 llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Tensor &t);
1615 
1616 llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Tensor *t);
1617 } // namespace glow
1618 
1619 #endif // GLOW_BASE_TENSOR_H
1620