1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file random/mt_random_engine.cc
22  * \brief mt19937 random engine
23  */
24 #include <dmlc/logging.h>
25 #include <tvm/runtime/device_api.h>
26 #include <tvm/runtime/ndarray.h>
27 
28 #include <algorithm>
29 #include <ctime>
30 #include <random>
31 
32 #include "../3rdparty/compiler-rt/builtin_fp16.h"
33 
34 namespace tvm {
35 namespace contrib {
36 
37 /*!
38  * \brief An interface for generating [tensors of] random numbers.
39  */
40 class RandomEngine {
41  public:
42   /*!
43    * \brief Creates a RandomEngine using a default seed.
44    */
RandomEngine()45   RandomEngine() { this->Seed(time(0)); }
46 
47   /*!
48    * \brief Creates a RandomEngine, suggesting the use of a provided seed.
49    */
RandomEngine(unsigned seed)50   explicit RandomEngine(unsigned seed) { this->Seed(seed); }
51 
52   /*!
53    * \brief Seeds the underlying RNG, if possible.
54    */
Seed(unsigned seed)55   inline void Seed(unsigned seed) {
56     rnd_engine_.seed(seed);
57     this->rseed_ = static_cast<unsigned>(seed);
58   }
59 
60   /*!
61    * \return the seed associated with the underlying RNG.
62    */
GetSeed() const63   inline unsigned GetSeed() const { return rseed_; }
64 
65   /*!
66    * \return a random integer sampled from the RNG.
67    */
GetRandInt()68   inline unsigned GetRandInt() { return rnd_engine_(); }
69 
70   /*!
71    * \brief Fills a tensor with values drawn from Unif(low, high)
72    */
SampleUniform(DLTensor * data,float low,float high)73   void SampleUniform(DLTensor* data, float low, float high) {
74     CHECK_GT(high, low) << "high must be bigger than low";
75     CHECK(data->strides == nullptr);
76 
77     DLDataType dtype = data->dtype;
78     int64_t size = 1;
79     for (int i = 0; i < data->ndim; ++i) {
80       size *= data->shape[i];
81     }
82 
83     CHECK(dtype.code == kDLFloat && dtype.bits == 32 && dtype.lanes == 1);
84 
85     if (data->ctx.device_type == kDLCPU) {
86       std::uniform_real_distribution<float> uniform_dist(low, high);
87       std::generate_n(static_cast<float*>(data->data), size,
88                       [&]() { return uniform_dist(rnd_engine_); });
89     } else {
90       LOG(FATAL) << "Do not support random.uniform on this device yet";
91     }
92   }
93 
94   /*!
95    * \brief Fills a tensor with values drawn from Normal(loc, scale**2)
96    */
SampleNormal(DLTensor * data,float loc,float scale)97   void SampleNormal(DLTensor* data, float loc, float scale) {
98     CHECK_GT(scale, 0) << "standard deviation must be positive";
99     CHECK(data->strides == nullptr);
100 
101     DLDataType dtype = data->dtype;
102     int64_t size = 1;
103     for (int i = 0; i < data->ndim; ++i) {
104       size *= data->shape[i];
105     }
106 
107     CHECK(dtype.code == kDLFloat && dtype.bits == 32 && dtype.lanes == 1);
108 
109     if (data->ctx.device_type == kDLCPU) {
110       std::normal_distribution<float> normal_dist(loc, scale);
111       std::generate_n(static_cast<float*>(data->data), size,
112                       [&]() { return normal_dist(rnd_engine_); });
113     } else {
114       LOG(FATAL) << "Do not support random.normal on this device yet";
115     }
116   }
117 
RandomFill(DLTensor * data)118   void RandomFill(DLTensor* data) {
119     int64_t size = 1;
120     for (int i = 0; i < data->ndim; ++i) {
121       size *= data->shape[i];
122     }
123 
124     if (data->ctx.device_type == kDLCPU) {
125       FillData(data, size);
126     } else {
127       runtime::NDArray local = runtime::NDArray::Empty(
128           std::vector<int64_t>{data->shape, data->shape + data->ndim}, data->dtype, {kDLCPU, 0});
129       FillData(&local.ToDLPack()->dl_tensor, size);
130       runtime::NDArray::CopyFromTo(&local.ToDLPack()->dl_tensor, data);
131     }
132   }
133 
134  private:
FillData(DLTensor * tensor,int64_t size)135   void FillData(DLTensor* tensor, int64_t size) {
136     // Make the value be 1.0 - 10.0, not (0.0 - 1.0) so that we could satisfy
137     // quantized dtype (uint8 / int8) data non-empty requirement
138     std::uniform_real_distribution<> dist(1.0, 10.0);
139     // Use float representation could make us work well on float / int type too.
140     if (tensor->dtype.bits == 1) {
141       std::generate_n(static_cast<bool*>(tensor->data), size, [&]() { return dist(rnd_engine_); });
142     } else if (tensor->dtype.bits == 8) {
143       std::generate_n(static_cast<uint8_t*>(tensor->data), size,
144                       [&]() { return dist(rnd_engine_); });
145     } else if (tensor->dtype.bits == 16) {
146       std::generate_n(static_cast<uint16_t*>(tensor->data), size, [&]() {
147         return __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(
148             static_cast<float>(dist(rnd_engine_)));
149       });
150     } else if (tensor->dtype.bits == 32) {
151       std::generate_n(static_cast<float*>(tensor->data), size, [&]() { return dist(rnd_engine_); });
152     } else if (tensor->dtype.bits == 64) {
153       std::generate_n(static_cast<double*>(tensor->data), size,
154                       [&]() { return dist(rnd_engine_); });
155     } else {
156       LOG(FATAL) << "Doesn't support dtype code " << tensor->dtype.code << " dtype bits "
157                  << tensor->dtype.bits;
158     }
159   }
160 
161  private:
162   std::mt19937 rnd_engine_;
163   unsigned rseed_;
164 };
165 
166 }  // namespace contrib
167 }  // namespace tvm
168