1 /**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "glow/Base/Tensor.h"
18
19 #include "glow/Base/Type.h"
20
21 #include "llvm/Support/NativeFormatting.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include <glog/logging.h>
24
25 using namespace glow;
26
27 namespace {
28
29 /// This is a helper method that's used in the visualization of tensors.
valueToChar(ElemTy input)30 template <class ElemTy> static char valueToChar(ElemTy input) {
31 char ch = ' ';
32 const double val = input;
33 if (val > 0.2) {
34 ch = '.';
35 }
36 if (val > 0.4) {
37 ch = ',';
38 }
39 if (val > 0.6) {
40 ch = ':';
41 }
42 if (val > 0.8) {
43 ch = 'o';
44 }
45 if (val > 1.0) {
46 ch = 'O';
47 }
48 if (val > 1.5) {
49 ch = '0';
50 }
51 if (val > 2.0) {
52 ch = '@';
53 }
54 if (val < -0.1) {
55 ch = '-';
56 }
57 if (val < -0.2) {
58 ch = '~';
59 }
60 if (val < -0.4) {
61 ch = '=';
62 }
63 if (val < -1.0) {
64 ch = '#';
65 }
66 return ch;
67 }
68
dumpShape(llvm::ArrayRef<dim_t> shape,llvm::raw_ostream & os)69 static void dumpShape(llvm::ArrayRef<dim_t> shape, llvm::raw_ostream &os) {
70 os << "shape: ( ";
71 for (auto &d : shape) {
72 os << d << " ";
73 }
74 os << ")";
75 }
76
77 template <class ElemTy>
dumpGenericImpl(Handle<ElemTy> handle,llvm::raw_ostream & os,unsigned maxNumElem)78 static void dumpGenericImpl(Handle<ElemTy> handle, llvm::raw_ostream &os,
79 unsigned maxNumElem) {
80 auto shape = handle.dims();
81 size_t numDims = shape.size();
82 auto &Ty = handle.getType();
83
84 // Check for 0-dimensional tensor.
85 if (!numDims) {
86 os << "[ Scalar containing: ";
87 llvm::write_double(os, handle.raw(0), llvm::FloatStyle::Fixed, 3);
88 os << " ]\n";
89 return;
90 }
91
92 // Output shape.
93 dumpShape(shape, os);
94 os << "\n";
95
96 // Check for tensor of size 0.
97 if (handle.getUnpaddedSizeInBytes() == 0) {
98 os << "[ tensor has no elements ]\n";
99 return;
100 }
101
102 ElemTy mx = handle.raw(0);
103 ElemTy mn = handle.raw(0);
104
105 for (auto elem : handle) {
106 mx = std::max(mx, elem);
107 mn = std::min(mn, elem);
108 }
109
110 // Check for zero tensor.
111 if (mn == ElemTy(.0) && mx == ElemTy(.0)) {
112 os << "[ Zero tensor ]\n";
113 return;
114 }
115
116 // Output max and min.
117 os << "max: ";
118 llvm::write_double(os, mx, llvm::FloatStyle::Fixed, 3);
119 os << " min: ";
120 llvm::write_double(os, mn, llvm::FloatStyle::Fixed, 3);
121 os << "\n";
122
123 os << "[";
124
125 for (size_t i = 0, e = std::min<size_t>(maxNumElem, handle.size()); i < e;
126 i++) {
127
128 // Print one open brace at the beginning of every row, slice, and tensor.
129 for (size_t j = 0, e = numDims - 1; numDims > 1 && j < e; j++) {
130 if (i % Ty.getSliceSize(j + 1) == 0) {
131 // This iteration of outer loop is a new row, slice or tensor.
132 os << "[";
133 }
134 }
135
136 // Print the value at the current index.
137 llvm::write_double(os, handle.raw(i), llvm::FloatStyle::Fixed, 3);
138
139 // Print one closed brace at the end of every row, slice, or tensor.
140 for (size_t j = 0, e = numDims - 1; numDims > 1 && j < e; j++) {
141 size_t next_index = i + 1;
142 if (next_index % Ty.getSliceSize(j + 1) == 0u) {
143 os << "]";
144 }
145 }
146
147 os << ", ";
148
149 // Print one newline at the end of every row, slice, or tensor.
150 for (size_t j = 0, e = numDims - 1; numDims > 1 && j < e; j++) {
151 size_t next_index = i + 1;
152 if (next_index % Ty.getSliceSize(j + 1) == 0u) {
153 // Next iteration of outer loop will be a new row, slice or tensor.
154 os << "\n";
155 }
156 }
157 }
158
159 if (handle.size() > maxNumElem) {
160 os << "...";
161 }
162
163 os << "]\n";
164
165 os.flush();
166 }
167
168 template <class ElemTy>
dumpAsciiGenericImpl(Handle<ElemTy> handle,llvm::raw_ostream & os)169 static void dumpAsciiGenericImpl(Handle<ElemTy> handle, llvm::raw_ostream &os) {
170 auto d = handle.dims();
171
172 if (d.size() == 2) {
173 for (dim_t x = 0; x < d[0]; x++) {
174 for (dim_t y = 0; y < d[1]; y++) {
175 auto val = handle.at({x, y});
176 os << valueToChar(val);
177 }
178 os << "\n";
179 }
180 } else if (d.size() == 3) {
181 // Print monochrome (one-color channel) tensors:
182 if (d[2] == 1) {
183 for (dim_t x = 0; x < d[0]; x++) {
184 for (dim_t y = 0; y < d[1]; y++) {
185 auto val = handle.at({x, y, 0});
186 os << valueToChar(val);
187 }
188 os << "\n";
189 }
190 } else {
191 for (dim_t z = 0; z < d[2]; z++) {
192 os << "\n";
193 for (dim_t x = 0; x < d[0]; x++) {
194 for (dim_t y = 0; y < d[1]; y++) {
195 auto val = handle.at({x, y, z});
196 os << valueToChar(val);
197 }
198 os << "\n";
199 }
200 }
201 }
202
203 } else {
204 llvm_unreachable("Invalid tensor size");
205 }
206
207 os.flush();
208 }
209
210 /// This is a slow generic transpose. This method performs a single for loop
211 /// over a single dimension, or if we've reached the last dimension perform a
212 /// single copy of a single element.
213 template <class ElemTy>
214 static void
transposeGenericImpl(const Handle<ElemTy> & src,Handle<ElemTy> & dest,dim_t * srcCoor,dim_t * destCoor,llvm::ArrayRef<unsigned_t> shuffle,unsigned depth=0)215 transposeGenericImpl(const Handle<ElemTy> &src, Handle<ElemTy> &dest,
216 dim_t *srcCoor, dim_t *destCoor,
217 llvm::ArrayRef<unsigned_t> shuffle, unsigned depth = 0) {
218 if (depth == shuffle.size()) {
219 auto srcIdx = llvm::ArrayRef<dim_t>(srcCoor, depth);
220 auto destIdx = llvm::ArrayRef<dim_t>(destCoor, depth);
221 dest.at(destIdx) = src.at(srcIdx);
222 return;
223 }
224
225 // Iterate over one dimension and continue recursively to the next dim.
226 for (dim_t x = 0, e = dest.dims()[depth]; x < e; x++) {
227 unsigned_t swizzledDepth = shuffle[depth];
228 srcCoor[swizzledDepth] = x;
229 destCoor[depth] = x;
230 transposeGenericImpl(src, dest, srcCoor, destCoor, shuffle, depth + 1);
231 }
232 }
233
234 /// Faster function for transposing a tensor for important/common tensor
235 /// shapes. If a transpose successfully occurs, the function \returns true;
236 /// otherwise it \returns false, representing no transpose occurred and some
237 /// other transpose function (e.g. transposeGenericImpl) must be called. \p
238 /// dest is the tensor to transpose, and \p shuffle defines how to transpose.
239 template <class ElemTy>
tryTransposeFastImpl(const Handle<ElemTy> & src,Handle<ElemTy> & dest,llvm::ArrayRef<unsigned_t> shuffle)240 static bool tryTransposeFastImpl(const Handle<ElemTy> &src,
241 Handle<ElemTy> &dest,
242 llvm::ArrayRef<unsigned_t> shuffle) {
243 const dim_t numDims = dest.dims().size();
244 dim_t srcCoorArr[max_tensor_dimensions];
245 dim_t destCoorArr[max_tensor_dimensions] = {0};
246 auto srcCoor = llvm::ArrayRef<dim_t>(srcCoorArr, numDims);
247 auto destCoor = llvm::ArrayRef<dim_t>(destCoorArr, numDims);
248
249 /// This defines a single depth of the for loop used to iterate over the
250 /// source and destination tensors for transposing.
251 #define TRANSPOSE_LOOP_LEVEL(DEPTH_) \
252 for (srcCoorArr[shuffle[DEPTH_]] = 0, destCoorArr[DEPTH_] = 0; \
253 destCoorArr[DEPTH_] < dest.dims()[DEPTH_]; \
254 srcCoorArr[shuffle[DEPTH_]]++, destCoorArr[DEPTH_]++)
255
256 switch (numDims) {
257 case 2:
258 TRANSPOSE_LOOP_LEVEL(1) {
259 TRANSPOSE_LOOP_LEVEL(0) { dest.at(destCoor) = src.at(srcCoor); }
260 }
261 return true;
262 case 4:
263 TRANSPOSE_LOOP_LEVEL(1) {
264 TRANSPOSE_LOOP_LEVEL(2) {
265 TRANSPOSE_LOOP_LEVEL(0) {
266 TRANSPOSE_LOOP_LEVEL(3) { dest.at(destCoor) = src.at(srcCoor); }
267 }
268 }
269 }
270 return true;
271 }
272 return false;
273 }
274
275 template <class ElemTy>
transposeSelectImpl(const Handle<ElemTy> & src,Handle<ElemTy> & dest,llvm::ArrayRef<unsigned_t> shuffle)276 static void transposeSelectImpl(const Handle<ElemTy> &src, Handle<ElemTy> &dest,
277 llvm::ArrayRef<unsigned_t> shuffle) {
278 bool transposeOccurred = tryTransposeFastImpl(src, dest, shuffle);
279 if (!transposeOccurred) {
280 dim_t srcCoor[max_tensor_dimensions];
281 dim_t destCoor[max_tensor_dimensions];
282 transposeGenericImpl(src, dest, srcCoor, destCoor, shuffle);
283 }
284 }
285
286 template <class ElemTy>
isTiledImpl(const Tensor * tensor,unsigned_t axis,dim_t size,bool fractional)287 static bool isTiledImpl(const Tensor *tensor, unsigned_t axis, dim_t size,
288 bool fractional) {
289 assert(axis < tensor->dims().size() && "Axis parameter invalid!");
290 assert(size <= tensor->dims()[axis] && "Size parameter invalid!");
291 assert(size >= 1 && "Size parameter invalid!");
292
293 // When the tile size matches the dimension size then we return true.
294 // This is because a tensor can be considered a tiled version of itself.
295 if (size == tensor->dims()[axis]) {
296 return true;
297 }
298
299 // If fractional tiling verification is disabled and the dimension size
300 // is NOT divisible by the tile size then we return false.
301 if (!fractional && ((tensor->dims()[axis] % size) != 0)) {
302 return false;
303 }
304
305 static_assert(max_tensor_dimensions == 6,
306 "Implementation assumes max_tensor_dimensions = 6.");
307
308 // Get tensor view with maximum number of dimensions.
309 auto dimsMax = expandDimsToMax(tensor->dims());
310 Tensor tensorMax = tensor->getUnowned(dimsMax);
311 auto tensorH = tensorMax.getHandle<ElemTy>();
312 for (dim_t idx0 = 0; idx0 < dimsMax[0]; ++idx0) {
313 for (dim_t idx1 = 0; idx1 < dimsMax[1]; ++idx1) {
314 for (dim_t idx2 = 0; idx2 < dimsMax[2]; ++idx2) {
315 for (dim_t idx3 = 0; idx3 < dimsMax[3]; ++idx3) {
316 for (dim_t idx4 = 0; idx4 < dimsMax[4]; ++idx4) {
317 for (dim_t idx5 = 0; idx5 < dimsMax[5]; ++idx5) {
318 std::vector<dim_t> idx = {idx0, idx1, idx2, idx3, idx4, idx5};
319 std::vector<dim_t> idxWrapped = idx;
320 idxWrapped[axis] = (idx[axis] % size);
321 double delta = tensorH.at(idx) - tensorH.at(idxWrapped);
322 // Since any comparison with NAN returns false, we use a negated
323 // condition so that this function correctly returns false when
324 // delta is NAN.
325 if (!(delta == 0.0)) {
326 return false;
327 }
328 }
329 }
330 }
331 }
332 }
333 }
334 return true;
335 }
336 } // namespace
337
dumpAsciiImpl(const Tensor * T,llvm::raw_ostream & os)338 void glow::dumpAsciiImpl(const Tensor *T, llvm::raw_ostream &os) {
339 switch (T->getElementType()) {
340 case ElemKind::FloatTy:
341 return dumpAsciiGenericImpl(T->getHandle<float>(), os);
342 case ElemKind::Float16Ty:
343 return dumpAsciiGenericImpl(T->getHandle<float16_t>(), os);
344 case ElemKind::BFloat16Ty:
345 return dumpAsciiGenericImpl(T->getHandle<bfloat16_t>(), os);
346 case ElemKind::Int8QTy:
347 return dumpAsciiGenericImpl(T->getHandle<int8_t>(), os);
348 case ElemKind::UInt8QTy:
349 return dumpAsciiGenericImpl(T->getHandle<uint8_t>(), os);
350 case ElemKind::Int16QTy:
351 return dumpAsciiGenericImpl(T->getHandle<int16_t>(), os);
352 case ElemKind::Int32QTy:
353 return dumpAsciiGenericImpl(T->getHandle<int32_t>(), os);
354 case ElemKind::Int32ITy:
355 return dumpAsciiGenericImpl(T->getHandle<int32_t>(), os);
356 case ElemKind::Int64ITy:
357 return dumpAsciiGenericImpl(T->getHandle<int64_t>(), os);
358 case ElemKind::UInt8FusedQTy:
359 return dumpAsciiGenericImpl(T->getHandle<uint8_t>(), os);
360 case ElemKind::UInt8FusedFP16QTy:
361 return dumpAsciiGenericImpl(T->getHandle<uint8_t>(), os);
362 case ElemKind::UInt4FusedFP16QTy:
363 return dumpAsciiGenericImpl(T->getHandle<uint8_t>(), os);
364 case ElemKind::BoolTy:
365 return dumpAsciiGenericImpl(T->getHandle<bool>(), os);
366 }
367 }
368
dumpAsciiImpl(const Tensor * T)369 void glow::dumpAsciiImpl(const Tensor *T) { dumpAsciiImpl(T, llvm::outs()); }
370
dumpImpl(const Tensor * T,llvm::raw_ostream & os,unsigned maxNumElem)371 void glow::dumpImpl(const Tensor *T, llvm::raw_ostream &os,
372 unsigned maxNumElem) {
373 switch (T->getElementType()) {
374 case ElemKind::FloatTy:
375 return dumpGenericImpl(T->getHandle<float>(), os, maxNumElem);
376 case ElemKind::Float16Ty:
377 return dumpGenericImpl(T->getHandle<float16_t>(), os, maxNumElem);
378 case ElemKind::BFloat16Ty:
379 return dumpGenericImpl(T->getHandle<bfloat16_t>(), os, maxNumElem);
380 case ElemKind::Int8QTy:
381 return dumpGenericImpl(T->getHandle<int8_t>(), os, maxNumElem);
382 case ElemKind::UInt8QTy:
383 return dumpGenericImpl(T->getHandle<uint8_t>(), os, maxNumElem);
384 case ElemKind::Int16QTy:
385 return dumpGenericImpl(T->getHandle<int16_t>(), os, maxNumElem);
386 case ElemKind::Int32QTy:
387 return dumpGenericImpl(T->getHandle<int32_t>(), os, maxNumElem);
388 case ElemKind::Int32ITy:
389 return dumpGenericImpl(T->getHandle<int32_t>(), os, maxNumElem);
390 case ElemKind::Int64ITy:
391 return dumpGenericImpl(T->getHandle<int64_t>(), os, maxNumElem);
392 case ElemKind::UInt8FusedQTy:
393 return dumpGenericImpl(T->getHandle<uint8_t>(), os, maxNumElem);
394 case ElemKind::UInt8FusedFP16QTy:
395 return dumpGenericImpl(T->getHandle<uint8_t>(), os, maxNumElem);
396 case ElemKind::UInt4FusedFP16QTy:
397 return dumpGenericImpl(T->getHandle<uint8_t>(), os, maxNumElem);
398 case ElemKind::BoolTy:
399 return dumpGenericImpl(T->getHandle<bool>(), os, maxNumElem);
400 }
401 }
402
dumpImpl(const Tensor * T,unsigned maxNumElem)403 void glow::dumpImpl(const Tensor *T, unsigned maxNumElem) {
404 dumpImpl(T, llvm::outs(), maxNumElem);
405 }
406
dumpImpl(const Tensor * T)407 void glow::dumpImpl(const Tensor *T) { dumpImpl(T, llvm::outs()); }
408
409 // Dump functions.
dump(llvm::raw_ostream & os) const410 void Tensor::dump(llvm::raw_ostream &os) const { dumpImpl(this, os); }
411
dump() const412 void Tensor::dump() const { dumpImpl(this, llvm::outs()); }
413
toString() const414 std::string Tensor::toString() const {
415 std::string storage;
416 llvm::raw_string_ostream os(storage);
417 dumpImpl(this, os);
418 return os.str();
419 }
420
getShapeToString() const421 std::string Tensor::getShapeToString() const {
422 std::string storage;
423 llvm::raw_string_ostream os(storage);
424 dumpShape(dims(), os);
425 return os.str();
426 }
427
dump(llvm::raw_ostream & os,unsigned maxNumElem) const428 void Tensor::dump(llvm::raw_ostream &os, unsigned maxNumElem) const {
429 dumpImpl(this, os, maxNumElem);
430 }
431
dump(unsigned maxNumElem) const432 void Tensor::dump(unsigned maxNumElem) const {
433 dumpImpl(this, llvm::outs(), maxNumElem);
434 }
435
toString(unsigned maxNumElem) const436 std::string Tensor::toString(unsigned maxNumElem) const {
437 std::string storage;
438 llvm::raw_string_ostream os(storage);
439 dumpImpl(this, os, maxNumElem);
440 return os.str();
441 }
442
443 /// Dump a textual representation of a specific number of elements in the Tensor
444 /// to std::string.
445
genericTranspose(const Tensor * src,Tensor * dest,llvm::ArrayRef<unsigned_t> shuffle)446 void glow::genericTranspose(const Tensor *src, Tensor *dest,
447 llvm::ArrayRef<unsigned_t> shuffle) {
448 DCHECK(src->dims().size() == shuffle.size())
449 << "Invalid dimensions " << src->dims().size()
450 << " != " << src->dims().size();
451
452 dim_t newSizes[max_tensor_dimensions];
453
454 // Generate the swizzled dimensions.
455 auto origDims = src->dims();
456 for (unsigned i = 0; i < origDims.size(); i++) {
457 newSizes[i] = origDims[shuffle[i]];
458 }
459
460 // Resize the tensor to the transposed shape.
461 auto destType = Type::newShape(src->getType(), {newSizes, origDims.size()});
462 // genericTranspose function doesn't know how to set non-trivial strides and
463 // alignments and it cannot figure out the correct ones as it can be
464 // backend-specific. Therefore set the type to destType only if it is not set
465 // properly by the caller yet.
466 // Reset should be called anyways to allocate memory for the tensor.
467 if (dest->dims() != destType.dims()) {
468 dest->reset(destType);
469 } else {
470 dest->reset(dest->getType());
471 }
472
473 // fill with 0 for padding bytes.
474 if (src->actualSize() != dest->actualSize()) {
475 dest->zero();
476 }
477
478 switch (src->getElementType()) {
479 case ElemKind::FloatTy: {
480 auto srcH = src->getHandle<float>();
481 auto destH = dest->getHandle<float>();
482 transposeSelectImpl(srcH, destH, shuffle);
483 return;
484 }
485 case ElemKind::Float16Ty: {
486 auto srcH = src->getHandle<float16_t>();
487 auto destH = dest->getHandle<float16_t>();
488 transposeSelectImpl(srcH, destH, shuffle);
489 return;
490 }
491 case ElemKind::BFloat16Ty: {
492 auto srcH = src->getHandle<bfloat16_t>();
493 auto destH = dest->getHandle<bfloat16_t>();
494 transposeSelectImpl(srcH, destH, shuffle);
495 return;
496 }
497 case ElemKind::Int8QTy: {
498 auto srcH = src->getHandle<int8_t>();
499 auto destH = dest->getHandle<int8_t>();
500 transposeSelectImpl(srcH, destH, shuffle);
501 return;
502 }
503 case ElemKind::UInt8QTy: {
504 auto srcH = src->getHandle<uint8_t>();
505 auto destH = dest->getHandle<uint8_t>();
506 transposeSelectImpl(srcH, destH, shuffle);
507 return;
508 }
509 case ElemKind::Int16QTy: {
510 auto srcH = src->getHandle<int16_t>();
511 auto destH = dest->getHandle<int16_t>();
512 transposeSelectImpl(srcH, destH, shuffle);
513 return;
514 }
515 case ElemKind::Int32QTy: {
516 auto srcH = src->getHandle<int32_t>();
517 auto destH = dest->getHandle<int32_t>();
518 transposeSelectImpl(srcH, destH, shuffle);
519 return;
520 }
521 case ElemKind::Int32ITy: {
522 auto srcH = src->getHandle<int32_t>();
523 auto destH = dest->getHandle<int32_t>();
524 transposeSelectImpl(srcH, destH, shuffle);
525 return;
526 }
527 case ElemKind::Int64ITy: {
528 auto srcH = src->getHandle<int64_t>();
529 auto destH = dest->getHandle<int64_t>();
530 transposeSelectImpl(srcH, destH, shuffle);
531 return;
532 }
533 case ElemKind::UInt8FusedQTy: {
534 llvm_unreachable("Transposing UInt8FusedQTy is unsupported.");
535 }
536 case ElemKind::UInt8FusedFP16QTy: {
537 llvm_unreachable("Transposing UInt8FusedFP16QTy is unsupported.");
538 }
539 case ElemKind::UInt4FusedFP16QTy: {
540 llvm_unreachable("Transposing UInt4FusedFP16QTy is unsupported.");
541 }
542 case ElemKind::BoolTy: {
543 auto srcH = src->getHandle<bool>();
544 auto destH = dest->getHandle<bool>();
545 transposeSelectImpl(srcH, destH, shuffle);
546 return;
547 }
548 }
549 }
550
expandDimsToMax(llvm::ArrayRef<dim_t> currDims)551 ShapeVector glow::expandDimsToMax(llvm::ArrayRef<dim_t> currDims) {
552 ShapeVector newDims(currDims.begin(), currDims.end());
553 for (size_t i = newDims.size(); i < max_tensor_dimensions; i++) {
554 newDims.push_back(1);
555 }
556 return newDims;
557 }
558
reduceDims(llvm::ArrayRef<dim_t> dims,llvm::ArrayRef<unsigned_t> axes,bool keepDims)559 ShapeVector glow::reduceDims(llvm::ArrayRef<dim_t> dims,
560 llvm::ArrayRef<unsigned_t> axes, bool keepDims) {
561 ShapeVector newDims;
562 for (unsigned_t dim = 0, end = dims.size(); dim < end; ++dim) {
563 auto it = std::find(axes.begin(), axes.end(), dim);
564 bool dimReduced = (it != axes.end());
565 if (dimReduced) {
566 if (keepDims) {
567 newDims.push_back(1);
568 } else {
569 continue;
570 }
571 } else {
572 newDims.push_back(dims[dim]);
573 }
574 }
575 return newDims;
576 }
577
init(InitKind init,float val,PseudoRNG & PRNG)578 void Tensor::init(InitKind init, float val, PseudoRNG &PRNG) {
579 assert(!isDeviceResident() && "Tensor must reside on host to access data.");
580 switch (init) {
581 case InitKind::Zero:
582 zero();
583 break;
584
585 case InitKind::Broadcast: {
586 switch (getElementType()) {
587 case ElemKind::FloatTy: {
588 getHandle<float>().clear(val);
589 break;
590 }
591 case ElemKind::Float16Ty: {
592 getHandle<float16_t>().clear(float16_t(val));
593 break;
594 }
595 case ElemKind::BFloat16Ty: {
596 getHandle<bfloat16_t>().clear(bfloat16_t(val));
597 break;
598 }
599 case ElemKind::Int8QTy: {
600 getHandle<int8_t>().clear(val);
601 break;
602 }
603 case ElemKind::UInt8QTy: {
604 getHandle<uint8_t>().clear(val);
605 break;
606 }
607 case ElemKind::Int16QTy: {
608 getHandle<int16_t>().clear(val);
609 break;
610 }
611 case ElemKind::Int32QTy: {
612 getHandle<int32_t>().clear(val);
613 break;
614 }
615 case ElemKind::Int32ITy: {
616 getHandle<int32_t>().clear(val);
617 break;
618 }
619 case ElemKind::Int64ITy: {
620 getHandle<int64_t>().clear(val);
621 break;
622 }
623
624 #define FUSED_CASE(ELEM_KIND, DATA_TYPE) \
625 case ElemKind::ELEM_KIND: { \
626 DCHECK(dims().size() == 2) \
627 << "Fused tensor must be 2-dimensional but instead has " \
628 << dims().size() << " dimensions."; \
629 DCHECK(dims()[1] > 2 * sizeof(DATA_TYPE)) \
630 << "Fused tensor must have space for scale/offset, but only has " \
631 << dims()[1] << " columns."; \
632 auto H = getHandle<uint8_t>(); \
633 for (dim_t i = 0; i < dims()[0]; i++) { \
634 for (dim_t j = 0, f = dims()[1] - 2 * sizeof(DATA_TYPE); j < f; j++) { \
635 H.at({i, j}) = val; \
636 } \
637 } \
638 break; \
639 }
640 FUSED_CASE(UInt8FusedQTy, float);
641 FUSED_CASE(UInt8FusedFP16QTy, float16_t);
642 FUSED_CASE(UInt4FusedFP16QTy, float16_t);
643 #undef FUSED_CASE
644
645 case ElemKind::BoolTy: {
646 getHandle<bool>().clear(val);
647 break;
648 }
649 }
650 break;
651 }
652
653 case InitKind::Xavier: {
654 switch (getElementType()) {
655 case ElemKind::FloatTy: {
656 getHandle<float>().initXavier(val, PRNG);
657 break;
658 }
659 case ElemKind::Float16Ty: {
660 getHandle<float16_t>().initXavier(val, PRNG);
661 break;
662 }
663 case ElemKind::BFloat16Ty: {
664 getHandle<bfloat16_t>().initXavier(val, PRNG);
665 break;
666 }
667 default: {
668 llvm_unreachable("Undefined to Xavier-initialize non-Float Tensors.");
669 }
670 }
671 break;
672 }
673 }
674 }
675
convertToType(ElemKind newTy)676 void Tensor::convertToType(ElemKind newTy) {
677 assert(!isDeviceResident() && "Tensor must reside on host to access data.");
678 *this = this->getCopyConvertedToType(newTy);
679 }
680
getCopyConvertedToType(ElemKind newKind) const681 Tensor Tensor::getCopyConvertedToType(ElemKind newKind) const {
682 assert(!isDeviceResident() && "Tensor must reside on host to access data.");
683 const ElemKind origKind = getElementType();
684 DCHECK((origKind == ElemKind::FloatTy && newKind == ElemKind::Float16Ty) ||
685 (origKind == ElemKind::FloatTy && newKind == ElemKind::BFloat16Ty) ||
686 (origKind == ElemKind::FloatTy && newKind == ElemKind::Int32ITy) ||
687 (origKind == ElemKind::FloatTy && newKind == ElemKind::Int64ITy) ||
688 (origKind == ElemKind::Float16Ty && newKind == ElemKind::FloatTy) ||
689 (origKind == ElemKind::BFloat16Ty && newKind == ElemKind::FloatTy) ||
690 (origKind == ElemKind::Int64ITy && newKind == ElemKind::Int32ITy) ||
691 (origKind == ElemKind::Int64ITy && newKind == ElemKind::FloatTy) ||
692 (origKind == ElemKind::Int32ITy && newKind == ElemKind::Int64ITy) ||
693 (origKind == ElemKind::Int32ITy && newKind == ElemKind::FloatTy) ||
694 (origKind == ElemKind::UInt8FusedQTy &&
695 newKind == ElemKind::UInt8FusedFP16QTy))
696 << "Conversion from " << Type::getElementName(origKind).str() << " to "
697 << Type::getElementName(newKind).str() << " is not yet implemented";
698
699 if (!isQuantizedElemKind(newKind)) {
700 Tensor tmp(newKind, dims());
701 switch (newKind) {
702 case ElemKind::Float16Ty:
703 tmp.copyWithCast<float16_t, float>(this);
704 break;
705 case ElemKind::BFloat16Ty:
706 tmp.copyWithCast<bfloat16_t, float>(this);
707 break;
708
709 case ElemKind::FloatTy:
710 if (getElementType() == ElemKind::Int32ITy) {
711 tmp.copyWithCast<float, int32_t>(this);
712 } else if (getElementType() == ElemKind::Int64ITy) {
713 tmp.copyWithCast<float, int64_t>(this);
714 } else if (getElementType() == ElemKind::Float16Ty) {
715 tmp.copyWithCast<float, float16_t>(this);
716 } else if (getElementType() == ElemKind::BFloat16Ty) {
717 tmp.copyWithCast<float, bfloat16_t>(this);
718 } else if (getElementType() == ElemKind::FloatTy) {
719 tmp.copyRawFrom(this);
720 } else {
721 llvm_unreachable("Invalid conversion to FLOAT.");
722 }
723 break;
724
725 case ElemKind::Int32ITy:
726 if (getElementType() == ElemKind::Int64ITy) {
727 tmp.copyWithCast<int32_t, int64_t>(this);
728 } else if (getElementType() == ElemKind::FloatTy) {
729 tmp.copyWithCast<int32_t, float>(this);
730 } else {
731 llvm_unreachable("Invalid conversion from FLOAT.");
732 }
733 break;
734 case ElemKind::Int64ITy:
735 if (getElementType() == ElemKind::Int32ITy) {
736 tmp.copyWithCast<int64_t, int32_t>(this);
737 } else {
738 llvm_unreachable("Invalid conversion from FLOAT.");
739 }
740 break;
741
742 default:
743 llvm_unreachable("Type not supported");
744 }
745 return tmp;
746 }
747
748 // Handle Fused conversion. Currently only supports UInt8FusedQTy ->
749 // UInt8FusedFP16QTy.
750 DCHECK(origKind == ElemKind::UInt8FusedQTy && dims().size() == 2)
751 << "UInt8FusedQTy must be 2 dimensional.";
752 Tensor tmp(newKind,
753 {dims()[0], dims()[1] - 2 * ((dim_t)sizeof(float) -
754 (dim_t)sizeof(float16_t))},
755 1.0, 0);
756
757 const size_t dstWidth = tmp.dims()[1];
758 auto srcH = getHandle<uint8_t>();
759 auto dstH = tmp.getHandle<uint8_t>();
760 for (dim_t i = 0, e = dims()[0]; i < e; i++) {
761 // Copy the scale/offset from src to dst.
762 float scale, offset;
763 std::tie(scale, offset) = srcH.getFusedScaleOffsetFromRow<float>(i);
764 dstH.setFusedScaleOffsetInRow<float16_t>(i, static_cast<float16_t>(scale),
765 static_cast<float16_t>(offset));
766
767 // Copy over the row's uint8 data from src to dst; scales and offsets were
768 // already copied over above.
769 for (dim_t j = 0, f = dstWidth - 2 * sizeof(float16_t); j < f; j++) {
770 dstH.at({i, j}) = srcH.at({i, j});
771 }
772 }
773 return tmp;
774 }
775
776 namespace glow {
operator <<(llvm::raw_ostream & os,const Tensor & t)777 llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Tensor &t) {
778 t.dump(os);
779 return os;
780 }
781
operator <<(llvm::raw_ostream & os,const Tensor * t)782 llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Tensor *t) {
783 assert(t != nullptr && "Null Pointer.");
784 t->dump(os);
785 return os;
786 }
787
moveToDevice(DeviceTensorTransferManager * deviceManager,void * locationContext)788 void Tensor::moveToDevice(DeviceTensorTransferManager *deviceManager,
789 void *locationContext) {
790 if (deviceResidency_ == nullptr) {
791 deviceResidency_ = new DeviceResidencyInfo();
792 }
793 deviceResidency_->deviceManager_ = deviceManager;
794 deviceResidency_->locationContext_ = locationContext;
795 deviceResidency_->tensorResidency_ =
796 DeviceResidencyInfo::TensorResidency::Device;
797 }
798
ensureOnHost()799 void Tensor::ensureOnHost() {
800 if (deviceResidency_ == nullptr) {
801 // already on host.
802 return;
803 }
804 if (deviceResidency_->isDeviceResident()) {
805 deviceResidency_->deviceManager_->transferFromDevice(*this);
806 }
807 assert(!isDeviceResident());
808 }
809
copyRawToDevice(const Tensor * t)810 void Tensor::copyRawToDevice(const Tensor *t) {
811 assert(isDeviceResident());
812 void *locationContext = deviceResidency_->locationContext_;
813 DeviceTensorTransferManager *DM = deviceResidency_->deviceManager_;
814 clearDeviceResidency();
815 copyRawFrom(t);
816 DM->transferToDevice(*this, locationContext);
817 }
818
isTiled(unsigned_t axis,dim_t size,bool fractional) const819 bool Tensor::isTiled(unsigned_t axis, dim_t size, bool fractional) const {
820 switch (getElementType()) {
821 case ElemKind::FloatTy: {
822 return isTiledImpl<float>(this, axis, size, fractional);
823 }
824 case ElemKind::Float16Ty: {
825 return isTiledImpl<float16_t>(this, axis, size, fractional);
826 }
827 case ElemKind::Int8QTy: {
828 return isTiledImpl<int8_t>(this, axis, size, fractional);
829 }
830 case ElemKind::UInt8QTy: {
831 return isTiledImpl<uint8_t>(this, axis, size, fractional);
832 }
833 case ElemKind::Int16QTy: {
834 return isTiledImpl<int16_t>(this, axis, size, fractional);
835 }
836 case ElemKind::Int32QTy: {
837 return isTiledImpl<int32_t>(this, axis, size, fractional);
838 }
839 case ElemKind::Int32ITy: {
840 return isTiledImpl<int32_t>(this, axis, size, fractional);
841 }
842 case ElemKind::Int64ITy: {
843 return isTiledImpl<int64_t>(this, axis, size, fractional);
844 }
845 case ElemKind::BoolTy: {
846 return isTiledImpl<bool>(this, axis, size, fractional);
847 }
848 default: { llvm_unreachable("isTiled: Precision not supported!"); }
849 }
850 }
851
isTiled(llvm::ArrayRef<unsigned_t> axes,llvm::ArrayRef<dim_t> sizes,bool fractional) const852 bool Tensor::isTiled(llvm::ArrayRef<unsigned_t> axes,
853 llvm::ArrayRef<dim_t> sizes, bool fractional) const {
854 assert(axes.size() == sizes.size() &&
855 "Mismatch between axes and sizes length!");
856 for (size_t idx = 0, end = axes.size(); idx < end; ++idx) {
857 if (!isTiled(axes[idx], sizes[idx], fractional)) {
858 return false;
859 }
860 }
861 return true;
862 }
863
isSliceContiguous(llvm::ArrayRef<dim_t> sliceShape,llvm::ArrayRef<dim_t> tensorShape)864 bool isSliceContiguous(llvm::ArrayRef<dim_t> sliceShape,
865 llvm::ArrayRef<dim_t> tensorShape) {
866 assert(sliceShape.size() == tensorShape.size() &&
867 "Array length mismatch for slice/tensor sizes!");
868 // Search first non-singleton slice dimension. If all the dimensions are
869 // singleton then by convention the first non-singleton dimension is the
870 // slice size.
871 size_t firstNonSingleDim = sliceShape.size();
872 for (size_t dim = 0, dimEnd = sliceShape.size(); dim < dimEnd; ++dim) {
873 if (sliceShape[dim] != 1) {
874 firstNonSingleDim = dim;
875 break;
876 }
877 }
878 // First non-singleton slice dimension can be partially or fully extracted.
879 // The following dimensions must be fully extracted.
880 for (size_t dim = firstNonSingleDim + 1, dimEnd = sliceShape.size();
881 dim < dimEnd; ++dim) {
882 if (sliceShape[dim] != tensorShape[dim]) {
883 return false;
884 }
885 }
886 return true;
887 }
888
889 } // namespace glow
890