1 /*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5 // ATTENTION: The code in this file is highly EXPERIMENTAL.
6 // Adventurous users should note that the APIs will probably change.
7
8 #pragma once
9
10 #include <cmath>
11 #include <functional>
12 #include <numeric>
13 #include "onnx/common/assertions.h"
14 #include "onnx/onnx_pb.h"
15
16 namespace ONNX_NAMESPACE {
17
18 struct Tensor final {
19 private:
20 bool is_segment_;
21 int64_t segment_begin_;
22 int64_t segment_end_;
23 bool has_name_;
24 std::string name_;
25 int32_t elem_type_;
26 std::vector<int64_t> sizes_;
27
28 std::vector<float> float_data_;
29 std::vector<double> double_data_;
30 std::vector<int32_t> int32_data_;
31 std::vector<int64_t> int64_data_;
32 std::vector<uint64_t> uint64_data_;
33 std::vector<std::string> string_data_;
34
35 bool is_raw_data_;
36 std::string raw_data_;
37
38 template <typename F, typename T>
39 void bin_func(const F& f, T* ptr, const T* a_ptr);
40
41 template <typename F, typename T>
42 void un_func(const F& f, T* ptr);
43
44 template <typename T>
45 void scale_dim(T* ptr, const T* s_ptr);
46
47 public:
Tensorfinal48 Tensor()
49 : is_segment_(false)
50 , segment_begin_(0)
51 , segment_end_(0)
52 , has_name_(false)
53 , elem_type_(ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED)
54 , is_raw_data_(false)
55 {}
56
Tensorfinal57 Tensor(const Tensor &other)
58 : is_segment_(other.is_segment_)
59 , segment_begin_(other.segment_begin_)
60 , segment_end_(other.segment_end_)
61 , has_name_(other.has_name_)
62 , elem_type_(other.elem_type_)
63 , sizes_(other.sizes_)
64 , float_data_(other.float_data_)
65 , double_data_(other.double_data_)
66 , int32_data_(other.int32_data_)
67 , int64_data_(other.int64_data_)
68 , uint64_data_(other.uint64_data_)
69 , is_raw_data_(other.is_raw_data_) {
70 // Deep copy. Avoid copy on write when using gcc<5.0
71 string_data_.resize(other.string_data_.size());
72 for(unsigned int i=0; i<other.string_data_.size(); ++i) {
73 string_data_[i] = std::string( other.string_data_[i].data(), other.string_data_[i].size() );
74 }
75 name_ = std::string(other.name_.data(), other.name_.size());
76 raw_data_ = std::string(other.raw_data_.data(), other.raw_data_.size());
77 }
78
swapfinal79 friend void swap(Tensor& first, Tensor& second){
80 using std::swap;
81 swap(first.is_segment_, second.is_segment_);
82 swap(first.segment_begin_, second.segment_begin_);
83 swap(first.segment_end_, second.segment_end_);
84 swap(first.has_name_, second.has_name_);
85 swap(first.name_, second.name_);
86 swap(first.elem_type_, second.elem_type_);
87 swap(first.sizes_, second.sizes_);
88 swap(first.float_data_, second.float_data_);
89 swap(first.double_data_, second.double_data_);
90 swap(first.int32_data_, second.int32_data_);
91 swap(first.int64_data_, second.int64_data_);
92 swap(first.uint64_data_, second.uint64_data_);
93 swap(first.is_raw_data_, second.is_raw_data_);
94 swap(first.string_data_, second.string_data_);
95 swap(first.raw_data_, second.raw_data_);
96 }
97
98 Tensor& operator=(Tensor other) noexcept {
99 swap(*this, other);
100 return *this;
101 }
102
sizesfinal103 const std::vector<int64_t>& sizes() const {
104 return sizes_;
105 }
sizesfinal106 std::vector<int64_t>& sizes() {
107 return sizes_;
108 }
109
size_from_dimfinal110 int64_t size_from_dim(int dim) const {
111 if (dim < 0) {
112 dim += (int)sizes_.size();
113 }
114 ONNX_ASSERT(dim >= 0 && (size_t)dim < sizes_.size());
115 return std::accumulate(sizes_.begin() + dim, sizes_.end(), (int64_t)1, std::multiplies<int64_t>{});
116 }
117
elem_typefinal118 int32_t elem_type() const {
119 return elem_type_;
120 }
121
elem_typefinal122 int32_t& elem_type() {
123 return elem_type_;
124 }
125
stringsfinal126 std::vector<std::string>& strings() {
127 return string_data_;
128 }
129
stringsfinal130 const std::vector<std::string>& strings() const {
131 return string_data_;
132 }
133
floatsfinal134 std::vector<float>& floats() {
135 return float_data_;
136 }
137
floatsfinal138 const std::vector<float>& floats() const {
139 return float_data_;
140 }
141
doublesfinal142 std::vector<double>& doubles() {
143 return double_data_;
144 }
145
doublesfinal146 const std::vector<double>& doubles() const {
147 return double_data_;
148 }
149
int32sfinal150 std::vector<int32_t>& int32s() {
151 return int32_data_;
152 }
153
int32sfinal154 const std::vector<int32_t>& int32s() const {
155 return int32_data_;
156 }
157
int64sfinal158 std::vector<int64_t>& int64s() {
159 return int64_data_;
160 }
161
int64sfinal162 const std::vector<int64_t>& int64s() const {
163 return int64_data_;
164 }
165
uint64sfinal166 std::vector<uint64_t>& uint64s() {
167 return uint64_data_;
168 }
169
uint64sfinal170 const std::vector<uint64_t>& uint64s() const {
171 return uint64_data_;
172 }
173
rawfinal174 const std::string& raw() const {
175 return raw_data_;
176 }
177
set_raw_datafinal178 void set_raw_data(std::string raw_data) {
179 is_raw_data_ = true;
180 raw_data_ = std::move(raw_data);
181 }
182
183 template <typename T>
184 T* data();
185
186 template <typename T>
187 const T* data() const;
188
is_segmentfinal189 bool is_segment() const {
190 return is_segment_;
191 }
192
segment_beginfinal193 int64_t segment_begin() const {
194 return segment_begin_;
195 }
196
segment_endfinal197 int64_t segment_end() const {
198 return segment_end_;
199 }
200
set_segment_begin_and_endfinal201 void set_segment_begin_and_end(int64_t begin, int64_t end) {
202 is_segment_ = true;
203 segment_begin_ = begin;
204 segment_end_ = end;
205 }
206
hasNamefinal207 bool hasName() const {
208 return has_name_;
209 }
210
namefinal211 const std::string& name() const {
212 return name_;
213 }
214
setNamefinal215 void setName(std::string name) {
216 has_name_ = true;
217 name_ = std::move(name);
218 }
219
is_raw_datafinal220 bool is_raw_data() const {
221 return is_raw_data_;
222 }
223
224 //this += a
225 //Supported for
226 //FLOAT, BOOL, INT8, INT16, INT32, UINT8, UINT16, INT64,
227 //UINT32, UINT64, DOUBLE,
228 //TODO: Support for FLOAT16, COMPLEX64, COMPLEX128
229 void add(const Tensor& a);
230
231 //this -= a
232 //Supported for
233 //FLOAT, BOOL, INT8, INT16, INT32, UINT8, UINT16, INT64,
234 //UINT32, UINT64, DOUBLE
235 //TODO: Support for FLOAT16, COMPLEX64, COMPLEX128
236 void subtract(const Tensor& a);
237
238 //this *= a
239 //Supported for
240 //FLOAT, BOOL, INT8, INT16, INT32, UINT8, UINT16, INT64,
241 //UINT32, UINT64, DOUBLE
242 //TODO: Support for FLOAT16, COMPLEX64, COMPLEX128
243 void multiply(const Tensor& a);
244
245 //this /= a
246 //Supported for
247 //FLOAT, INT8, INT16, INT32, UINT8, UINT16, INT64,
248 //UINT32, UINT64, DOUBLE
249 //TODO: Support for FLOAT16, COMPLEX64, COMPLEX128
250 void divide(const Tensor& a);
251
252 //Element-wise square root of This
253 //Supported for
254 //FLOAT, DOUBLE,
255 //TODO: Support for FLOAT16
256 void sqrt();
257
258 //Element wise scaling of tensor s
259 //s is one dimensional, has size M, where M is size of first dimension of tensor
260 //s must have has data type corresponding to this
261 //Supported for
262 //FLOAT16, FLOAT, DOUBLE
263 void scale_by_first_dim(const Tensor& s);
264 };
265
266 #define define_data(type, field) \
267 template <> \
268 inline type* Tensor::data<type>() { \
269 if (is_raw_data_) { \
270 return (type*)&raw_data_.data()[0]; \
271 } else { \
272 return field.data(); \
273 } \
274 } \
275 \
276 template <> \
277 inline const type* Tensor::data<type>() const { \
278 if (is_raw_data_) { \
279 return (type*)(raw_data_.data()); \
280 } else { \
281 return field.data(); \
282 } \
283 }
284
285 define_data(float, float_data_);
286 define_data(double, double_data_);
287 define_data(int32_t, int32_data_);
288 define_data(int64_t, int64_data_);
289 define_data(uint64_t, uint64_data_);
290 define_data(std::string, string_data_);
291 #undef define_data
292
293 template <typename F, typename T>
bin_func(const F & f,T * ptr,const T * a_ptr)294 inline void Tensor::bin_func(const F& f, T* ptr, const T* a_ptr) {
295 const int64_t num_elements = size_from_dim(0);
296 for (int64_t i = 0; i < num_elements; ++i) {
297 ptr[i] = f(ptr[i], a_ptr[i]);
298 }
299 }
300
301 template <typename F, typename T>
un_func(const F & f,T * ptr)302 inline void Tensor::un_func(const F& f, T* ptr) {
303 const int64_t num_elements = size_from_dim(0);
304 for (int64_t i = 0; i < num_elements; ++i) {
305 ptr[i] = f(ptr[i]);
306 }
307 }
308
309 template <typename T>
scale_dim(T * ptr,const T * s_ptr)310 inline void Tensor::scale_dim(T* ptr, const T* s_ptr) {
311 int64_t elems_per_first_dim = size_from_dim(1);
312 int64_t first_dim_size = sizes_[0];
313 int64_t counter = 0;
314 for (int64_t i = 0; i < first_dim_size; ++i) {
315 for (int64_t j = 0; j < elems_per_first_dim; ++j) {
316 ptr[counter++] *= s_ptr[i];
317 }
318 }
319 }
320
321 #define APPLY_BINARY_FUNCTION(op_name, f) \
322 inline void Tensor::op_name(const Tensor& other) { \
323 TENSOR_ASSERTM( \
324 other.elem_type() == elem_type_, \
325 "Tensor types do not match: %s != %s", \
326 to_string(elem_type_).c_str(), \
327 " vs. ", \
328 to_string(other.elem_type()).c_str()); \
329 TENSOR_ASSERTM(other.sizes() == sizes_, "Tensor sizes do not match."); \
330 switch (elem_type_) { \
331 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { \
332 bin_func(f<float>(), data<float>(), other.data<float>()); \
333 break; \
334 } \
335 case ONNX_NAMESPACE::TensorProto_DataType_BOOL: \
336 case ONNX_NAMESPACE::TensorProto_DataType_INT8: \
337 case ONNX_NAMESPACE::TensorProto_DataType_INT16: \
338 case ONNX_NAMESPACE::TensorProto_DataType_INT32: \
339 case ONNX_NAMESPACE::TensorProto_DataType_UINT8: \
340 case ONNX_NAMESPACE::TensorProto_DataType_UINT16: { \
341 bin_func(f<int32_t>(), data<int32_t>(), other.data<int32_t>()); \
342 break; \
343 } \
344 case ONNX_NAMESPACE::TensorProto_DataType_INT64: { \
345 bin_func(f<int64_t>(), data<int64_t>(), other.data<int64_t>()); \
346 break; \
347 } \
348 case ONNX_NAMESPACE::TensorProto_DataType_UINT32: \
349 case ONNX_NAMESPACE::TensorProto_DataType_UINT64: { \
350 bin_func(f<uint64_t>(), data<uint64_t>(), other.data<uint64_t>()); \
351 break; \
352 } \
353 case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: { \
354 bin_func(f<double>(), data<double>(), other.data<double>()); \
355 break; \
356 } \
357 default: \
358 TENSOR_ASSERTM( \
359 false, \
360 "Operation %s not supported for data type %s", \
361 #op_name, \
362 " not supported for data type ", \
363 to_string(elem_type_).c_str()); \
364 } \
365 }
366
APPLY_BINARY_FUNCTION(add,std::plus)367 APPLY_BINARY_FUNCTION(add, std::plus)
368 APPLY_BINARY_FUNCTION(subtract, std::minus)
369 APPLY_BINARY_FUNCTION(multiply, std::multiplies)
370 APPLY_BINARY_FUNCTION(divide, std::divides)
371
372 #undef APPLY_BINARY_FUNCTION
373
374 inline void Tensor::sqrt() {
375 switch(elem_type_) {
376 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
377 un_func<float (*)(float), float>(std::sqrt, data<float>());
378 break;
379 }
380 case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: {
381 un_func<double (*)(double), double>(std::sqrt, data<double>());
382 break;
383 }
384 default:
385 TENSOR_ASSERTM(
386 false,
387 "Operation sqrt not supported for data type %s",
388 to_string(elem_type_).c_str());
389 }
390 }
391
scale_by_first_dim(const Tensor & other)392 inline void Tensor::scale_by_first_dim(const Tensor& other) {
393 ONNX_ASSERT(
394 sizes_.size() > 1 && other.sizes().size() == 1 &&
395 other.sizes()[0] == sizes_[0]);
396 ONNX_ASSERT(other.elem_type() == elem_type_);
397
398 switch(elem_type_) {
399 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
400 scale_dim(data<float>(), other.data<float>());
401 break;
402 }
403 case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: {
404 scale_dim(data<int32_t>(), other.data<int32_t>());
405 break;
406 }
407 case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: {
408 scale_dim(data<double>(), other.data<double>());
409 break;
410 }
411 default:
412 TENSOR_ASSERTM(
413 false,
414 "Operation scale_by_first_dim not supported for data type %s",
415 to_string(elem_type_).c_str());
416 }
417 }
418
419 } // namespace ONNX_NAMESPACE
420