1 /*
2  * SPDX-License-Identifier: Apache-2.0
3  */
4 
5 // ATTENTION: The code in this file is highly EXPERIMENTAL.
6 // Adventurous users should note that the APIs will probably change.
7 
8 #pragma once
9 
10 #include <cmath>
11 #include <functional>
12 #include <numeric>
13 #include "onnx/common/assertions.h"
14 #include "onnx/onnx_pb.h"
15 
16 namespace ONNX_NAMESPACE {
17 
18 struct Tensor final {
19 private:
20   bool is_segment_;
21   int64_t segment_begin_;
22   int64_t segment_end_;
23   bool has_name_;
24   std::string name_;
25   int32_t elem_type_;
26   std::vector<int64_t> sizes_;
27 
28   std::vector<float> float_data_;
29   std::vector<double> double_data_;
30   std::vector<int32_t> int32_data_;
31   std::vector<int64_t> int64_data_;
32   std::vector<uint64_t> uint64_data_;
33   std::vector<std::string> string_data_;
34 
35   bool is_raw_data_;
36   std::string raw_data_;
37 
38   template <typename F, typename T>
39   void bin_func(const F& f, T* ptr, const T* a_ptr);
40 
41   template <typename F, typename T>
42   void un_func(const F& f, T* ptr);
43 
44   template <typename T>
45   void scale_dim(T* ptr, const T* s_ptr);
46 
47  public:
Tensorfinal48   Tensor()
49   : is_segment_(false)
50   , segment_begin_(0)
51   , segment_end_(0)
52   , has_name_(false)
53   , elem_type_(ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED)
54   , is_raw_data_(false)
55   {}
56 
Tensorfinal57   Tensor(const Tensor &other)
58   : is_segment_(other.is_segment_)
59   , segment_begin_(other.segment_begin_)
60   , segment_end_(other.segment_end_)
61   , has_name_(other.has_name_)
62   , elem_type_(other.elem_type_)
63   , sizes_(other.sizes_)
64   , float_data_(other.float_data_)
65   , double_data_(other.double_data_)
66   , int32_data_(other.int32_data_)
67   , int64_data_(other.int64_data_)
68   , uint64_data_(other.uint64_data_)
69   , is_raw_data_(other.is_raw_data_) {
70     // Deep copy. Avoid copy on write when using gcc<5.0
71     string_data_.resize(other.string_data_.size());
72     for(unsigned int i=0; i<other.string_data_.size(); ++i) {
73       string_data_[i] = std::string( other.string_data_[i].data(), other.string_data_[i].size() );
74     }
75     name_ = std::string(other.name_.data(), other.name_.size());
76     raw_data_ = std::string(other.raw_data_.data(), other.raw_data_.size());
77   }
78 
swapfinal79   friend void swap(Tensor& first, Tensor& second){
80     using std::swap;
81     swap(first.is_segment_, second.is_segment_);
82     swap(first.segment_begin_, second.segment_begin_);
83     swap(first.segment_end_, second.segment_end_);
84     swap(first.has_name_, second.has_name_);
85     swap(first.name_, second.name_);
86     swap(first.elem_type_, second.elem_type_);
87     swap(first.sizes_, second.sizes_);
88     swap(first.float_data_, second.float_data_);
89     swap(first.double_data_, second.double_data_);
90     swap(first.int32_data_, second.int32_data_);
91     swap(first.int64_data_, second.int64_data_);
92     swap(first.uint64_data_, second.uint64_data_);
93     swap(first.is_raw_data_, second.is_raw_data_);
94     swap(first.string_data_, second.string_data_);
95     swap(first.raw_data_, second.raw_data_);
96   }
97 
98   Tensor& operator=(Tensor other) noexcept {
99     swap(*this, other);
100     return *this;
101   }
102 
sizesfinal103   const std::vector<int64_t>& sizes() const {
104     return sizes_;
105   }
sizesfinal106   std::vector<int64_t>& sizes() {
107     return sizes_;
108   }
109 
size_from_dimfinal110   int64_t size_from_dim(int dim) const {
111     if (dim < 0) {
112       dim += (int)sizes_.size();
113     }
114     ONNX_ASSERT(dim >= 0 && (size_t)dim < sizes_.size());
115     return std::accumulate(sizes_.begin() + dim, sizes_.end(), (int64_t)1, std::multiplies<int64_t>{});
116   }
117 
elem_typefinal118   int32_t elem_type() const {
119     return elem_type_;
120   }
121 
elem_typefinal122   int32_t& elem_type() {
123     return elem_type_;
124   }
125 
stringsfinal126   std::vector<std::string>& strings() {
127     return string_data_;
128   }
129 
stringsfinal130   const std::vector<std::string>& strings() const {
131     return string_data_;
132   }
133 
floatsfinal134   std::vector<float>& floats() {
135     return float_data_;
136   }
137 
floatsfinal138   const std::vector<float>& floats() const {
139     return float_data_;
140   }
141 
doublesfinal142   std::vector<double>& doubles() {
143     return double_data_;
144   }
145 
doublesfinal146   const std::vector<double>& doubles() const {
147     return double_data_;
148   }
149 
int32sfinal150   std::vector<int32_t>& int32s() {
151     return int32_data_;
152   }
153 
int32sfinal154   const std::vector<int32_t>& int32s() const {
155     return int32_data_;
156   }
157 
int64sfinal158   std::vector<int64_t>& int64s() {
159     return int64_data_;
160   }
161 
int64sfinal162   const std::vector<int64_t>& int64s() const {
163     return int64_data_;
164   }
165 
uint64sfinal166   std::vector<uint64_t>& uint64s() {
167     return uint64_data_;
168   }
169 
uint64sfinal170   const std::vector<uint64_t>& uint64s() const {
171     return uint64_data_;
172   }
173 
rawfinal174   const std::string& raw() const {
175     return raw_data_;
176   }
177 
set_raw_datafinal178   void set_raw_data(std::string raw_data) {
179     is_raw_data_ = true;
180     raw_data_ = std::move(raw_data);
181   }
182 
183   template <typename T>
184   T* data();
185 
186   template <typename T>
187   const T* data() const;
188 
is_segmentfinal189   bool is_segment() const {
190     return is_segment_;
191   }
192 
segment_beginfinal193   int64_t segment_begin() const {
194     return segment_begin_;
195   }
196 
segment_endfinal197   int64_t segment_end() const {
198     return segment_end_;
199   }
200 
set_segment_begin_and_endfinal201   void set_segment_begin_and_end(int64_t begin, int64_t end) {
202     is_segment_ = true;
203     segment_begin_ = begin;
204     segment_end_ = end;
205   }
206 
hasNamefinal207   bool hasName() const {
208     return has_name_;
209   }
210 
namefinal211   const std::string& name() const {
212     return name_;
213   }
214 
setNamefinal215   void setName(std::string name) {
216     has_name_ = true;
217     name_ = std::move(name);
218   }
219 
is_raw_datafinal220   bool is_raw_data() const {
221     return is_raw_data_;
222   }
223 
224   //this += a
225   //Supported for
226   //FLOAT, BOOL, INT8, INT16, INT32, UINT8, UINT16, INT64,
227   //UINT32, UINT64, DOUBLE,
228   //TODO: Support for FLOAT16, COMPLEX64, COMPLEX128
229   void add(const Tensor& a);
230 
231   //this -= a
232   //Supported for
233   //FLOAT, BOOL, INT8, INT16, INT32, UINT8, UINT16, INT64,
234   //UINT32, UINT64, DOUBLE
235   //TODO: Support for FLOAT16, COMPLEX64, COMPLEX128
236   void subtract(const Tensor& a);
237 
238   //this *= a
239   //Supported for
240   //FLOAT, BOOL, INT8, INT16, INT32, UINT8, UINT16, INT64,
241   //UINT32, UINT64, DOUBLE
242   //TODO: Support for FLOAT16, COMPLEX64, COMPLEX128
243   void multiply(const Tensor& a);
244 
245   //this /= a
246   //Supported for
247   //FLOAT, INT8, INT16, INT32, UINT8, UINT16, INT64,
248   //UINT32, UINT64, DOUBLE
249   //TODO: Support for FLOAT16, COMPLEX64, COMPLEX128
250   void divide(const Tensor& a);
251 
252   //Element-wise square root of This
253   //Supported for
254   //FLOAT, DOUBLE,
255   //TODO: Support for FLOAT16
256   void sqrt();
257 
258   //Element wise scaling of tensor s
259   //s is one dimensional, has size M, where M is size of first dimension of tensor
260   //s must have has data type corresponding to this
261   //Supported for
262   //FLOAT16, FLOAT, DOUBLE
263   void scale_by_first_dim(const Tensor& s);
264 };
265 
266 #define define_data(type, field)                  \
267   template <>                                     \
268   inline type* Tensor::data<type>() {             \
269     if (is_raw_data_) {                           \
270       return (type*)&raw_data_.data()[0];         \
271     } else {                                      \
272       return field.data();                        \
273     }                                             \
274   }                                               \
275                                                   \
276   template <>                                     \
277   inline const type* Tensor::data<type>() const { \
278     if (is_raw_data_) {                           \
279       return (type*)(raw_data_.data());           \
280     } else {                                      \
281       return field.data();                        \
282     }                                             \
283   }
284 
285 define_data(float, float_data_);
286 define_data(double, double_data_);
287 define_data(int32_t, int32_data_);
288 define_data(int64_t, int64_data_);
289 define_data(uint64_t, uint64_data_);
290 define_data(std::string, string_data_);
291 #undef define_data
292 
293 template <typename F, typename T>
bin_func(const F & f,T * ptr,const T * a_ptr)294 inline void Tensor::bin_func(const F& f, T* ptr, const T* a_ptr) {
295   const int64_t num_elements = size_from_dim(0);
296   for (int64_t i = 0; i < num_elements; ++i) {
297     ptr[i] = f(ptr[i], a_ptr[i]);
298   }
299 }
300 
301 template <typename F, typename T>
un_func(const F & f,T * ptr)302 inline void Tensor::un_func(const F& f, T* ptr) {
303   const int64_t num_elements = size_from_dim(0);
304   for (int64_t i = 0; i < num_elements; ++i) {
305     ptr[i] = f(ptr[i]);
306   }
307 }
308 
309 template <typename T>
scale_dim(T * ptr,const T * s_ptr)310 inline void Tensor::scale_dim(T* ptr, const T* s_ptr) {
311   int64_t elems_per_first_dim = size_from_dim(1);
312   int64_t first_dim_size = sizes_[0];
313   int64_t counter = 0;
314   for (int64_t i = 0; i < first_dim_size; ++i) {
315     for (int64_t j = 0; j < elems_per_first_dim; ++j) {
316       ptr[counter++] *= s_ptr[i];
317     }
318   }
319 }
320 
321 #define APPLY_BINARY_FUNCTION(op_name, f)                                  \
322   inline void Tensor::op_name(const Tensor& other) {                       \
323     TENSOR_ASSERTM(                                                        \
324         other.elem_type() == elem_type_,                                   \
325         "Tensor types do not match: %s != %s",                             \
326         to_string(elem_type_).c_str(),                                     \
327         " vs. ",                                                           \
328         to_string(other.elem_type()).c_str());                             \
329     TENSOR_ASSERTM(other.sizes() == sizes_, "Tensor sizes do not match."); \
330     switch (elem_type_) {                                                  \
331       case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {                   \
332         bin_func(f<float>(), data<float>(), other.data<float>());          \
333         break;                                                             \
334       }                                                                    \
335       case ONNX_NAMESPACE::TensorProto_DataType_BOOL:                      \
336       case ONNX_NAMESPACE::TensorProto_DataType_INT8:                      \
337       case ONNX_NAMESPACE::TensorProto_DataType_INT16:                     \
338       case ONNX_NAMESPACE::TensorProto_DataType_INT32:                     \
339       case ONNX_NAMESPACE::TensorProto_DataType_UINT8:                     \
340       case ONNX_NAMESPACE::TensorProto_DataType_UINT16: {                  \
341         bin_func(f<int32_t>(), data<int32_t>(), other.data<int32_t>());    \
342         break;                                                             \
343       }                                                                    \
344       case ONNX_NAMESPACE::TensorProto_DataType_INT64: {                   \
345         bin_func(f<int64_t>(), data<int64_t>(), other.data<int64_t>());    \
346         break;                                                             \
347       }                                                                    \
348       case ONNX_NAMESPACE::TensorProto_DataType_UINT32:                    \
349       case ONNX_NAMESPACE::TensorProto_DataType_UINT64: {                  \
350         bin_func(f<uint64_t>(), data<uint64_t>(), other.data<uint64_t>()); \
351         break;                                                             \
352       }                                                                    \
353       case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: {                  \
354         bin_func(f<double>(), data<double>(), other.data<double>());       \
355         break;                                                             \
356       }                                                                    \
357       default:                                                             \
358         TENSOR_ASSERTM(                                                    \
359             false,                                                         \
360             "Operation %s not supported for data type %s",                 \
361             #op_name,                                                      \
362             " not supported for data type ",                               \
363             to_string(elem_type_).c_str());                                \
364     }                                                                      \
365   }
366 
APPLY_BINARY_FUNCTION(add,std::plus)367 APPLY_BINARY_FUNCTION(add, std::plus)
368 APPLY_BINARY_FUNCTION(subtract, std::minus)
369 APPLY_BINARY_FUNCTION(multiply, std::multiplies)
370 APPLY_BINARY_FUNCTION(divide, std::divides)
371 
372 #undef APPLY_BINARY_FUNCTION
373 
374 inline void Tensor::sqrt() {
375   switch(elem_type_) {
376     case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
377       un_func<float (*)(float), float>(std::sqrt, data<float>());
378       break;
379     }
380     case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: {
381       un_func<double (*)(double), double>(std::sqrt, data<double>());
382       break;
383     }
384     default:
385       TENSOR_ASSERTM(
386           false,
387           "Operation sqrt not supported for data type %s",
388           to_string(elem_type_).c_str());
389   }
390 }
391 
scale_by_first_dim(const Tensor & other)392 inline void Tensor::scale_by_first_dim(const Tensor& other) {
393   ONNX_ASSERT(
394       sizes_.size() > 1 && other.sizes().size() == 1 &&
395       other.sizes()[0] == sizes_[0]);
396   ONNX_ASSERT(other.elem_type() == elem_type_);
397 
398   switch(elem_type_) {
399     case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
400       scale_dim(data<float>(), other.data<float>());
401       break;
402     }
403     case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: {
404       scale_dim(data<int32_t>(), other.data<int32_t>());
405       break;
406     }
407     case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: {
408       scale_dim(data<double>(), other.data<double>());
409       break;
410     }
411     default:
412       TENSOR_ASSERTM(
413           false,
414           "Operation scale_by_first_dim not supported for data type %s",
415           to_string(elem_type_).c_str());
416   }
417 }
418 
419 } // namespace ONNX_NAMESPACE
420