1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file runtime.cc
22  * \brief Generic VTA runtime in C++11.
23  *
24  *  The runtime depends on specific instruction
25  *  stream spec as specified in hw_spec.h
26  */
27 #include "runtime.h"
28 
29 #include <dmlc/logging.h>
30 #include <tvm/runtime/c_runtime_api.h>
31 #include <vta/driver.h>
32 #include <vta/hw_spec.h>
33 
34 #include <algorithm>
35 #include <cassert>
36 #include <cstring>
37 #include <memory>
38 #include <vector>
39 
40 namespace vta {
41 
42 // Avoid bad configurations.
43 static_assert(VTA_UOP_WIDTH == sizeof(VTAUop) * 8, "VTA_UOP_WIDTH do not match VTAUop size");
44 
45 /*! \brief Enable coherent access of data buffers between VTA and CPU */
46 static const bool kBufferCoherent = VTA_COHERENT_ACCESSES;
47 /*! \brief Always cache buffers (otherwise, write back to DRAM from CPU) */
48 static const bool kAlwaysCache = true;
49 
50 /*!
51  * \brief Data buffer represents data on CMA.
52  */
53 struct DataBuffer {
54   /*! \return Virtual address of the data. */
virt_addrvta::DataBuffer55   void* virt_addr() const { return data_; }
56   /*! \return Physical address of the data. */
phy_addrvta::DataBuffer57   vta_phy_addr_t phy_addr() const { return phy_addr_; }
58   /*!
59    * \brief Invalidate the cache of given location in data buffer.
60    * \param offset The offset to the data.
61    * \param size The size of the data.
62    */
InvalidateCachevta::DataBuffer63   void InvalidateCache(size_t offset, size_t size) {
64     if (!kBufferCoherent && kAlwaysCache) {
65       VTAInvalidateCache(reinterpret_cast<char*>(data_) + offset, phy_addr_ + offset, size);
66     }
67   }
68   /*!
69    * \brief Invalidate the cache of certain location in data buffer.
70    * \param offset The offset to the data.
71    * \param size The size of the data.
72    */
FlushCachevta::DataBuffer73   void FlushCache(size_t offset, size_t size) {
74     if (!kBufferCoherent && kAlwaysCache) {
75       VTAFlushCache(reinterpret_cast<char*>(data_) + offset, phy_addr_ + offset, size);
76     }
77   }
78   /*!
79    * \brief Performs a copy operation from host memory to buffer allocated with VTAMemAlloc.
80    * \param dst The desination buffer in FPGA-accessible memory. Has to be allocated with
81    * VTAMemAlloc(). \param src The source buffer in host memory. \param size Size of the region in
82    * Bytes.
83    */
MemCopyFromHostvta::DataBuffer84   void MemCopyFromHost(void* dst, const void* src, size_t size) {
85     VTAMemCopyFromHost(dst, src, size);
86   }
87   /*!
88    * \brief Performs a copy operation from buffer allocated with VTAMemAlloc to host memory.
89    * \param dst The desination buffer in host memory.
90    * \param src The source buffer in FPGA-accessible memory. Has to be allocated with VTAMemAlloc().
91    * \param size Size of the region in Bytes.
92    */
MemCopyToHostvta::DataBuffer93   void MemCopyToHost(void* dst, const void* src, size_t size) { VTAMemCopyToHost(dst, src, size); }
94   /*!
95    * \brief Allocate a buffer of a given size.
96    * \param size The size of the buffer.
97    */
Allocvta::DataBuffer98   static DataBuffer* Alloc(size_t size) {
99     void* data = VTAMemAlloc(size, kAlwaysCache);
100     CHECK(data != nullptr);
101     DataBuffer* buffer = new DataBuffer();
102     buffer->data_ = data;
103     buffer->phy_addr_ = VTAMemGetPhyAddr(data);
104     return buffer;
105   }
106   /*!
107    * \brief Free the data buffer.
108    * \param buffer The buffer to be freed.
109    */
Freevta::DataBuffer110   static void Free(DataBuffer* buffer) {
111     VTAMemFree(buffer->data_);
112     delete buffer;
113   }
114   /*!
115    * \brief Create data buffer header from buffer ptr.
116    * \param buffer The buffer pointer.
117    * \return The corresponding data buffer header.
118    */
FromHandlevta::DataBuffer119   static DataBuffer* FromHandle(const void* buffer) {
120     return const_cast<DataBuffer*>(reinterpret_cast<const DataBuffer*>(buffer));
121   }
122 
123  private:
124   /*! \brief The internal data. */
125   void* data_;
126   /*! \brief The physical address of the buffer, excluding header. */
127   vta_phy_addr_t phy_addr_;
128 };
129 
130 /*!
131  * \brief Micro op kernel.
132  *  Contains functions to construct the kernel with prefix Push.
133  */
134 class UopKernel {
135  public:
136   /*! \brief Loop information. */
137   struct LoopEntry {
138     uint32_t extent;
139     uint32_t dst_factor;
140     uint32_t src_factor;
141     uint32_t wgt_factor;
142   };
143   /*!
144    * \brief Construct UopKernel with signature.
145    * \param signature The pointer to signature.
146    * \param nbytes Number of bytes.
147    */
UopKernel(const char * signature,int nbytes)148   UopKernel(const char* signature, int nbytes) : signature_(signature, signature + nbytes) {}
149   /*!
150    * \brief Verify if the signature is correct.
151    * \param signature Signature ptr.
152    * \param nbytes Number of bytes.
153    */
MatchSignature(void * signature,int nbytes) const154   bool MatchSignature(void* signature, int nbytes) const {
155     if (static_cast<size_t>(nbytes) != signature_.size()) return false;
156     return memcmp(signature, signature_.data(), nbytes) == 0;
157   }
158   /*! \return Whether the kernel is cached in SRAM. */
cached() const159   bool cached() const { return sram_begin_ != sram_end_; }
160   /*! \return The length of the micro op sequence. */
size() const161   size_t size() const { return seq_.size(); }
162   /*! \return The micro-op data. */
data() const163   const VTAUop* data() const { return seq_.data(); }
164   /*! \return The loop structure. */
loop() const165   const std::vector<LoopEntry>& loop() const { return loop_; }
166   /*!
167    * \brief Declare loop start.
168    * \param extent The loop extent.
169    * \param dst_factor Loop factor of accum index.
170    * \param src_factor Loop factor of input index
171    * \param wgt_factor Loop factor of weight index.
172    */
PushLoopBegin(uint32_t extent,uint32_t dst_factor,uint32_t src_factor,uint32_t wgt_factor)173   void PushLoopBegin(uint32_t extent, uint32_t dst_factor, uint32_t src_factor,
174                      uint32_t wgt_factor) {
175     LoopEntry le;
176     le.extent = extent;
177     le.dst_factor = dst_factor;
178     le.src_factor = src_factor;
179     le.wgt_factor = wgt_factor;
180     CHECK_EQ(seq_.size(), 0U);
181     CHECK_LT(loop_.size(), 2U);
182     loop_.push_back(le);
183     ++loop_ptr_;
184   }
185   /*!
186    * \brief Declare loop end.
187    */
PushLoopEnd()188   void PushLoopEnd() { --loop_ptr_; }
189   /*!
190    * \brief Push micro op into kernel.
191    * \param mode Set to GEMM mode if set to 0, ALU mode is set to 1.
192    * \param reset_out Resets the accum to 0.
193    * \param dst_index The accum memory index.
194    * \param src_index The input memory (gemm) / accum memory (alu) index.
195    * \param wgt_index The weight memory index.
196    * \param opcode The ALU opcode.
197    * \param use_imm Use immediate in ALU mode if set to true.
198    * \param imm_val Immediate value in ALU mode.
199    */
Push(uint32_t mode,uint32_t reset_out,uint32_t dst_index,uint32_t src_index,uint32_t wgt_index,uint32_t opcode,uint32_t use_imm,int32_t imm_val)200   void Push(uint32_t mode, uint32_t reset_out, uint32_t dst_index, uint32_t src_index,
201             uint32_t wgt_index, uint32_t opcode, uint32_t use_imm, int32_t imm_val) {
202     // The loop nest structure
203     VerifyDep(dst_index);
204     VTAUop op;
205     op.dst_idx = dst_index;
206     op.src_idx = src_index;
207     op.wgt_idx = wgt_index;
208     seq_.push_back(op);
209     // Ensure that mode is consistent if set
210     if (mode_ == 0xFFFFFFFF) {
211       mode_ = mode;
212     } else {
213       CHECK(mode_ == mode);
214     }
215     // Set reset_out field if unset
216     if (reset_out_ == 0xFFFFFFFF) {
217       reset_out_ = reset_out;
218     } else {
219       CHECK(reset_out_ == reset_out);
220     }
221     // Check kernel op and imm/imm_val in ALU mode
222     if (mode == 1) {
223       if (opcode_ == 0xFFFFFFFF) {
224         opcode_ = opcode;
225         use_imm_ = use_imm;
226         imm_val_ = imm_val;
227       } else {
228         CHECK(opcode_ == opcode);
229         CHECK(use_imm_ == use_imm);
230         CHECK(imm_val_ == imm_val);
231       }
232     }
233   }
234   /*! \brief Dump kernel micro ops to stdout. */
Dump()235   void Dump() {
236     uint32_t size = seq_.size();
237     printf("There are %u uops\n", size);
238     for (uint32_t i = 0; i < size; ++i) {
239       printf("[%04u]\t acc=%u, inp=%u, wgt=%u\n", i, seq_[i].dst_idx, seq_[i].src_idx,
240              seq_[i].wgt_idx);
241     }
242     printf("\n");
243   }
244 
245  public:
246   // The kernel's mode, opcode, immediate setting and value
247   uint32_t mode_{0xFFFFFFFF};  // UOP type: 0xFFFFFFFF - unset, 0 - GEMM, 1 - ALU
248   uint32_t opcode_{0xFFFFFFFF};
249   uint32_t reset_out_{0xFFFFFFFF};
250   bool use_imm_{false};
251   int16_t imm_val_{0};
252 
253  private:
254   // Verify that we don't write to the same acc_mem index two cycles in a row
VerifyDep(uint32_t dst_index)255   void VerifyDep(uint32_t dst_index) {
256     size_t step = std::min(static_cast<size_t>(2U), seq_.size());
257     for (size_t i = seq_.size() - step; i < seq_.size(); ++i) {
258       CHECK(seq_[i].dst_idx != dst_index);
259     }
260   }
261   // The uop buffer
262   template <int, bool, bool>
263   friend class UopQueue;
264   friend class CommandQueue;
265   // SRAM location if begin != end
266   uint32_t sram_begin_{0};
267   uint32_t sram_end_{0};
268   // The signature used for verification
269   std::vector<char> signature_;
270   // Internal sequence
271   std::vector<VTAUop> seq_;
272   // The loop nest structure specific to ALU instructions
273   std::vector<LoopEntry> loop_;
274   // The loop pointer
275   size_t loop_ptr_{0};
276 };
277 
278 /*!
279  * \brief Base class of all queues to send and recv serial data.
280  */
281 template <class T>
282 class BaseQueue {
283  public:
~BaseQueue()284   virtual ~BaseQueue() {
285     if (fpga_buff_ != nullptr) {
286       VTAMemFree(fpga_buff_);
287     }
288   }
289   /*! \return Content of DRAM buffer. */
dram_buffer() const290   char* dram_buffer() const { return dram_buffer_; }
291   /*! \return Physical address of DRAM. */
dram_phy_addr() const292   vta_phy_addr_t dram_phy_addr() const {
293     CHECK(fpga_buff_phy_);
294     return fpga_buff_phy_;
295   }
296   /*! \return Whether there is pending information. */
pending() const297   bool pending() const { return sram_begin_ != sram_end_; }
298   /*! \brief Initialize the space of the buffer. */
InitSpace(uint32_t elem_bytes,uint32_t max_bytes,bool coherent,bool always_cache)299   void InitSpace(uint32_t elem_bytes, uint32_t max_bytes, bool coherent, bool always_cache) {
300     coherent_ = coherent;
301     always_cache_ = always_cache;
302     elem_bytes_ = elem_bytes;
303     // Allocate buffer ahead of time
304     fpga_buff_ = static_cast<char*>(VTAMemAlloc(max_bytes, coherent_ || always_cache_));
305     CHECK(fpga_buff_ != nullptr);
306     fpga_buff_phy_ = VTAMemGetPhyAddr(fpga_buff_);
307   }
308   /*!
309    * \brief Reset the pointer of the buffer.
310    *  Set SRAM pointer to be the current end.
311    */
Reset()312   virtual void Reset() {
313     dram_buffer_.clear();
314     // reset to 0 as we always copy data to area starting from fpga_buff base
315     // we do mem copy for every DeviceRun
316     sram_end_ = 0;
317     sram_begin_ = sram_end_;
318   }
319 
320  protected:
321   // Cache coherence access (shared memory only)
322   bool coherent_{false};
323   // Make the buffer cacheable
324   bool always_cache_{false};
325   // Element bytes
326   uint32_t elem_bytes_{0};
327   // Begin location of current SRAM read in FIFO mode
328   uint32_t sram_begin_{0};
329   // End location of current SRAM write in FIFO mode
330   uint32_t sram_end_{0};
331   // The buffer in DRAM
332   std::vector<T> dram_buffer_;
333   // FPGA accessible buffer
334   void* fpga_buff_{NULL};
335   // Physical address of the FPGA buffer
336   vta_phy_addr_t fpga_buff_phy_{0};
337 };
338 
339 /*!
340  * \brief Micro op buffer that manages the micro op cache.
341  */
342 template <int kMaxBytes, bool kCoherent, bool kAlwaysCache>
343 class UopQueue : public BaseQueue<VTAUop> {
344  public:
InitSpace()345   void InitSpace() { BaseQueue::InitSpace(kElemBytes, kMaxBytes, kCoherent, kAlwaysCache); }
346   // Push data to the queue
347   template <typename FAutoSync>
Push(UopKernel * kernel,FAutoSync fautosync)348   void Push(UopKernel* kernel, FAutoSync fautosync) {
349     // if the micro-op is cached in VTA SRAM, skip
350     if (kernel->cached()) return;
351     // check if we've exceeded the size of the allocated FPGA readable buffer
352     size_t num_op = kernel->size();
353     if (dram_buffer_.size() + num_op > kMaxElems) {
354       fautosync();
355       CHECK(dram_buffer_.size() <= kMaxElems);
356     }
357     // Cannot have a micro-op kernel larger than SRAM buffer
358     CHECK(num_op <= kMaxNumUop);
359     uint32_t uop_begin = 0;
360     if (sram_end_ + num_op > kMaxNumUop) {
361       // Need to evict
362       cache_idx_ = 0;
363       sram_begin_ = 0;
364       sram_end_ = num_op;
365     } else {
366       uop_begin = sram_end_;
367       sram_end_ += num_op;
368     }
369     // Simple eviction policy
370     uint32_t evict_begin = cache_idx_;
371     for (; cache_idx_ < cache_.size(); ++cache_idx_) {
372       if (cache_[cache_idx_]->sram_begin_ >= sram_end_) break;
373       // Mark the kernel as "invalid"
374       cache_[cache_idx_]->sram_begin_ = 0;
375       cache_[cache_idx_]->sram_end_ = 0;
376     }
377     // Increase size of buffer
378     kernel->sram_begin_ = uop_begin;
379     kernel->sram_end_ = sram_end_;
380     CHECK(kernel->cached());
381     cache_.insert(cache_.begin() + cache_idx_, kernel);
382     cache_.erase(cache_.begin() + evict_begin, cache_.begin() + cache_idx_);
383     cache_idx_ = evict_begin + 1;
384   }
385   // Flush micro op load instruction
FlushUopLoad(VTAMemInsn * insn)386   void FlushUopLoad(VTAMemInsn* insn) {
387     if (sram_begin_ != sram_end_) {
388       // Derive offset in FPGA-readable buffer
389       int32_t offset = 0;
390       for (uint32_t i = 0; i < cache_idx_ - 1; ++i) {
391         offset += cache_[i]->size() * kElemBytes;
392       }
393       insn->memory_type = VTA_MEM_ID_UOP;
394       insn->sram_base = sram_begin_;
395       // Update cache idx to physical address map
396       insn->dram_base = (fpga_buff_phy_ + offset) / kElemBytes;
397       insn->y_size = 1;
398       insn->x_size = (sram_end_ - sram_begin_);
399       insn->x_stride = (sram_end_ - sram_begin_);
400       insn->y_pad_0 = 0;
401       insn->y_pad_1 = 0;
402       insn->x_pad_0 = 0;
403       insn->x_pad_1 = 0;
404       // Reset indices
405       sram_begin_ = sram_end_;
406     }
407   }
408   /*! \brief clear cache and reset base queue buffer.*/
Reset()409   void Reset() {
410     // unmark "cached" status
411     // as we cannot assume it is still in SRAM across DeviceRun
412     for (UopKernel* kernel : cache_) {
413       kernel->sram_begin_ = 0;
414       kernel->sram_end_ = 0;
415     }
416 
417     cache_.clear();
418     cache_idx_ = 0;
419     BaseQueue<VTAUop>::Reset();
420   }
AutoReadBarrier()421   void AutoReadBarrier() { ReadBarrier(); }
422   /*! \brief Writer barrier to make sure that data written by CPU is visible to VTA. */
ReadBarrier()423   void ReadBarrier() {
424     CHECK(fpga_buff_ != nullptr);
425     CHECK(fpga_buff_phy_);
426     // Iterate over caches; allocate buffer in FPGA-readable memory
427     uint32_t buff_size = 0;
428     for (uint32_t i = 0; i < cache_.size(); ++i) {
429       buff_size += cache_[i]->size() * kElemBytes;
430     }
431     CHECK(buff_size <= kMaxBytes);
432     // Move kernel contents to FPGA readable buffer
433     uint32_t offset = 0;
434     for (uint32_t i = 0; i < cache_.size(); ++i) {
435       uint32_t ksize = cache_[i]->size() * kElemBytes;
436       VTAMemCopyFromHost(static_cast<char*>(fpga_buff_) + offset, cache_[i]->data(), ksize);
437       // Update offset
438       offset += ksize;
439     }
440     // Flush if we're using a shared memory system
441     // and if interface is non-coherent
442     if (!coherent_ && always_cache_) {
443       VTAFlushCache(fpga_buff_, fpga_buff_phy_, offset);
444     }
445   }
446 
447  private:
448   // Cache pointer
449   uint32_t cache_idx_{0};
450   // Cached ring, sorted by sram_begin
451   std::vector<UopKernel*> cache_;
452   // Constants
453   static constexpr int kElemBytes = sizeof(VTAUop);
454   static constexpr int kMaxNumUop = VTA_UOP_BUFF_DEPTH;
455   static constexpr int kMaxElems = kMaxBytes / kElemBytes;
456 };
457 
458 // Internal kernel structure
459 class UopKernelMap {
460  public:
461   // Simple hash map
Get(void * signature,int nbytes)462   UopKernel** Get(void* signature, int nbytes) {
463     uint32_t key = 0;
464     CHECK(nbytes == 0 || nbytes == sizeof(int));
465     if (nbytes == sizeof(int)) {
466       memcpy(&key, signature, sizeof(int));
467       key = key + 1;
468     }
469     CHECK_LT(key, 100);
470     if (kmap_.size() <= key) {
471       kmap_.resize(key + 1, nullptr);
472     }
473     return &(kmap_[key]);
474   }
475 
476  private:
477   std::vector<UopKernel*> kmap_;
478 };
479 
480 enum PipelineStage : int { kNoneStage = 0, kLoadStage = 1, kComputeStage = 2, kStoreStage = 3 };
481 
482 // Instruction Queue
483 template <int kMaxBytes, bool kCoherent, bool kAlwaysCache>
484 class InsnQueue : public BaseQueue<VTAGenericInsn> {
485  public:
486   /*! \brief Initialize the space. */
InitSpace()487   void InitSpace() {
488     BaseQueue::InitSpace(kElemBytes, kMaxBytes, kCoherent, kAlwaysCache);
489     // Initialize the stage
490     std::fill(pending_pop_prev_, pending_pop_prev_ + 4, 0);
491     std::fill(pending_pop_next_, pending_pop_next_ + 4, 0);
492   }
493   /*! \return The data pointer. */
data()494   VTAGenericInsn* data() { return dram_buffer_.data(); }
495   /*! \return Number of instructions. */
count()496   uint32_t count() { return dram_buffer_.size(); }
497   // Insert dependency push of load
DepPop(int from,int to)498   void DepPop(int from, int to) {
499     // NOTE: This instruction executes on queue[to]
500     if (from < to) {
501       if (pending_pop_prev_[to]) {
502         this->CommitPendingPop(to);
503       }
504       pending_pop_prev_[to] = 1;
505     } else {
506       if (pending_pop_next_[to]) {
507         this->CommitPendingPop(to);
508       }
509       pending_pop_next_[to] = 1;
510     }
511     // Impossible condition
512     CHECK(from != kLoadStage || to != kStoreStage);
513     CHECK(from != kStoreStage || to != kLoadStage);
514   }
515   // Insert dependency push of load
DepPush(int from,int to)516   void DepPush(int from, int to) {
517     // NOTE: this instruction executes on queue[from]
518     this->CommitPendingPop(from);
519     if (!dram_buffer_.empty()) {
520       VTAMemInsn* mptr = reinterpret_cast<VTAMemInsn*>(&dram_buffer_.back());
521       if (GetPipelineStage(mptr) == from) {
522         if (from < to && !mptr->push_next_dep) {
523           // push(LD->C) or push(C->ST)
524           mptr->push_next_dep = true;
525           return;
526         } else if (from > to && !mptr->push_prev_dep) {
527           // push(C->LD) or push(ST->C)
528           mptr->push_prev_dep = true;
529           return;
530         }
531       }
532     }
533     if (from < to) {
534       // Push next dep
535       PushNoop(from, false, true, false, false);
536     } else {
537       // Push prev dep
538       PushNoop(from, true, false, false, false);
539     }
540   }
541   // Create a new instruction for a GEMM stage
CreateGemInsn()542   VTAGemInsn* CreateGemInsn() { return reinterpret_cast<VTAGemInsn*>(Create(kComputeStage)); }
543   // Create a new instruction for a ALU stage
CreateAluInsn()544   VTAAluInsn* CreateAluInsn() { return reinterpret_cast<VTAAluInsn*>(Create(kComputeStage)); }
545   // Create a new instruction for a memory stage
CreateMemInsn(int memory_type)546   VTAMemInsn* CreateMemInsn(int memory_type) {
547     return reinterpret_cast<VTAMemInsn*>(Create(GetMemPipelineStage(memory_type)));
548   }
549   // create a new instruction for a store stage
CreateStoreInsn()550   VTAMemInsn* CreateStoreInsn() { return reinterpret_cast<VTAMemInsn*>(Create(kStoreStage)); }
551   // Rewrite instruction stream to force serial execution
RewriteForceSerial()552   void RewriteForceSerial() {
553     int insn_count = count();
554     VTAMemInsn* mem_ptr = reinterpret_cast<VTAMemInsn*>(data());
555     VTAMemInsn* mem_last_store_ptr = nullptr;
556     VTAMemInsn* mem_last_ptr = nullptr;
557     for (int i = 1; i < insn_count; ++i) {
558       PipelineStage prev = GetPipelineStageAll(mem_ptr + i - 1);
559       PipelineStage now = GetPipelineStageAll(mem_ptr + i);
560       if (prev == kLoadStage && now == kComputeStage) {
561         mem_ptr[i - 1].push_prev_dep = false;
562         mem_ptr[i - 1].push_next_dep = true;
563         mem_ptr[i].pop_prev_dep = true;
564         mem_ptr[i].pop_next_dep = false;
565       } else if (prev == kComputeStage && now == kLoadStage) {
566         mem_ptr[i - 1].push_prev_dep = true;
567         mem_ptr[i - 1].push_next_dep = false;
568         mem_ptr[i].pop_prev_dep = false;
569         mem_ptr[i].pop_next_dep = true;
570       } else if (prev == kStoreStage && now == kComputeStage) {
571         mem_ptr[i - 1].push_prev_dep = true;
572         mem_ptr[i - 1].push_next_dep = false;
573         mem_ptr[i].pop_prev_dep = false;
574         mem_ptr[i].pop_next_dep = true;
575       } else if (prev == kComputeStage && now == kStoreStage) {
576         mem_ptr[i - 1].push_prev_dep = false;
577         mem_ptr[i - 1].push_next_dep = true;
578         mem_ptr[i].pop_prev_dep = true;
579         mem_ptr[i].pop_next_dep = false;
580       } else {
581         mem_ptr[i - 1].push_prev_dep = false;
582         mem_ptr[i - 1].push_next_dep = false;
583         mem_ptr[i].pop_prev_dep = false;
584         mem_ptr[i].pop_next_dep = false;
585       }
586       if (now == kStoreStage) {
587         mem_last_store_ptr = &mem_ptr[i];
588       }
589       mem_last_ptr = &mem_ptr[i];
590     }
591     // set dependency to make sure all core instruction get excuted
592     // before last FINISH instruction
593     if (mem_last_store_ptr && mem_last_ptr == mem_last_store_ptr) {
594       mem_last_store_ptr->push_prev_dep = true;
595       if (!pending_pop_next_[kComputeStage]) {
596         DepPop(kStoreStage, kComputeStage);
597       }
598       CommitPendingPop(kComputeStage);
599     } else {
600       pending_pop_next_[kComputeStage] = 0;
601     }
602     DepPush(kComputeStage, kLoadStage);
603     DepPop(kLoadStage, kComputeStage);
604     if (!pending_pop_next_[kLoadStage]) {
605       DepPop(kComputeStage, kLoadStage);
606     }
607     CommitPendingPop(kLoadStage);
608     DepPush(kLoadStage, kComputeStage);
609     CommitPendingPop(kComputeStage);
610   }
611   // Helper function: Get Opcode string
getOpcodeString(int opcode,bool use_imm)612   const char* getOpcodeString(int opcode, bool use_imm) {
613     // The string name
614     if (opcode == VTA_ALU_OPCODE_MIN) {
615       if (use_imm) {
616         return "min imm";
617       } else {
618         return "min";
619       }
620     } else if (opcode == VTA_ALU_OPCODE_MAX) {
621       if (use_imm) {
622         return "max imm";
623       } else {
624         return "max";
625       }
626     } else if (opcode == VTA_ALU_OPCODE_ADD) {
627       if (use_imm) {
628         return "add imm";
629       } else {
630         return "add";
631       }
632     } else if (opcode == VTA_ALU_OPCODE_SHR) {
633       return "shr";
634     }
635 
636     return "unknown op";
637   }
638   // Dump instructions in the queue
DumpInsn()639   void DumpInsn() {
640     // Keep tabs on dependence queues
641     int l2g_queue = 0;
642     int g2l_queue = 0;
643     int s2g_queue = 0;
644     int g2s_queue = 0;
645     // Converter
646     union VTAInsn c;
647     // Iterate over all instructions
648     int insn_count = count();
649     const VTAGenericInsn* insn = data();
650     printf("There are %u instructions\n", insn_count);
651     for (int i = 0; i < insn_count; ++i) {
652       // Fetch instruction and decode opcode
653       c.generic = insn[i];
654       printf("INSTRUCTION %u: ", i);
655       if (c.mem.opcode == VTA_OPCODE_LOAD || c.mem.opcode == VTA_OPCODE_STORE) {
656         if (c.mem.x_size == 0) {
657           if (c.mem.opcode == VTA_OPCODE_STORE) {
658             printf("NOP-STORE-STAGE\n");
659           } else if (GetMemPipelineStage(c.mem.memory_type) == kComputeStage) {
660             printf("NOP-COMPUTE-STAGE\n");
661           } else {
662             printf("NOP-MEMORY-STAGE\n");
663           }
664           printf("\tdep - pop prev: %d, pop next: %d, push prev: %d, push next: %d\n",
665                  static_cast<int>(c.mem.pop_prev_dep), static_cast<int>(c.mem.pop_next_dep),
666                  static_cast<int>(c.mem.push_prev_dep), static_cast<int>(c.mem.push_next_dep));
667           // Count status in queues
668           if (c.mem.opcode == VTA_OPCODE_STORE) {
669             CHECK(c.mem.pop_next_dep == false);
670             CHECK(c.mem.push_next_dep == false);
671             if (c.mem.pop_prev_dep) g2s_queue--;
672             if (c.mem.push_prev_dep) s2g_queue++;
673           } else if (c.mem.opcode == VTA_OPCODE_LOAD &&
674                      (c.mem.memory_type == VTA_MEM_ID_INP || c.mem.memory_type == VTA_MEM_ID_WGT)) {
675             CHECK(c.mem.pop_prev_dep == false);
676             CHECK(c.mem.push_prev_dep == false);
677             if (c.mem.pop_next_dep) g2l_queue--;
678             if (c.mem.push_next_dep) l2g_queue++;
679           } else {
680             if (c.mem.pop_prev_dep) l2g_queue--;
681             if (c.mem.push_prev_dep) g2l_queue++;
682             if (c.mem.pop_next_dep) s2g_queue--;
683             if (c.mem.push_next_dep) g2s_queue++;
684           }
685           printf("\tl2g_queue = %d, g2l_queue = %d\n", l2g_queue, g2l_queue);
686           printf("\ts2g_queue = %d, g2s_queue = %d\n", s2g_queue, g2s_queue);
687           continue;
688         }
689         // Print instruction field information
690         if (c.mem.opcode == VTA_OPCODE_LOAD) {
691           printf("LOAD ");
692           if (c.mem.memory_type == VTA_MEM_ID_UOP) printf("UOP\n");
693           if (c.mem.memory_type == VTA_MEM_ID_WGT) printf("WGT\n");
694           if (c.mem.memory_type == VTA_MEM_ID_INP) printf("INP\n");
695           if (c.mem.memory_type == VTA_MEM_ID_ACC) printf("ACC\n");
696         }
697         if (c.mem.opcode == VTA_OPCODE_STORE) {
698           printf("STORE:\n");
699         }
700         printf("\tdep - pop prev: %d, pop next: %d, push prev: %d, push next: %d\n",
701                static_cast<int>(c.mem.pop_prev_dep), static_cast<int>(c.mem.pop_next_dep),
702                static_cast<int>(c.mem.push_prev_dep), static_cast<int>(c.mem.push_next_dep));
703         printf("\tDRAM: 0x%08x, SRAM:0x%04x\n", static_cast<int>(c.mem.dram_base),
704                static_cast<int>(c.mem.sram_base));
705         printf("\ty: size=%d, pad=[%d, %d]\n", static_cast<int>(c.mem.y_size),
706                static_cast<int>(c.mem.y_pad_0), static_cast<int>(c.mem.y_pad_1));
707         printf("\tx: size=%d, stride=%d, pad=[%d, %d]\n", static_cast<int>(c.mem.x_size),
708                static_cast<int>(c.mem.x_stride), static_cast<int>(c.mem.x_pad_0),
709                static_cast<int>(c.mem.x_pad_1));
710       } else if (c.mem.opcode == VTA_OPCODE_GEMM) {
711         // Print instruction field information
712         printf("GEMM\n");
713 
714         printf("\tdep - pop prev: %d, pop next: %d, push prev: %d, push next: %d\n",
715                static_cast<int>(c.mem.pop_prev_dep), static_cast<int>(c.mem.pop_next_dep),
716                static_cast<int>(c.mem.push_prev_dep), static_cast<int>(c.mem.push_next_dep));
717         printf("\treset_out: %d\n", static_cast<int>(c.gemm.reset_reg));
718         printf("\trange (%d, %d)\n", static_cast<int>(c.gemm.uop_bgn),
719                static_cast<int>(c.gemm.uop_end));
720         printf("\touter loop - iter: %d, wgt: %d, inp: %d, acc: %d\n",
721                static_cast<int>(c.gemm.iter_out), static_cast<int>(c.gemm.wgt_factor_out),
722                static_cast<int>(c.gemm.src_factor_out), static_cast<int>(c.gemm.dst_factor_out));
723         printf("\tinner loop - iter: %d, wgt: %d, inp: %d, acc: %d\n",
724                static_cast<int>(c.gemm.iter_in), static_cast<int>(c.gemm.wgt_factor_in),
725                static_cast<int>(c.gemm.src_factor_in), static_cast<int>(c.gemm.dst_factor_in));
726       } else if (c.mem.opcode == VTA_OPCODE_ALU) {
727         // Print instruction field information
728         printf("ALU - %s\n", getOpcodeString(c.alu.alu_opcode, c.alu.use_imm));
729         printf("\tdep - pop prev: %d, pop next: %d, push prev: %d, push next: %d\n",
730                static_cast<int>(c.mem.pop_prev_dep), static_cast<int>(c.mem.pop_next_dep),
731                static_cast<int>(c.mem.push_prev_dep), static_cast<int>(c.mem.push_next_dep));
732         printf("\treset_out: %d\n", static_cast<int>(c.alu.reset_reg));
733         printf("\trange (%d, %d)\n", static_cast<int>(c.alu.uop_bgn),
734                static_cast<int>(c.alu.uop_end));
735         printf("\touter loop - iter: %d, dst: %d, src: %d\n", static_cast<int>(c.alu.iter_out),
736                static_cast<int>(c.alu.dst_factor_out), static_cast<int>(c.alu.src_factor_out));
737         printf("\tinner loop - iter: %d, dst: %d, src: %d\n", static_cast<int>(c.alu.iter_in),
738                static_cast<int>(c.alu.dst_factor_in), static_cast<int>(c.alu.src_factor_in));
739       } else if (c.mem.opcode == VTA_OPCODE_FINISH) {
740         printf("FINISH\n");
741       }
742 
743       // Count status in queues
744       if (c.mem.opcode == VTA_OPCODE_LOAD || c.mem.opcode == VTA_OPCODE_STORE) {
745         if (c.mem.opcode == VTA_OPCODE_STORE) {
746           CHECK(c.mem.pop_next_dep == false);
747           CHECK(c.mem.push_next_dep == false);
748           if (c.mem.pop_prev_dep) g2s_queue--;
749           if (c.mem.push_prev_dep) s2g_queue++;
750         } else if (c.mem.opcode == VTA_OPCODE_LOAD &&
751                    (c.mem.memory_type == VTA_MEM_ID_INP || c.mem.memory_type == VTA_MEM_ID_WGT)) {
752           CHECK(c.mem.pop_prev_dep == false);
753           CHECK(c.mem.push_prev_dep == false);
754           if (c.mem.pop_next_dep) g2l_queue--;
755           if (c.mem.push_next_dep) l2g_queue++;
756         } else {
757           if (c.mem.pop_prev_dep) l2g_queue--;
758           if (c.mem.push_prev_dep) g2l_queue++;
759           if (c.mem.pop_next_dep) s2g_queue--;
760           if (c.mem.push_next_dep) g2s_queue++;
761         }
762       } else if (c.mem.opcode == VTA_OPCODE_GEMM || c.mem.opcode == VTA_OPCODE_ALU) {
763         // Print instruction field information
764         if (c.gemm.pop_prev_dep) l2g_queue--;
765         if (c.gemm.push_prev_dep) g2l_queue++;
766         if (c.gemm.pop_next_dep) s2g_queue--;
767         if (c.gemm.push_next_dep) g2s_queue++;
768       }
769       printf("\tl2g_queue = %d, g2l_queue = %d\n", l2g_queue, g2l_queue);
770       printf("\ts2g_queue = %d, g2s_queue = %d\n", s2g_queue, g2s_queue);
771     }
772   }
773   // Commit all pending pop of corresponding stage
CommitPendingPop(int stage)774   void CommitPendingPop(int stage) {
775     // Handle the LD<->compute queue
776     // NOTE: pop executes on target(stage)
777     CHECK(stage > 0 && stage < 4);
778     if (pending_pop_prev_[stage] || pending_pop_next_[stage]) {
779       PushNoop(stage, false, false, pending_pop_prev_[stage], pending_pop_next_[stage]);
780       pending_pop_prev_[stage] = 0;
781       pending_pop_next_[stage] = 0;
782     }
783   }
CommitPending()784   void CommitPending() {
785     for (int i = kLoadStage; i <= kStoreStage; ++i) {
786       CommitPendingPop(i);
787     }
788   }
PendingPop()789   bool PendingPop() {
790     for (int i = kLoadStage; i <= kStoreStage; ++i) {
791       if (pending_pop_prev_[i]) return true;
792       if (pending_pop_next_[i]) return true;
793     }
794     return false;
795   }
AutoReadBarrier()796   void AutoReadBarrier() { ReadBarrier(); }
797   /*! \brief Writer barrier to make sure that data written by CPU is visible to VTA. */
ReadBarrier()798   void ReadBarrier() {
799     CHECK(fpga_buff_ != nullptr);
800     CHECK(fpga_buff_phy_);
801     uint32_t buff_size = dram_buffer_.size() * elem_bytes_;
802     CHECK(buff_size <= kMaxBytes);
803     // Copy contents of DRAM buffer to FPGA buff
804     VTAMemCopyFromHost(fpga_buff_, dram_buffer_.data(), buff_size);
805     // Flush if we're using a shared memory system
806     // and if interface is non-coherent
807     if (!coherent_ && always_cache_) {
808       VTAFlushCache(fpga_buff_, fpga_buff_phy_, buff_size);
809     }
810   }
811 
812  protected:
813   /*! \return Add new instruction to the buffer. */
NextInsn()814   VTAGenericInsn* NextInsn() {
815     VTAGenericInsn insn;
816     dram_buffer_.push_back(insn);
817     return &dram_buffer_.back();
818   }
819   // Create a new instruction for a given stage
Create(PipelineStage stage)820   VTAGenericInsn* Create(PipelineStage stage) {
821     VTAGenericInsn* gptr = NextInsn();
822     VTAMemInsn* mptr = reinterpret_cast<VTAMemInsn*>(gptr);
823     mptr->pop_prev_dep = pending_pop_prev_[stage];
824     mptr->pop_next_dep = pending_pop_next_[stage];
825     mptr->push_prev_dep = false;
826     mptr->push_next_dep = false;
827     pending_pop_prev_[stage] = 0;
828     pending_pop_next_[stage] = 0;
829     return gptr;
830   }
831   // Get stage of the memory
GetMemPipelineStage(int memory_type)832   static PipelineStage GetMemPipelineStage(int memory_type) {
833     if (memory_type == VTA_MEM_ID_ACC) return kComputeStage;
834     if (memory_type == VTA_MEM_ID_UOP) return kComputeStage;
835     return kLoadStage;
836   }
837   // Get stage of the computation
GetPipelineStage(VTAMemInsn * insn)838   static PipelineStage GetPipelineStage(VTAMemInsn* insn) {
839     if (insn->opcode == VTA_OPCODE_GEMM) return kComputeStage;
840     if (insn->opcode == VTA_OPCODE_ALU) return kComputeStage;
841     if (insn->opcode == VTA_OPCODE_LOAD) {
842       if (insn->x_size == 0) return kNoneStage;
843       if (insn->memory_type == VTA_MEM_ID_ACC) return kComputeStage;
844       if (insn->memory_type == VTA_MEM_ID_UOP) return kComputeStage;
845       return kLoadStage;
846     }
847     if (insn->opcode == VTA_OPCODE_STORE) {
848       // FIXME: Right now memory_type is a 2-bit field which means that
849       //        VTA_MEM_ID_OUT will appear as 0. For now we'll refrain from
850       //        checking the memory_type to avoid an CHECK error...
851       return kStoreStage;
852     }
853     LOG(FATAL) << "not reached";
854     return kNoneStage;
855   }
856 
857   // Get stage of memory and computation
GetPipelineStageAll(VTAMemInsn * insn)858   static PipelineStage GetPipelineStageAll(VTAMemInsn* insn) {
859     PipelineStage stage = GetPipelineStage(insn);
860     if (stage != kNoneStage) return stage;
861     return GetMemPipelineStage(insn->memory_type);
862   }
863 
864   // Push no-op
PushNoop(int stage,bool push_prev_dep,bool push_next_dep,bool pop_prev_dep,bool pop_next_dep)865   void PushNoop(int stage, bool push_prev_dep, bool push_next_dep, bool pop_prev_dep,
866                 bool pop_next_dep) {
867     VTAMemInsn* insn = reinterpret_cast<VTAMemInsn*>(NextInsn());
868     insn->opcode = (stage == kStoreStage ? VTA_OPCODE_STORE : VTA_OPCODE_LOAD);
869     insn->push_prev_dep = push_prev_dep;
870     insn->push_next_dep = push_next_dep;
871     insn->pop_prev_dep = pop_prev_dep;
872     insn->pop_next_dep = pop_next_dep;
873     insn->sram_base = 0;
874     insn->dram_base = 0;
875     insn->y_size = 0;
876     insn->x_size = 0;
877     insn->x_stride = 0;
878     insn->y_pad_0 = 0;
879     insn->y_pad_1 = 0;
880     insn->x_pad_0 = 0;
881     insn->x_pad_1 = 0;
882     insn->memory_type = (stage == kLoadStage ? VTA_MEM_ID_INP : VTA_MEM_ID_UOP);
883   }
884 
885  private:
886   // Pending pop of each isntruction queue, qid=0 is not used
887   int pending_pop_prev_[4];
888   int pending_pop_next_[4];
889   static constexpr int kElemBytes = sizeof(VTAGenericInsn);
890   static constexpr int kMaxElems = kMaxBytes / kElemBytes;
891 };
892 
893 /*!
894  * \brief The command queue object that handles the request.
895  */
896 class CommandQueue {
897  public:
CommandQueue()898   CommandQueue() { this->InitSpace(); }
InitSpace()899   void InitSpace() {
900     uop_queue_.InitSpace();
901     insn_queue_.InitSpace();
902     device_ = VTADeviceAlloc();
903     CHECK(device_ != nullptr);
904   }
905 
~CommandQueue()906   ~CommandQueue() { VTADeviceFree(device_); }
907 
GetElemBytes(uint32_t memory_id)908   uint32_t GetElemBytes(uint32_t memory_id) {
909     uint32_t elem_bytes = 0;
910     switch (memory_id) {
911       case VTA_MEM_ID_UOP:
912         elem_bytes = VTA_UOP_ELEM_BYTES;
913         break;
914       case VTA_MEM_ID_INP:
915         elem_bytes = VTA_INP_ELEM_BYTES;
916         break;
917       case VTA_MEM_ID_WGT:
918         elem_bytes = VTA_WGT_ELEM_BYTES;
919         break;
920       case VTA_MEM_ID_ACC:
921         elem_bytes = VTA_ACC_ELEM_BYTES;
922         break;
923       case VTA_MEM_ID_OUT:
924         elem_bytes = VTA_OUT_ELEM_BYTES;
925         break;
926       default:
927         LOG(FATAL) << "Memory id not recognized:" << memory_id;
928         break;
929     }
930     /*
931      * elements size should not larger than VTA_PAGE_BYTES.
932      *
933      */
934     CHECK_GE(VTA_PAGE_BYTES, elem_bytes);
935     return elem_bytes;
936   }
937 
LoadBuffer2D(void * src_dram_addr,uint32_t src_elem_offset,uint32_t x_size,uint32_t y_size,uint32_t x_stride,uint32_t x_pad_before,uint32_t y_pad_before,uint32_t x_pad_after,uint32_t y_pad_after,uint32_t dst_sram_index,uint32_t dst_memory_type)938   void LoadBuffer2D(void* src_dram_addr, uint32_t src_elem_offset, uint32_t x_size, uint32_t y_size,
939                     uint32_t x_stride, uint32_t x_pad_before, uint32_t y_pad_before,
940                     uint32_t x_pad_after, uint32_t y_pad_after, uint32_t dst_sram_index,
941                     uint32_t dst_memory_type) {
942     VTAMemInsn* insn = insn_queue_.CreateMemInsn(dst_memory_type);
943     insn->opcode = VTA_OPCODE_LOAD;
944     insn->memory_type = dst_memory_type;
945     insn->sram_base = dst_sram_index;
946     DataBuffer* src = DataBuffer::FromHandle(src_dram_addr);
947     insn->dram_base = src->phy_addr() / GetElemBytes(dst_memory_type) + src_elem_offset;
948     insn->y_size = y_size;
949     insn->x_size = x_size;
950     insn->x_stride = x_stride;
951     insn->y_pad_0 = y_pad_before;
952     insn->y_pad_1 = y_pad_after;
953     insn->x_pad_0 = x_pad_before;
954     insn->x_pad_1 = x_pad_after;
955     this->CheckInsnOverFlow();
956   }
957 
StoreBuffer2D(uint32_t src_sram_index,uint32_t src_memory_type,void * dst_dram_addr,uint32_t dst_elem_offset,uint32_t x_size,uint32_t y_size,uint32_t x_stride)958   void StoreBuffer2D(uint32_t src_sram_index, uint32_t src_memory_type, void* dst_dram_addr,
959                      uint32_t dst_elem_offset, uint32_t x_size, uint32_t y_size,
960                      uint32_t x_stride) {
961     VTAMemInsn* insn = insn_queue_.CreateStoreInsn();
962     insn->opcode = VTA_OPCODE_STORE;
963     insn->memory_type = src_memory_type;
964     insn->sram_base = src_sram_index;
965     DataBuffer* dst = DataBuffer::FromHandle(dst_dram_addr);
966     insn->dram_base = dst->phy_addr() / GetElemBytes(src_memory_type) + dst_elem_offset;
967     insn->y_size = y_size;
968     insn->x_size = x_size;
969     insn->x_stride = x_stride;
970     insn->y_pad_0 = 0;
971     insn->y_pad_1 = 0;
972     insn->x_pad_0 = 0;
973     insn->x_pad_1 = 0;
974     this->CheckInsnOverFlow();
975   }
976 
DepPush(int from_qid,int to_qid)977   void DepPush(int from_qid, int to_qid) { insn_queue_.DepPush(from_qid, to_qid); }
978 
DepPop(int from_qid,int to_qid)979   void DepPop(int from_qid, int to_qid) { insn_queue_.DepPop(from_qid, to_qid); }
980 
ReadBarrier(void * buffer,uint32_t elem_bits,uint32_t start,uint32_t extent)981   void ReadBarrier(void* buffer, uint32_t elem_bits, uint32_t start, uint32_t extent) {
982     if (!(debug_flag_ & VTA_DEBUG_SKIP_READ_BARRIER)) {
983       uint32_t elem_bytes = (elem_bits + 8 - 1) / 8;
984       DataBuffer::FromHandle(buffer)->FlushCache(elem_bytes * start, elem_bytes * extent);
985     }
986   }
987 
WriteBarrier(void * buffer,uint32_t elem_bits,uint32_t start,uint32_t extent)988   void WriteBarrier(void* buffer, uint32_t elem_bits, uint32_t start, uint32_t extent) {
989     if (!(debug_flag_ & VTA_DEBUG_SKIP_WRITE_BARRIER)) {
990       uint32_t elem_bytes = (elem_bits + 8 - 1) / 8;
991       DataBuffer::FromHandle(buffer)->InvalidateCache(elem_bytes * start, elem_bytes * extent);
992     }
993   }
994 
Synchronize(uint32_t wait_cycles)995   void Synchronize(uint32_t wait_cycles) {
996     // Insert dependences to force serialization
997     if (debug_flag_ & VTA_DEBUG_FORCE_SERIAL) {
998       insn_queue_.RewriteForceSerial();
999     } else {
1000       // This will issue finish after last store finishes
1001       insn_queue_.DepPush(kStoreStage, kComputeStage);
1002       insn_queue_.DepPush(kLoadStage, kComputeStage);
1003       insn_queue_.DepPop(kStoreStage, kComputeStage);
1004       insn_queue_.DepPop(kLoadStage, kComputeStage);
1005       insn_queue_.CommitPendingPop(kComputeStage);
1006     }
1007     // NOTE: FINISH cannot contain pop
1008     VTAGemInsn* insn = insn_queue_.CreateGemInsn();
1009     insn->opcode = VTA_OPCODE_FINISH;
1010     CHECK(!insn_queue_.PendingPop());
1011     // Check if there are no instruction to execute at all
1012     if (insn_queue_.count() == 0) return;
1013     // Synchronization for the queues
1014     uop_queue_.AutoReadBarrier();
1015     insn_queue_.AutoReadBarrier();
1016     // Dump instructions if debug enabled
1017     if (debug_flag_ & VTA_DEBUG_DUMP_INSN) {
1018       insn_queue_.DumpInsn();
1019     }
1020     // Make sure that the last instruction is a finish instruction
1021     CHECK(reinterpret_cast<VTAMemInsn*>(insn_queue_.data())[insn_queue_.count() - 1].opcode ==
1022           VTA_OPCODE_FINISH);
1023 
1024     // Make sure that we don't exceed contiguous physical memory limits
1025     CHECK(insn_queue_.count() * sizeof(VTAGenericInsn) < VTA_MAX_XFER);
1026     int timeout =
1027         VTADeviceRun(device_, insn_queue_.dram_phy_addr(), insn_queue_.count(), wait_cycles);
1028     CHECK_EQ(timeout, 0);
1029     // Reset buffers
1030     uop_queue_.Reset();
1031     insn_queue_.Reset();
1032   }
1033 
1034   // Get record kernel
record_kernel() const1035   UopKernel* record_kernel() const {
1036     CHECK(record_kernel_ != nullptr);
1037     return record_kernel_;
1038   }
1039 
1040   // Set debug flag
SetDebugFlag(int debug_flag)1041   void SetDebugFlag(int debug_flag) { debug_flag_ = debug_flag; }
1042 
PushGEMMOp(void ** uop_handle,int (* finit)(void *),void * signature,int nbytes)1043   void PushGEMMOp(void** uop_handle, int (*finit)(void*), void* signature, int nbytes) {
1044     UopKernelMap** uptr = reinterpret_cast<UopKernelMap**>(uop_handle);
1045     if (uptr[0] == nullptr) {
1046       uptr[0] = new UopKernelMap();
1047     }
1048     UopKernel** kptr = uptr[0]->Get(signature, nbytes);
1049     if (kptr[0] == nullptr) {
1050       record_kernel_ = new UopKernel(static_cast<char*>(signature), nbytes);
1051       CHECK_EQ((*finit)(signature), 0);
1052       kptr[0] = static_cast<UopKernel*>(record_kernel_);
1053       if (debug_flag_ & VTA_DEBUG_DUMP_UOP) {
1054         record_kernel_->Dump();
1055       }
1056       record_kernel_ = nullptr;
1057     }
1058     this->PushGEMMOp(static_cast<UopKernel*>(kptr[0]));
1059     this->CheckInsnOverFlow();
1060   }
1061 
PushALUUop(void ** uop_handle,int (* finit)(void *),void * signature,int nbytes)1062   void PushALUUop(void** uop_handle, int (*finit)(void*), void* signature, int nbytes) {
1063     UopKernelMap** uptr = reinterpret_cast<UopKernelMap**>(uop_handle);
1064     if (uptr[0] == nullptr) {
1065       uptr[0] = new UopKernelMap();
1066     }
1067     UopKernel** kptr = uptr[0]->Get(signature, nbytes);
1068     if (kptr[0] == nullptr) {
1069       record_kernel_ = new UopKernel(static_cast<char*>(signature), nbytes);
1070       CHECK_EQ((*finit)(signature), 0);
1071       kptr[0] = static_cast<UopKernel*>(record_kernel_);
1072       if (debug_flag_ & VTA_DEBUG_DUMP_UOP) {
1073         record_kernel_->Dump();
1074       }
1075       record_kernel_ = nullptr;
1076     }
1077     this->PushALUUop(static_cast<UopKernel*>(kptr[0]));
1078     this->CheckInsnOverFlow();
1079   }
1080 
ThreadLocal()1081   static std::shared_ptr<CommandQueue>& ThreadLocal() {
1082     static std::shared_ptr<CommandQueue> inst = std::make_shared<CommandQueue>();
1083     if (inst == nullptr) {
1084       inst = std::make_shared<CommandQueue>();
1085     }
1086     return inst;
1087   }
1088 
Shutdown()1089   static void Shutdown() { ThreadLocal().reset(); }
1090 
1091  private:
1092   // Push GEMM uop to the command buffer
PushGEMMOp(UopKernel * kernel)1093   void PushGEMMOp(UopKernel* kernel) {
1094     uop_queue_.Push(kernel, [this]() { this->AutoSync(); });
1095     if (uop_queue_.pending()) {
1096       VTAMemInsn* insn = insn_queue_.CreateMemInsn(VTA_MEM_ID_UOP);
1097       insn->opcode = VTA_OPCODE_LOAD;
1098       uop_queue_.FlushUopLoad(insn);
1099     }
1100     VTAGemInsn* insn = insn_queue_.CreateGemInsn();
1101     insn->opcode = VTA_OPCODE_GEMM;
1102     insn->reset_reg = kernel->reset_out_;
1103     insn->uop_bgn = kernel->sram_begin_;
1104     insn->uop_end = kernel->sram_end_;
1105     const std::vector<UopKernel::LoopEntry>& loop = kernel->loop();
1106     if (loop.size() > 0) {
1107       insn->iter_out = loop[0].extent;
1108       insn->wgt_factor_out = loop[0].wgt_factor;
1109       insn->src_factor_out = loop[0].src_factor;
1110       insn->dst_factor_out = loop[0].dst_factor;
1111     } else {
1112       insn->iter_out = 1;
1113       insn->wgt_factor_out = 0;
1114       insn->src_factor_out = 0;
1115       insn->dst_factor_out = 0;
1116     }
1117     if (loop.size() > 1) {
1118       insn->iter_in = loop[1].extent;
1119       insn->wgt_factor_in = loop[1].wgt_factor;
1120       insn->src_factor_in = loop[1].src_factor;
1121       insn->dst_factor_in = loop[1].dst_factor;
1122     } else {
1123       insn->iter_in = 1;
1124       insn->wgt_factor_in = 0;
1125       insn->src_factor_in = 0;
1126       insn->dst_factor_in = 0;
1127     }
1128   }
1129 
1130   // Push ALU uop to the command buffer
PushALUUop(UopKernel * kernel)1131   void PushALUUop(UopKernel* kernel) {
1132     uop_queue_.Push(kernel, [this]() { this->AutoSync(); });
1133     if (uop_queue_.pending()) {
1134       VTAMemInsn* insn = insn_queue_.CreateMemInsn(VTA_MEM_ID_UOP);
1135       insn->opcode = VTA_OPCODE_LOAD;
1136       uop_queue_.FlushUopLoad(insn);
1137     }
1138     VTAAluInsn* insn = insn_queue_.CreateAluInsn();
1139     insn->opcode = VTA_OPCODE_ALU;
1140     insn->reset_reg = kernel->reset_out_;
1141     insn->uop_bgn = kernel->sram_begin_;
1142     insn->uop_end = kernel->sram_end_;
1143     insn->alu_opcode = kernel->opcode_;
1144     insn->use_imm = kernel->use_imm_;
1145     insn->imm = kernel->imm_val_;
1146     const std::vector<UopKernel::LoopEntry>& loop = kernel->loop();
1147     if (loop.size() == 0) {
1148       insn->iter_out = 1;
1149       insn->dst_factor_out = 0;
1150       insn->src_factor_out = 0;
1151       insn->iter_in = 1;
1152       insn->dst_factor_in = 0;
1153       insn->src_factor_in = 0;
1154     } else if (loop.size() == 1) {
1155       insn->iter_out = 1;
1156       insn->dst_factor_out = 0;
1157       insn->src_factor_out = 0;
1158       insn->iter_in = loop[0].extent;
1159       insn->dst_factor_in = loop[0].dst_factor;
1160       insn->src_factor_in = loop[0].src_factor;
1161     } else {
1162       insn->iter_out = loop[0].extent;
1163       insn->dst_factor_out = loop[0].dst_factor;
1164       insn->src_factor_out = loop[0].src_factor;
1165       insn->iter_in = loop[1].extent;
1166       insn->dst_factor_in = loop[1].dst_factor;
1167       insn->src_factor_in = loop[1].src_factor;
1168     }
1169   }
1170 
CheckInsnOverFlow()1171   void CheckInsnOverFlow() {
1172     // At each API call, we can at most commit:
1173     // one pending store, one pending load, and one uop
1174     if ((insn_queue_.count() + 4) * sizeof(VTAGenericInsn) >= VTA_MAX_XFER) {
1175       this->AutoSync();
1176     }
1177   }
1178   // Auto sync when instruction overflow
AutoSync()1179   void AutoSync() { this->Synchronize(1 << 31); }
1180 
1181   // Internal debug flag
1182   int debug_flag_{0};
1183   // The kernel we are currently recording
1184   UopKernel* record_kernel_{nullptr};
1185   // Micro op queue
1186   UopQueue<VTA_MAX_XFER, kBufferCoherent, kAlwaysCache> uop_queue_;
1187   // instruction queue
1188   InsnQueue<VTA_MAX_XFER, kBufferCoherent, kAlwaysCache> insn_queue_;
1189   // Device handle
1190   VTADeviceHandle device_{nullptr};
1191 };
1192 
1193 }  // namespace vta
1194 
VTABufferAlloc(size_t size)1195 void* VTABufferAlloc(size_t size) { return vta::DataBuffer::Alloc(size); }
1196 
VTABufferFree(void * buffer)1197 void VTABufferFree(void* buffer) { vta::DataBuffer::Free(vta::DataBuffer::FromHandle(buffer)); }
1198 
VTABufferCopy(const void * from,size_t from_offset,void * to,size_t to_offset,size_t size,int kind_mask)1199 void VTABufferCopy(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size,
1200                    int kind_mask) {
1201   vta::DataBuffer* from_buffer = nullptr;
1202   vta::DataBuffer* to_buffer = nullptr;
1203 
1204   if (kind_mask & 2) {
1205     from_buffer = vta::DataBuffer::FromHandle(from);
1206     from = from_buffer->virt_addr();
1207   }
1208   if (kind_mask & 1) {
1209     to_buffer = vta::DataBuffer::FromHandle(to);
1210     to = to_buffer->virt_addr();
1211   }
1212 
1213   if (from_buffer) {
1214     // This is an FPGA to host mem transfer
1215     from_buffer->InvalidateCache(from_offset, size);
1216     from_buffer->MemCopyToHost(static_cast<char*>(to) + to_offset,
1217                                static_cast<const char*>(from) + from_offset, size);
1218   } else if (to_buffer) {
1219     // This is a host to FPGA mem transfer
1220     to_buffer->MemCopyFromHost(static_cast<char*>(to) + to_offset,
1221                                static_cast<const char*>(from) + from_offset, size);
1222     to_buffer->FlushCache(to_offset, size);
1223   }
1224 }
1225 
VTATLSCommandHandle()1226 VTACommandHandle VTATLSCommandHandle() { return vta::CommandQueue::ThreadLocal().get(); }
1227 
VTARuntimeShutdown()1228 void VTARuntimeShutdown() { vta::CommandQueue::Shutdown(); }
1229 
VTASetDebugMode(VTACommandHandle cmd,int debug_flag)1230 void VTASetDebugMode(VTACommandHandle cmd, int debug_flag) {
1231   static_cast<vta::CommandQueue*>(cmd)->SetDebugFlag(debug_flag);
1232 }
1233 
VTABufferCPUPtr(VTACommandHandle cmd,void * buffer)1234 void* VTABufferCPUPtr(VTACommandHandle cmd, void* buffer) {
1235   return vta::DataBuffer::FromHandle(buffer)->virt_addr();
1236 }
1237 
VTAWriteBarrier(VTACommandHandle cmd,void * buffer,uint32_t elem_bits,uint32_t start,uint32_t extent)1238 void VTAWriteBarrier(VTACommandHandle cmd, void* buffer, uint32_t elem_bits, uint32_t start,
1239                      uint32_t extent) {
1240   static_cast<vta::CommandQueue*>(cmd)->WriteBarrier(buffer, elem_bits, start, extent);
1241 }
1242 
VTAReadBarrier(VTACommandHandle cmd,void * buffer,uint32_t elem_bits,uint32_t start,uint32_t extent)1243 void VTAReadBarrier(VTACommandHandle cmd, void* buffer, uint32_t elem_bits, uint32_t start,
1244                     uint32_t extent) {
1245   static_cast<vta::CommandQueue*>(cmd)->ReadBarrier(buffer, elem_bits, start, extent);
1246 }
1247 
VTALoadBuffer2D(VTACommandHandle cmd,void * src_dram_addr,uint32_t src_elem_offset,uint32_t x_size,uint32_t y_size,uint32_t x_stride,uint32_t x_pad_before,uint32_t y_pad_before,uint32_t x_pad_after,uint32_t y_pad_after,uint32_t dst_sram_index,uint32_t dst_memory_type)1248 void VTALoadBuffer2D(VTACommandHandle cmd, void* src_dram_addr, uint32_t src_elem_offset,
1249                      uint32_t x_size, uint32_t y_size, uint32_t x_stride, uint32_t x_pad_before,
1250                      uint32_t y_pad_before, uint32_t x_pad_after, uint32_t y_pad_after,
1251                      uint32_t dst_sram_index, uint32_t dst_memory_type) {
1252   static_cast<vta::CommandQueue*>(cmd)->LoadBuffer2D(
1253       src_dram_addr, src_elem_offset, x_size, y_size, x_stride, x_pad_before, y_pad_before,
1254       x_pad_after, y_pad_after, dst_sram_index, dst_memory_type);
1255 }
1256 
VTAStoreBuffer2D(VTACommandHandle cmd,uint32_t src_sram_index,uint32_t src_memory_type,void * dst_dram_addr,uint32_t dst_elem_offset,uint32_t x_size,uint32_t y_size,uint32_t x_stride)1257 void VTAStoreBuffer2D(VTACommandHandle cmd, uint32_t src_sram_index, uint32_t src_memory_type,
1258                       void* dst_dram_addr, uint32_t dst_elem_offset, uint32_t x_size,
1259                       uint32_t y_size, uint32_t x_stride) {
1260   static_cast<vta::CommandQueue*>(cmd)->StoreBuffer2D(
1261       src_sram_index, src_memory_type, dst_dram_addr, dst_elem_offset, x_size, y_size, x_stride);
1262 }
1263 
VTAUopPush(uint32_t mode,uint32_t reset_out,uint32_t dst_index,uint32_t src_index,uint32_t wgt_index,uint32_t opcode,uint32_t use_imm,int32_t imm_val)1264 void VTAUopPush(uint32_t mode, uint32_t reset_out, uint32_t dst_index, uint32_t src_index,
1265                 uint32_t wgt_index, uint32_t opcode, uint32_t use_imm, int32_t imm_val) {
1266   vta::CommandQueue::ThreadLocal()->record_kernel()->Push(mode, reset_out, dst_index, src_index,
1267                                                           wgt_index, opcode, use_imm, imm_val);
1268 }
1269 
VTAUopLoopBegin(uint32_t extent,uint32_t dst_factor,uint32_t src_factor,uint32_t wgt_factor)1270 void VTAUopLoopBegin(uint32_t extent, uint32_t dst_factor, uint32_t src_factor,
1271                      uint32_t wgt_factor) {
1272   vta::CommandQueue::ThreadLocal()->record_kernel()->PushLoopBegin(extent, dst_factor, src_factor,
1273                                                                    wgt_factor);
1274 }
1275 
VTAUopLoopEnd()1276 void VTAUopLoopEnd() { vta::CommandQueue::ThreadLocal()->record_kernel()->PushLoopEnd(); }
1277 
VTAPushGEMMOp(void ** uop_handle,int (* finit)(void *),void * signature,int nbytes)1278 int VTAPushGEMMOp(void** uop_handle, int (*finit)(void*), void* signature, int nbytes) {
1279   vta::CommandQueue::ThreadLocal()->PushGEMMOp(uop_handle, finit, signature, nbytes);
1280   return 0;
1281 }
1282 
VTAPushALUOp(void ** uop_handle,int (* finit)(void *),void * signature,int nbytes)1283 int VTAPushALUOp(void** uop_handle, int (*finit)(void*), void* signature, int nbytes) {
1284   vta::CommandQueue::ThreadLocal()->PushALUUop(uop_handle, finit, signature, nbytes);
1285   return 0;
1286 }
1287 
VTADepPush(VTACommandHandle cmd,int from_qid,int to_qid)1288 int VTADepPush(VTACommandHandle cmd, int from_qid, int to_qid) {
1289   static_cast<vta::CommandQueue*>(cmd)->DepPush(from_qid, to_qid);
1290   return 0;
1291 }
1292 
VTADepPop(VTACommandHandle cmd,int from_qid,int to_qid)1293 int VTADepPop(VTACommandHandle cmd, int from_qid, int to_qid) {
1294   static_cast<vta::CommandQueue*>(cmd)->DepPop(from_qid, to_qid);
1295   return 0;
1296 }
1297 
VTASynchronize(VTACommandHandle cmd,uint32_t wait_cycles)1298 void VTASynchronize(VTACommandHandle cmd, uint32_t wait_cycles) {
1299   static_cast<vta::CommandQueue*>(cmd)->Synchronize(wait_cycles);
1300 }
1301