1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file tensor_container.h
22  * \brief tensor container that does memory allocation and resize like STL
23  * \author Tianqi Chen
24  */
25 #ifndef MSHADOW_TENSOR_CONTAINER_H_
26 #define MSHADOW_TENSOR_CONTAINER_H_
27 #include "./tensor.h"
28 #include "./io.h"
29 
30 namespace mshadow {
31 /*!
32  * \brief tensor container that does memory allocation and resize like STL,
33  *        use it to save the lines of FreeSpace in class.
34  *        Do not abuse it, efficiency can come from pre-allocation and no re-allocation
35  *
36  * \tparam Device which device the tensor is on
37  * \tparam dimension dimension of the tensor
38  */
39 template<typename Device, int dimension, typename DType = default_real_t>
40 class TensorContainer: public Tensor<Device, dimension, DType> {
41  public:
42   /*!
43    * \brief constructor
44    * \param pad whether use padding alignment in space allocation
45    */
46   explicit TensorContainer(bool pad = MSHADOW_ALLOC_PAD) {
47     this->pad_ = pad;
48     this->dptr_ = data_.dptr_ = NULL;
49     this->shape_[0] = 0;
50     this->stride_ = 0;
51     this->data_.stride_ = 0;
52     this->data_.shape_[0] = 0;
53   }
54   /*!
55    * \brief constructor
56    * \param shape intial shape
57    */
TensorContainer(const Shape<dimension> & shape)58   explicit TensorContainer(const Shape<dimension> &shape) {
59     this->pad_ = MSHADOW_ALLOC_PAD;
60     data_.dptr_ = NULL;
61     this->AllocByShape(shape);
62   }
63   /*!
64    * \brief constructor
65    * \param shape intial shape
66    * \param initv intial value
67    */
TensorContainer(const Shape<dimension> & shape,DType initv)68   explicit TensorContainer(const Shape<dimension> &shape, DType initv) {
69     this->pad_ = MSHADOW_ALLOC_PAD;
70     data_.dptr_ = NULL;
71     this->AllocByShape(shape);
72     (*this) = initv;
73   }
74   /*!
75    * \brief copy constructor
76    * \param src source value
77    */
TensorContainer(const TensorContainer<Device,dimension,DType> & src)78   TensorContainer
79   (const TensorContainer<Device, dimension, DType> &src)
80       : pad_(src.pad_) {
81     this->dptr_ = data_.dptr_ = NULL;
82     this->shape_[0] = 0;
83     this->stride_ = 0;
84     this->data_.stride_ = 0;
85     this->data_.shape_[0] = 0;
86     this->stream_ = src.stream_;
87     if (src.dptr_ != NULL) {
88       this->AllocByShape(src.shape_);
89       mshadow::Copy(*this, src, this->stream_);
90     }
91   }
~TensorContainer(void)92   ~TensorContainer(void) MSHADOW_THROW_EXCEPTION {
93     this->Release();
94   }
95   /*!
96    * \brief resize the container to given shape, content is NOT preserved
97    * \param shape target shape
98    */
Resize(const Shape<dimension> & shape)99   inline void Resize(const Shape<dimension> &shape) {
100     Shape<2> s2 = shape.FlatTo2D();
101     if (s2.shape_[1] > data_.stride_ || s2.shape_[0] > data_.size(0)) {
102       this->AllocByShape(shape);
103     } else {
104       this->shape_ = shape;
105       if (this->pad_) {
106         this->stride_ = data_.stride_;
107       } else {
108         this->stride_ = s2.shape_[1];
109       }
110     }
111   }
112   /*!
113    * \brief resize the container to given shape, and initialize, content is NOT preserved
114    * \param shape target shape
115    * \param initv initialization value
116    */
Resize(const Shape<dimension> & shape,DType initv)117   inline void Resize(const Shape<dimension> &shape, DType initv) {
118     this->Resize(shape);
119     (*this) = initv;
120   }
121   /*! \brief set whether padding is allowed in tensor */
set_pad(bool pad)122   inline void set_pad(bool pad) {
123     this->pad_ = pad;
124   }
125   /*!
126    * \brief save by binary format
127    * \param fo output binary stream
128    * \tparam TStream type of stream, need to support Read, Write, one example is utils::IStream.
129    */
130   template<typename TStream>
SaveBinary(TStream & fo)131   inline void SaveBinary(TStream &fo) const { // NOLINT(*)
132     mshadow::SaveBinary(fo, *this);
133   }
134   /*!
135    * \brief load by binary format, a temp Tensor<cpu,dim> storage will be allocated
136    * \param fi input binary stream
137    * \tparam TStream type of stream, need to support Read, Write, one example is utils::IStream.
138    */
139   template<typename TStream>
LoadBinary(TStream & fi)140   inline void LoadBinary(TStream &fi) { // NOLINT(*)
141     Tensor<cpu, dimension, DType> tmp;
142     mshadow::LoadBinary(fi, &tmp, false);
143     this->Resize(tmp.shape_);
144     Stream<Device> stream;
145     Copy(*this, tmp, &stream);
146     mshadow::FreeSpace(&tmp);
147   }
148   /*!
149    * \brief assign operator from TensorContainer
150    * \param src source value
151    * \return reference of self
152    */
153   inline TensorContainer &operator=
154   (const TensorContainer<Device, dimension, DType> &src) {
155     this->pad_ = src.pad_;
156     this->stream_ = src.stream_;
157     if (src.dptr_ != NULL) {
158       this->Resize(src.shape_);
159       mshadow::Copy(*this, src, this->stream_);
160     }
161     return *this;
162   }
163   /*!\brief functions to fit expression template */
164   inline Tensor<Device, dimension, DType> &operator=(DType s) {
165     return this->__assign(s);
166   }
167   /*!\brief functions to fit expression template */
168   template<typename E>
169   inline Tensor<Device, dimension, DType> &
170   operator=(const expr::Exp<E, DType, expr::type::kMapper> &exp) {
171     return this->__assign(exp);
172   }
173   /*!\brief functions to fit expression template */
174   template<typename E>
175   inline Tensor<Device, dimension, DType> &
176   operator=(const expr::Exp<E, DType, expr::type::kChainer> &exp) {
177     return this->__assign(exp);
178   }
179   /*!\brief functions to fit expression template */
180   template<typename E>
181   inline Tensor<Device, dimension, DType> &
182   operator=(const expr::Exp<E, DType, expr::type::kComplex> &exp) {
183     return this->__assign(exp);
184   }
185   /*!
186    * \brief Release the llocated space,
187    *  The TensorContainer is still functionable,
188    *  but will restart allocating space when Resize is called.
189    */
Release(void)190   inline void Release(void) {
191     if (data_.dptr_ != NULL) {
192       this->shape_[0] = 0;
193       this->stride_ = 0;
194       this->data_.stride_ = 0;
195       this->data_.shape_[0] = 0;
196       try {
197         mshadow::FreeSpace(&data_);
198       } catch (const dmlc::Error &e) {
199         this->dptr_ = data_.dptr_ = NULL;
200         throw e;
201       }
202       this->dptr_ = data_.dptr_ = NULL;
203     }
204   }
205 
206  private:
207   /*! \brief whether we do padding in the space */
208   bool pad_;
209   /*! \brief the shape of data_ is actually current data space */
210   Tensor<Device, 2, DType> data_;
211 
AllocByShape(const Shape<dimension> & shape)212   inline void AllocByShape(const Shape<dimension>& shape) {
213     if (data_.dptr_ != NULL) this->Release();
214     data_.shape_ = shape.FlatTo2D();
215     mshadow::AllocSpace(&data_, pad_);
216     this->dptr_ = data_.dptr_;
217     this->shape_ = shape;
218     if (this->pad_) {
219       this->stride_ = data_.stride_;
220     } else {
221       this->stride_ = data_.size(1);
222     }
223   }
224 };
225 }  // namespace mshadow
226 #endif  // MSHADOW_TENSOR_CONTAINER_H_
227