1 //===- SparseUtils.cpp - Sparse Utils for MLIR execution ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a light-weight runtime support library that is useful
10 // for sparse tensor manipulations. The functionality provided in this library
11 // is meant to simplify benchmarking, testing, and debugging MLIR code that
12 // operates on sparse tensors. The provided functionality is **not** part
13 // of core MLIR, however.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #include "mlir/ExecutionEngine/CRunnerUtils.h"
18
19 #ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
20
21 #include <algorithm>
22 #include <cassert>
23 #include <cctype>
24 #include <cinttypes>
25 #include <cstdio>
26 #include <cstdlib>
27 #include <cstring>
28 #include <vector>
29
30 //===----------------------------------------------------------------------===//
31 //
32 // Internal support for storing and reading sparse tensors.
33 //
34 // The following memory-resident sparse storage schemes are supported:
35 //
36 // (a) A coordinate scheme for temporarily storing and lexicographically
37 // sorting a sparse tensor by index.
38 //
39 // (b) A "one-size-fits-all" sparse storage scheme defined by per-rank
40 // sparse/dense annnotations to be used by generated MLIR code.
41 //
42 // The following external formats are supported:
43 //
44 // (1) Matrix Market Exchange (MME): *.mtx
45 // https://math.nist.gov/MatrixMarket/formats.html
46 //
47 // (2) Formidable Repository of Open Sparse Tensors and Tools (FROSTT): *.tns
48 // http://frostt.io/tensors/file-formats.html
49 //
50 //===----------------------------------------------------------------------===//
51
52 namespace {
53
54 /// A sparse tensor element in coordinate scheme (value and indices).
55 /// For example, a rank-1 vector element would look like
56 /// ({i}, a[i])
57 /// and a rank-5 tensor element like
58 /// ({i,j,k,l,m}, a[i,j,k,l,m])
59 struct Element {
Element__anon52f405c70111::Element60 Element(const std::vector<uint64_t> &ind, double val)
61 : indices(ind), value(val){};
62 std::vector<uint64_t> indices;
63 double value;
64 };
65
66 /// A memory-resident sparse tensor in coordinate scheme (collection of
67 /// elements). This data structure is used to read a sparse tensor from
68 /// external file format into memory and sort the elements lexicographically
69 /// by indices before passing it back to the client (most packed storage
70 /// formats require the elements to appear in lexicographic index order).
71 struct SparseTensor {
72 public:
SparseTensor__anon52f405c70111::SparseTensor73 SparseTensor(const std::vector<uint64_t> &szs, uint64_t capacity)
74 : sizes(szs), pos(0) {
75 elements.reserve(capacity);
76 }
77 /// Adds element as indices and value.
add__anon52f405c70111::SparseTensor78 void add(const std::vector<uint64_t> &ind, double val) {
79 assert(getRank() == ind.size());
80 for (int64_t r = 0, rank = getRank(); r < rank; r++)
81 assert(ind[r] < sizes[r]); // within bounds
82 elements.emplace_back(Element(ind, val));
83 }
84 /// Sorts elements lexicographically by index.
sort__anon52f405c70111::SparseTensor85 void sort() { std::sort(elements.begin(), elements.end(), lexOrder); }
86 /// Primitive one-time iteration.
next__anon52f405c70111::SparseTensor87 const Element &next() { return elements[pos++]; }
88 /// Returns rank.
getRank__anon52f405c70111::SparseTensor89 uint64_t getRank() const { return sizes.size(); }
90 /// Getter for sizes array.
getSizes__anon52f405c70111::SparseTensor91 const std::vector<uint64_t> &getSizes() const { return sizes; }
92 /// Getter for elements array.
getElements__anon52f405c70111::SparseTensor93 const std::vector<Element> &getElements() const { return elements; }
94
95 private:
96 /// Returns true if indices of e1 < indices of e2.
lexOrder__anon52f405c70111::SparseTensor97 static bool lexOrder(const Element &e1, const Element &e2) {
98 assert(e1.indices.size() == e2.indices.size());
99 for (int64_t r = 0, rank = e1.indices.size(); r < rank; r++) {
100 if (e1.indices[r] == e2.indices[r])
101 continue;
102 return e1.indices[r] < e2.indices[r];
103 }
104 return false;
105 }
106 std::vector<uint64_t> sizes; // per-rank dimension sizes
107 std::vector<Element> elements;
108 uint64_t pos;
109 };
110
111 /// Abstract base class of sparse tensor storage. Note that we use
112 /// function overloading to implement "partial" method specialization.
113 class SparseTensorStorageBase {
114 public:
115 enum DimLevelType : uint8_t { kDense = 0, kCompressed = 1, kSingleton = 2 };
116
117 virtual uint64_t getDimSize(uint64_t) = 0;
118
119 // Overhead storage.
getPointers(std::vector<uint64_t> **,uint64_t)120 virtual void getPointers(std::vector<uint64_t> **, uint64_t) { fatal("p64"); }
getPointers(std::vector<uint32_t> **,uint64_t)121 virtual void getPointers(std::vector<uint32_t> **, uint64_t) { fatal("p32"); }
getPointers(std::vector<uint16_t> **,uint64_t)122 virtual void getPointers(std::vector<uint16_t> **, uint64_t) { fatal("p16"); }
getPointers(std::vector<uint8_t> **,uint64_t)123 virtual void getPointers(std::vector<uint8_t> **, uint64_t) { fatal("p8"); }
getIndices(std::vector<uint64_t> **,uint64_t)124 virtual void getIndices(std::vector<uint64_t> **, uint64_t) { fatal("i64"); }
getIndices(std::vector<uint32_t> **,uint64_t)125 virtual void getIndices(std::vector<uint32_t> **, uint64_t) { fatal("i32"); }
getIndices(std::vector<uint16_t> **,uint64_t)126 virtual void getIndices(std::vector<uint16_t> **, uint64_t) { fatal("i16"); }
getIndices(std::vector<uint8_t> **,uint64_t)127 virtual void getIndices(std::vector<uint8_t> **, uint64_t) { fatal("i8"); }
128
129 // Primary storage.
getValues(std::vector<double> **)130 virtual void getValues(std::vector<double> **) { fatal("valf64"); }
getValues(std::vector<float> **)131 virtual void getValues(std::vector<float> **) { fatal("valf32"); }
getValues(std::vector<int64_t> **)132 virtual void getValues(std::vector<int64_t> **) { fatal("vali64"); }
getValues(std::vector<int32_t> **)133 virtual void getValues(std::vector<int32_t> **) { fatal("vali32"); }
getValues(std::vector<int16_t> **)134 virtual void getValues(std::vector<int16_t> **) { fatal("vali16"); }
getValues(std::vector<int8_t> **)135 virtual void getValues(std::vector<int8_t> **) { fatal("vali8"); }
136
~SparseTensorStorageBase()137 virtual ~SparseTensorStorageBase() {}
138
139 private:
fatal(const char * tp)140 void fatal(const char *tp) {
141 fprintf(stderr, "unsupported %s\n", tp);
142 exit(1);
143 }
144 };
145
146 /// A memory-resident sparse tensor using a storage scheme based on per-rank
147 /// annotations on dense/sparse. This data structure provides a bufferized
148 /// form of an imaginary SparseTensorType, until such a type becomes a
149 /// first-class citizen of MLIR. In contrast to generating setup methods for
150 /// each differently annotated sparse tensor, this method provides a convenient
151 /// "one-size-fits-all" solution that simply takes an input tensor and
152 /// annotations to implement all required setup in a general manner.
153 template <typename P, typename I, typename V>
154 class SparseTensorStorage : public SparseTensorStorageBase {
155 public:
156 /// Constructs sparse tensor storage scheme following the given
157 /// per-rank dimension dense/sparse annotations.
SparseTensorStorage(SparseTensor * tensor,uint8_t * sparsity)158 SparseTensorStorage(SparseTensor *tensor, uint8_t *sparsity)
159 : sizes(tensor->getSizes()), pointers(getRank()), indices(getRank()) {
160 // Provide hints on capacity.
161 // TODO: needs fine-tuning based on sparsity
162 uint64_t nnz = tensor->getElements().size();
163 values.reserve(nnz);
164 for (uint64_t d = 0, s = 1, rank = getRank(); d < rank; d++) {
165 s *= sizes[d];
166 if (sparsity[d] == kCompressed) {
167 pointers[d].reserve(s + 1);
168 indices[d].reserve(s);
169 s = 1;
170 } else {
171 assert(sparsity[d] == kDense && "singleton not yet supported");
172 }
173 }
174 // Then setup the tensor.
175 traverse(tensor, sparsity, 0, nnz, 0);
176 }
177
~SparseTensorStorage()178 virtual ~SparseTensorStorage() {}
179
getRank() const180 uint64_t getRank() const { return sizes.size(); }
181
getDimSize(uint64_t d)182 uint64_t getDimSize(uint64_t d) override { return sizes[d]; }
183
184 // Partially specialize these three methods based on template types.
getPointers(std::vector<P> ** out,uint64_t d)185 void getPointers(std::vector<P> **out, uint64_t d) override {
186 *out = &pointers[d];
187 }
getIndices(std::vector<I> ** out,uint64_t d)188 void getIndices(std::vector<I> **out, uint64_t d) override {
189 *out = &indices[d];
190 }
getValues(std::vector<V> ** out)191 void getValues(std::vector<V> **out) override { *out = &values; }
192
193 private:
194 /// Initializes sparse tensor storage scheme from a memory-resident
195 /// representation of an external sparse tensor. This method prepares
196 /// the pointers and indices arrays under the given per-rank dimension
197 /// dense/sparse annotations.
traverse(SparseTensor * tensor,uint8_t * sparsity,uint64_t lo,uint64_t hi,uint64_t d)198 void traverse(SparseTensor *tensor, uint8_t *sparsity, uint64_t lo,
199 uint64_t hi, uint64_t d) {
200 const std::vector<Element> &elements = tensor->getElements();
201 // Once dimensions are exhausted, insert the numerical values.
202 if (d == getRank()) {
203 values.push_back(lo < hi ? elements[lo].value : 0.0);
204 return;
205 }
206 // Prepare a sparse pointer structure at this dimension.
207 if (sparsity[d] == kCompressed && pointers[d].empty())
208 pointers[d].push_back(0);
209 // Visit all elements in this interval.
210 uint64_t full = 0;
211 while (lo < hi) {
212 // Find segment in interval with same index elements in this dimension.
213 unsigned idx = elements[lo].indices[d];
214 unsigned seg = lo + 1;
215 while (seg < hi && elements[seg].indices[d] == idx)
216 seg++;
217 // Handle segment in interval for sparse or dense dimension.
218 if (sparsity[d] == kCompressed) {
219 indices[d].push_back(idx);
220 } else {
221 for (; full < idx; full++)
222 traverse(tensor, sparsity, 0, 0, d + 1); // pass empty
223 full++;
224 }
225 traverse(tensor, sparsity, lo, seg, d + 1);
226 // And move on to next segment in interval.
227 lo = seg;
228 }
229 // Finalize the sparse pointer structure at this dimension.
230 if (sparsity[d] == kCompressed) {
231 pointers[d].push_back(indices[d].size());
232 } else {
233 for (uint64_t sz = tensor->getSizes()[d]; full < sz; full++)
234 traverse(tensor, sparsity, 0, 0, d + 1); // pass empty
235 }
236 }
237
238 private:
239 std::vector<uint64_t> sizes; // per-rank dimension sizes
240 std::vector<std::vector<P>> pointers;
241 std::vector<std::vector<I>> indices;
242 std::vector<V> values;
243 };
244
245 /// Helper to convert string to lower case.
toLower(char * token)246 static char *toLower(char *token) {
247 for (char *c = token; *c; c++)
248 *c = tolower(*c);
249 return token;
250 }
251
252 /// Read the MME header of a general sparse matrix of type real.
readMMEHeader(FILE * file,char * name,uint64_t * idata)253 static void readMMEHeader(FILE *file, char *name, uint64_t *idata) {
254 char line[1025];
255 char header[64];
256 char object[64];
257 char format[64];
258 char field[64];
259 char symmetry[64];
260 // Read header line.
261 if (fscanf(file, "%63s %63s %63s %63s %63s\n", header, object, format, field,
262 symmetry) != 5) {
263 fprintf(stderr, "Corrupt header in %s\n", name);
264 exit(1);
265 }
266 // Make sure this is a general sparse matrix.
267 if (strcmp(toLower(header), "%%matrixmarket") ||
268 strcmp(toLower(object), "matrix") ||
269 strcmp(toLower(format), "coordinate") || strcmp(toLower(field), "real") ||
270 strcmp(toLower(symmetry), "general")) {
271 fprintf(stderr,
272 "Cannot find a general sparse matrix with type real in %s\n", name);
273 exit(1);
274 }
275 // Skip comments.
276 while (1) {
277 if (!fgets(line, 1025, file)) {
278 fprintf(stderr, "Cannot find data in %s\n", name);
279 exit(1);
280 }
281 if (line[0] != '%')
282 break;
283 }
284 // Next line contains M N NNZ.
285 idata[0] = 2; // rank
286 if (sscanf(line, "%" PRIu64 "%" PRIu64 "%" PRIu64 "\n", idata + 2, idata + 3,
287 idata + 1) != 3) {
288 fprintf(stderr, "Cannot find size in %s\n", name);
289 exit(1);
290 }
291 }
292
293 /// Read the "extended" FROSTT header. Although not part of the documented
294 /// format, we assume that the file starts with optional comments followed
295 /// by two lines that define the rank, the number of nonzeros, and the
296 /// dimensions sizes (one per rank) of the sparse tensor.
readExtFROSTTHeader(FILE * file,char * name,uint64_t * idata)297 static void readExtFROSTTHeader(FILE *file, char *name, uint64_t *idata) {
298 char line[1025];
299 // Skip comments.
300 while (1) {
301 if (!fgets(line, 1025, file)) {
302 fprintf(stderr, "Cannot find data in %s\n", name);
303 exit(1);
304 }
305 if (line[0] != '#')
306 break;
307 }
308 // Next line contains RANK and NNZ.
309 if (sscanf(line, "%" PRIu64 "%" PRIu64 "\n", idata, idata + 1) != 2) {
310 fprintf(stderr, "Cannot find metadata in %s\n", name);
311 exit(1);
312 }
313 // Followed by a line with the dimension sizes (one per rank).
314 for (uint64_t r = 0; r < idata[0]; r++) {
315 if (fscanf(file, "%" PRIu64, idata + 2 + r) != 1) {
316 fprintf(stderr, "Cannot find dimension size %s\n", name);
317 exit(1);
318 }
319 }
320 }
321
322 /// Reads a sparse tensor with the given filename into a memory-resident
323 /// sparse tensor in coordinate scheme.
openTensor(char * filename,uint64_t * perm)324 static SparseTensor *openTensor(char *filename, uint64_t *perm) {
325 // Open the file.
326 FILE *file = fopen(filename, "r");
327 if (!file) {
328 fprintf(stderr, "Cannot find %s\n", filename);
329 exit(1);
330 }
331 // Perform some file format dependent set up.
332 uint64_t idata[512];
333 if (strstr(filename, ".mtx")) {
334 readMMEHeader(file, filename, idata);
335 } else if (strstr(filename, ".tns")) {
336 readExtFROSTTHeader(file, filename, idata);
337 } else {
338 fprintf(stderr, "Unknown format %s\n", filename);
339 exit(1);
340 }
341 // Prepare sparse tensor object with per-rank dimension sizes
342 // and the number of nonzeros as initial capacity.
343 uint64_t rank = idata[0];
344 uint64_t nnz = idata[1];
345 std::vector<uint64_t> indices(rank);
346 for (uint64_t r = 0; r < rank; r++)
347 indices[perm[r]] = idata[2 + r];
348 SparseTensor *tensor = new SparseTensor(indices, nnz);
349 // Read all nonzero elements.
350 for (uint64_t k = 0; k < nnz; k++) {
351 uint64_t idx = -1;
352 for (uint64_t r = 0; r < rank; r++) {
353 if (fscanf(file, "%" PRIu64, &idx) != 1) {
354 fprintf(stderr, "Cannot find next index in %s\n", filename);
355 exit(1);
356 }
357 // Add 0-based index.
358 indices[perm[r]] = idx - 1;
359 }
360 double value;
361 if (fscanf(file, "%lg\n", &value) != 1) {
362 fprintf(stderr, "Cannot find next value in %s\n", filename);
363 exit(1);
364 }
365 tensor->add(indices, value);
366 }
367 // Close the file and return sorted tensor.
368 fclose(file);
369 tensor->sort(); // sort lexicographically
370 return tensor;
371 }
372
373 /// Templated reader.
374 template <typename P, typename I, typename V>
newSparseTensor(char * filename,uint8_t * sparsity,uint64_t * perm,uint64_t size)375 void *newSparseTensor(char *filename, uint8_t *sparsity, uint64_t *perm,
376 uint64_t size) {
377 SparseTensor *t = openTensor(filename, perm);
378 assert(size == t->getRank()); // sparsity array must match rank
379 SparseTensorStorageBase *tensor =
380 new SparseTensorStorage<P, I, V>(t, sparsity);
381 delete t;
382 return tensor;
383 }
384
385 } // anonymous namespace
386
387 extern "C" {
388
389 /// Helper method to read a sparse tensor filename from the environment,
390 /// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc.
getTensorFilename(uint64_t id)391 char *getTensorFilename(uint64_t id) {
392 char var[80];
393 sprintf(var, "TENSOR%" PRIu64, id);
394 char *env = getenv(var);
395 return env;
396 }
397
398 //===----------------------------------------------------------------------===//
399 //
400 // Public API of the sparse runtime support library that support an opaque
401 // implementation of a bufferized SparseTensor in MLIR. This could be replaced
402 // by actual codegen in MLIR.
403 //
404 // Because we cannot use C++ templates with C linkage, some macro magic is used
405 // to generate implementations for all required type combinations that can be
406 // called from MLIR generated code.
407 //
408 //===----------------------------------------------------------------------===//
409
410 #define TEMPLATE(NAME, TYPE) \
411 struct NAME { \
412 const TYPE *base; \
413 const TYPE *data; \
414 uint64_t off; \
415 uint64_t sizes[1]; \
416 uint64_t strides[1]; \
417 }
418
419 #define CASE(p, i, v, P, I, V) \
420 if (ptrTp == (p) && indTp == (i) && valTp == (v)) \
421 return newSparseTensor<P, I, V>(filename, sparsity, perm, asize)
422
423 #define IMPL1(RET, NAME, TYPE, LIB) \
424 RET NAME(void *tensor) { \
425 std::vector<TYPE> *v; \
426 static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v); \
427 return {v->data(), v->data(), 0, {v->size()}, {1}}; \
428 }
429
430 #define IMPL2(RET, NAME, TYPE, LIB) \
431 RET NAME(void *tensor, uint64_t d) { \
432 std::vector<TYPE> *v; \
433 static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d); \
434 return {v->data(), v->data(), 0, {v->size()}, {1}}; \
435 }
436
437 TEMPLATE(MemRef1DU64, uint64_t);
438 TEMPLATE(MemRef1DU32, uint32_t);
439 TEMPLATE(MemRef1DU16, uint16_t);
440 TEMPLATE(MemRef1DU8, uint8_t);
441 TEMPLATE(MemRef1DI64, int64_t);
442 TEMPLATE(MemRef1DI32, int32_t);
443 TEMPLATE(MemRef1DI16, int16_t);
444 TEMPLATE(MemRef1DI8, int8_t);
445 TEMPLATE(MemRef1DF64, double);
446 TEMPLATE(MemRef1DF32, float);
447
448 enum OverheadTypeEnum : uint64_t { kU64 = 1, kU32 = 2, kU16 = 3, kU8 = 4 };
449
450 enum PrimaryTypeEnum : uint64_t {
451 kF64 = 1,
452 kF32 = 2,
453 kI64 = 3,
454 kI32 = 4,
455 kI16 = 5,
456 kI8 = 6
457 };
458
newSparseTensor(char * filename,uint8_t * abase,uint8_t * adata,uint64_t aoff,uint64_t asize,uint64_t astride,uint64_t * pbase,uint64_t * pdata,uint64_t poff,uint64_t psize,uint64_t pstride,uint64_t ptrTp,uint64_t indTp,uint64_t valTp)459 void *newSparseTensor(char *filename, uint8_t *abase, uint8_t *adata,
460 uint64_t aoff, uint64_t asize, uint64_t astride,
461 uint64_t *pbase, uint64_t *pdata, uint64_t poff,
462 uint64_t psize, uint64_t pstride, uint64_t ptrTp,
463 uint64_t indTp, uint64_t valTp) {
464 assert(astride == 1 && pstride == 1);
465 uint8_t *sparsity = adata + aoff;
466 uint64_t *perm = pdata + poff;
467
468 // Double matrices with all combinations of overhead storage.
469 CASE(kU64, kU64, kF64, uint64_t, uint64_t, double);
470 CASE(kU64, kU32, kF64, uint64_t, uint32_t, double);
471 CASE(kU64, kU16, kF64, uint64_t, uint16_t, double);
472 CASE(kU64, kU8, kF64, uint64_t, uint8_t, double);
473 CASE(kU32, kU64, kF64, uint32_t, uint64_t, double);
474 CASE(kU32, kU32, kF64, uint32_t, uint32_t, double);
475 CASE(kU32, kU16, kF64, uint32_t, uint16_t, double);
476 CASE(kU32, kU8, kF64, uint32_t, uint8_t, double);
477 CASE(kU16, kU64, kF64, uint16_t, uint64_t, double);
478 CASE(kU16, kU32, kF64, uint16_t, uint32_t, double);
479 CASE(kU16, kU16, kF64, uint16_t, uint16_t, double);
480 CASE(kU16, kU8, kF64, uint16_t, uint8_t, double);
481 CASE(kU8, kU64, kF64, uint8_t, uint64_t, double);
482 CASE(kU8, kU32, kF64, uint8_t, uint32_t, double);
483 CASE(kU8, kU16, kF64, uint8_t, uint16_t, double);
484 CASE(kU8, kU8, kF64, uint8_t, uint8_t, double);
485
486 // Float matrices with all combinations of overhead storage.
487 CASE(kU64, kU64, kF32, uint64_t, uint64_t, float);
488 CASE(kU64, kU32, kF32, uint64_t, uint32_t, float);
489 CASE(kU64, kU16, kF32, uint64_t, uint16_t, float);
490 CASE(kU64, kU8, kF32, uint64_t, uint8_t, float);
491 CASE(kU32, kU64, kF32, uint32_t, uint64_t, float);
492 CASE(kU32, kU32, kF32, uint32_t, uint32_t, float);
493 CASE(kU32, kU16, kF32, uint32_t, uint16_t, float);
494 CASE(kU32, kU8, kF32, uint32_t, uint8_t, float);
495 CASE(kU16, kU64, kF32, uint16_t, uint64_t, float);
496 CASE(kU16, kU32, kF32, uint16_t, uint32_t, float);
497 CASE(kU16, kU16, kF32, uint16_t, uint16_t, float);
498 CASE(kU16, kU8, kF32, uint16_t, uint8_t, float);
499 CASE(kU8, kU64, kF32, uint8_t, uint64_t, float);
500 CASE(kU8, kU32, kF32, uint8_t, uint32_t, float);
501 CASE(kU8, kU16, kF32, uint8_t, uint16_t, float);
502 CASE(kU8, kU8, kF32, uint8_t, uint8_t, float);
503
504 // Integral matrices with same overhead storage.
505 CASE(kU64, kU64, kI64, uint64_t, uint64_t, int64_t);
506 CASE(kU64, kU64, kI32, uint64_t, uint64_t, int32_t);
507 CASE(kU64, kU64, kI16, uint64_t, uint64_t, int16_t);
508 CASE(kU64, kU64, kI8, uint64_t, uint64_t, int8_t);
509 CASE(kU32, kU32, kI32, uint32_t, uint32_t, int32_t);
510 CASE(kU32, kU32, kI16, uint32_t, uint32_t, int16_t);
511 CASE(kU32, kU32, kI8, uint32_t, uint32_t, int8_t);
512 CASE(kU16, kU16, kI32, uint16_t, uint16_t, int32_t);
513 CASE(kU16, kU16, kI16, uint16_t, uint16_t, int16_t);
514 CASE(kU16, kU16, kI8, uint16_t, uint16_t, int8_t);
515 CASE(kU8, kU8, kI32, uint8_t, uint8_t, int32_t);
516 CASE(kU8, kU8, kI16, uint8_t, uint8_t, int16_t);
517 CASE(kU8, kU8, kI8, uint8_t, uint8_t, int8_t);
518
519 // Unsupported case (add above if needed).
520 fputs("unsupported combination of types\n", stderr);
521 exit(1);
522 }
523
524 #undef CASE
525
sparseDimSize(void * tensor,uint64_t d)526 uint64_t sparseDimSize(void *tensor, uint64_t d) {
527 return static_cast<SparseTensorStorageBase *>(tensor)->getDimSize(d);
528 }
529
IMPL2(MemRef1DU64,sparsePointers,uint64_t,getPointers)530 IMPL2(MemRef1DU64, sparsePointers, uint64_t, getPointers)
531 IMPL2(MemRef1DU64, sparsePointers64, uint64_t, getPointers)
532 IMPL2(MemRef1DU32, sparsePointers32, uint32_t, getPointers)
533 IMPL2(MemRef1DU16, sparsePointers16, uint16_t, getPointers)
534 IMPL2(MemRef1DU8, sparsePointers8, uint8_t, getPointers)
535 IMPL2(MemRef1DU64, sparseIndices, uint64_t, getIndices)
536 IMPL2(MemRef1DU64, sparseIndices64, uint64_t, getIndices)
537 IMPL2(MemRef1DU32, sparseIndices32, uint32_t, getIndices)
538 IMPL2(MemRef1DU16, sparseIndices16, uint16_t, getIndices)
539 IMPL2(MemRef1DU8, sparseIndices8, uint8_t, getIndices)
540 IMPL1(MemRef1DF64, sparseValuesF64, double, getValues)
541 IMPL1(MemRef1DF32, sparseValuesF32, float, getValues)
542 IMPL1(MemRef1DI64, sparseValuesI64, int64_t, getValues)
543 IMPL1(MemRef1DI32, sparseValuesI32, int32_t, getValues)
544 IMPL1(MemRef1DI16, sparseValuesI16, int16_t, getValues)
545 IMPL1(MemRef1DI8, sparseValuesI8, int8_t, getValues)
546
547 void delSparseTensor(void *tensor) {
548 delete static_cast<SparseTensorStorageBase *>(tensor);
549 }
550
551 #undef TEMPLATE
552 #undef CASE
553 #undef IMPL1
554 #undef IMPL2
555
556 } // extern "C"
557
558 #endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
559