1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements.  See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership.  The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License.  You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied.  See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20package vta.core
21
22import chisel3._
23import chisel3.util._
24import chisel3.experimental._
25import vta.util.config._
26import scala.math.pow
27
28/** Pipelined multiply and accumulate */
29class MAC(aBits: Int = 8, bBits: Int = 8, cBits: Int = 16) extends Module {
30  val outBits = Math.max(aBits + bBits, cBits) + 1
31  val io = IO(new Bundle {
32    val a = Input(SInt(aBits.W))
33    val b = Input(SInt(bBits.W))
34    val c = Input(SInt(cBits.W))
35    val y = Output(SInt(outBits.W))
36  })
37  val mult = Wire(SInt((aBits + bBits).W))
38  val add = Wire(SInt(outBits.W))
39  val rA = RegNext(io.a)
40  val rB = RegNext(io.b)
41  val rC = RegNext(io.c)
42
43  mult := rA * rB
44  add := rC +& mult
45
46  io.y := add
47}
48
49/** PipeAdder
50  *
51  * This unit loads input bits into register and performs addition in the next cycle
52  */
53class PipeAdder(aBits: Int = 8, bBits: Int = 8) extends Module {
54  val outBits = Math.max(aBits, bBits) + 1
55  val io = IO(new Bundle {
56    val a = Input(SInt(aBits.W))
57    val b = Input(SInt(bBits.W))
58    val y = Output(SInt(outBits.W))
59  })
60  val add = Wire(SInt(outBits.W))
61  val rA = RegNext(io.a)
62  val rB = RegNext(io.b)
63  add := rA +& rB
64  io.y := add
65}
66
67/** Adder
68  *
69  * This unit wires input bits to an adder directly.
70  * The output comes out of combinational logic without waiting for another cycle.
71  */
72class Adder(aBits: Int = 8, bBits: Int = 8) extends Module {
73  val outBits = Math.max(aBits, bBits) + 1
74  val io = IO(new Bundle {
75    val a = Input(SInt(aBits.W))
76    val b = Input(SInt(bBits.W))
77    val y = Output(SInt(outBits.W))
78  })
79  val add = Wire(SInt(outBits.W))
80  val rA = Wire(SInt(aBits.W))
81  val rB = Wire(SInt(bBits.W))
82  rA := io.a
83  rB := io.b
84  add := rA +& rB
85  io.y := add
86}
87
88/** Pipelined DotProduct based on MAC and PipeAdder */
89class DotProduct(aBits: Int = 8, bBits: Int = 8, size: Int = 16)
90    extends Module {
91  val errorMsg =
92    s"\n\n[VTA] [DotProduct] size must be greater than 4 and a power of 2\n\n"
93  require(size >= 2 && isPow2(size), errorMsg)
94  val b = aBits + bBits
95  val outBits = b + log2Ceil(size) + 1
96  val io = IO(new Bundle {
97    val a = Input(Vec(size, SInt(aBits.W)))
98    val b = Input(Vec(size, SInt(bBits.W)))
99    val y = Output(SInt(outBits.W))
100  })
101  val s = Seq.tabulate(log2Ceil(size + 1))(i =>
102    pow(2, log2Ceil(size) - i).toInt) // # of total layers
103  val p = log2Ceil(size / 2) + 1 // # of adder layers
104  val m = Seq.fill(s(0))(Module(new MAC(aBits, bBits, cBits = 1))) // # of total vector pairs
105  val a = Seq.tabulate(p)(
106    i =>
107      Seq.fill(s(i + 1))(
108        if (i == 0)
109          Module(new PipeAdder(aBits = (b + i + 1), bBits = (b + i + 1)))
110        else
111          Module(new Adder(aBits = (b + i + 1), bBits = (b + i + 1))))) // # adders within each layer
112
113  // Vector MACs
114  for (i <- 0 until s(0)) {
115    m(i).io.a := io.a(i)
116    m(i).io.b := io.b(i)
117    m(i).io.c := 0.S
118  }
119
120  // PipeAdder Reduction
121  for (i <- 0 until p) {
122    for (j <- 0 until s(i + 1)) {
123      if (i == 0) {
124        // First layer of PipeAdders
125        a(i)(j).io.a := m(2 * j).io.y
126        a(i)(j).io.b := m(2 * j + 1).io.y
127      } else {
128        a(i)(j).io.a := a(i - 1)(2 * j).io.y
129        a(i)(j).io.b := a(i - 1)(2 * j + 1).io.y
130      }
131    }
132  }
133
134  // last adder
135  io.y := a(p - 1)(0).io.y
136}
137
138/** Perform matrix-vector-multiplication based on DotProduct */
139class MatrixVectorMultiplication(implicit p: Parameters) extends Module {
140  val accBits = p(CoreKey).accBits
141  val size = p(CoreKey).blockOut
142  val inpBits = p(CoreKey).inpBits
143  val wgtBits = p(CoreKey).wgtBits
144  val outBits = p(CoreKey).outBits
145  val io = IO(new Bundle {
146    val reset = Input(Bool()) // FIXME: reset should be replaced by a load-acc instr
147    val inp = new TensorMasterData(tensorType = "inp")
148    val wgt = new TensorMasterData(tensorType = "wgt")
149    val acc_i = new TensorMasterData(tensorType = "acc")
150    val acc_o = new TensorClientData(tensorType = "acc")
151    val out = new TensorClientData(tensorType = "out")
152  })
153  val dot = Seq.fill(size)(
154    Module(new DotProduct(aBits = inpBits, bBits = wgtBits, size)))
155  // Latency is defined as two in the following, because there is one cycle in the MAC module,
156  // and another cycle in the pipelined adders as the first layer of the accumulator
157  val acc = Seq.fill(size)(Module(new Pipe(UInt(accBits.W), latency = 2)))
158  val add = Seq.fill(size)(Wire(SInt(accBits.W)))
159  val vld = Wire(Vec(size, Bool()))
160
161  for (i <- 0 until size) {
162    acc(i).io.enq.valid := io.inp.data.valid & io.wgt.data.valid & io.acc_i.data.valid & ~io.reset
163    acc(i).io.enq.bits := io.acc_i.data.bits(0)(i)
164    for (j <- 0 until size) {
165      dot(i).io.a(j) := io.inp.data.bits(0)(j).asSInt
166      dot(i).io.b(j) := io.wgt.data.bits(i)(j).asSInt
167    }
168    add(i) := acc(i).io.deq.bits.asSInt + dot(i).io.y
169    io.acc_o.data.bits(0)(i) := Mux(io.reset, 0.U, add(i).asUInt)
170    io.out.data.bits(0)(i) := add(i).asUInt
171    vld(i) := acc(i).io.deq.valid
172  }
173  io.acc_o.data.valid := vld.asUInt.andR | io.reset
174  io.out.data.valid := vld.asUInt.andR
175}
176
177/** TensorGemm.
178  *
179  * This unit instantiate the MatrixVectorMultiplication and go over the
180  * micro-ops (uops) which are used to read inputs, weights and biases,
181  * and writes results back to the acc and out scratchpads.
182  *
183  * Also, the TensorGemm uses the reset field in the Gemm instruction to
184  * clear or zero-out the acc-scratchpad locations based on the micro-ops.
185  */
186class TensorGemm(debug: Boolean = false)(implicit p: Parameters)
187    extends Module {
188  val io = IO(new Bundle {
189    val start = Input(Bool())
190    val done = Output(Bool())
191    val inst = Input(UInt(INST_BITS.W))
192    val uop = new UopMaster
193    val inp = new TensorMaster(tensorType = "inp")
194    val wgt = new TensorMaster(tensorType = "wgt")
195    val acc = new TensorMaster(tensorType = "acc")
196    val out = new TensorMaster(tensorType = "out")
197  })
198  val sIdle :: sReadUop :: sComputeIdx :: sReadTensor :: sExe :: sWait :: Nil =
199    Enum(6)
200  val state = RegInit(sIdle)
201  val mvc = Module(new MatrixVectorMultiplication)
202  val dec = io.inst.asTypeOf(new GemmDecode)
203  val uop_idx = Reg(chiselTypeOf(dec.uop_end))
204  val uop_end = dec.uop_end
205  val uop_acc = Reg(chiselTypeOf(dec.uop_end))
206  val uop_inp = Reg(chiselTypeOf(dec.uop_end))
207  val uop_wgt = Reg(chiselTypeOf(dec.uop_end))
208  val cnt_o = Reg(chiselTypeOf(dec.lp_0))
209  val acc_o = Reg(chiselTypeOf(dec.uop_end))
210  val inp_o = Reg(chiselTypeOf(dec.uop_end))
211  val wgt_o = Reg(chiselTypeOf(dec.uop_end))
212  val cnt_i = Reg(chiselTypeOf(dec.lp_1))
213  val acc_i = Reg(chiselTypeOf(dec.uop_end))
214  val inp_i = Reg(chiselTypeOf(dec.uop_end))
215  val wgt_i = Reg(chiselTypeOf(dec.uop_end))
216  val pBits = log2Ceil(p(CoreKey).blockOut) + 1
217  val inflight = Reg(UInt(pBits.W))
218  // Latency is defined as two in the following, because there is one cycle in the MAC module,
219  // and another cycle in the pipelined adders as the first layer of the accumulator
220  val wrpipe = Module(new Pipe(chiselTypeOf(dec.uop_end), latency = 2))
221  val done = inflight === 0.U &
222    ((state === sExe &
223      cnt_o === dec.lp_0 - 1.U &
224      cnt_i === dec.lp_1 - 1.U &
225      uop_idx === uop_end - 1.U &
226      inflight === 0.U) |
227      state === sWait)
228
229  switch(state) {
230    is(sIdle) {
231      when(io.start) {
232        state := sReadUop
233      }
234    }
235    is(sReadUop) {
236      state := sComputeIdx
237    }
238    is(sComputeIdx) {
239      state := sReadTensor
240    }
241    is(sReadTensor) {
242      state := sExe
243    }
244    is(sExe) {
245      when(
246        (cnt_o === dec.lp_0 - 1.U) &&
247          (cnt_i === dec.lp_1 - 1.U) &&
248          (uop_idx === uop_end - 1.U)) {
249        when(inflight =/= 0.U) {
250          state := sWait
251        }.otherwise {
252          state := sIdle
253        }
254      }.otherwise {
255        state := sReadUop
256      }
257    }
258    is(sWait) {
259      when(inflight === 0.U) {
260        state := sIdle
261      }
262    }
263  }
264
265  when(state === sIdle) {
266    inflight := 0.U
267  }.elsewhen(!dec.reset) {
268    when((state === sReadTensor) && mvc.io.acc_o.data.valid) { // issue & commit
269      inflight := inflight
270    }.elsewhen(state === sReadTensor) { // issue a tensor
271        inflight := inflight + 1.U
272      }
273      .elsewhen(mvc.io.acc_o.data.valid) { // commit a tensor
274        inflight := inflight - 1.U
275      }
276  }
277
278  when(
279    state === sIdle ||
280      (state === sExe &&
281        uop_idx === uop_end - 1.U)) {
282    uop_idx := dec.uop_begin
283  }.elsewhen(state === sExe && dec.uop_begin =/= uop_end) {
284    uop_idx := uop_idx + 1.U
285  }
286
287  when(state === sIdle) {
288    cnt_o := 0.U
289    acc_o := 0.U
290    inp_o := 0.U
291    wgt_o := 0.U
292  }.elsewhen(
293    state === sExe &&
294      uop_idx === uop_end - 1.U &&
295      cnt_i === dec.lp_1 - 1.U) {
296    cnt_o := cnt_o + 1.U
297    acc_o := acc_o + dec.acc_0
298    inp_o := inp_o + dec.inp_0
299    wgt_o := wgt_o + dec.wgt_0
300  }
301
302  when(state === sIdle) {
303    cnt_i := 0.U
304    acc_i := 0.U
305    inp_i := 0.U
306    wgt_i := 0.U
307  }.elsewhen(state === sReadUop && cnt_i === dec.lp_1) {
308      cnt_i := 0.U
309      acc_i := acc_o
310      inp_i := inp_o
311      wgt_i := wgt_o
312    }
313    .elsewhen(state === sExe && uop_idx === uop_end - 1.U) {
314      cnt_i := cnt_i + 1.U
315      acc_i := acc_i + dec.acc_1
316      inp_i := inp_i + dec.inp_1
317      wgt_i := wgt_i + dec.wgt_1
318    }
319
320  when(state === sComputeIdx && io.uop.data.valid) {
321    uop_acc := io.uop.data.bits.u0 + acc_i
322    uop_inp := io.uop.data.bits.u1 + inp_i
323    uop_wgt := io.uop.data.bits.u2 + wgt_i
324  }
325
326  wrpipe.io.enq.valid := state === sExe & ~dec.reset
327  wrpipe.io.enq.bits := uop_acc
328
329  // uop
330  io.uop.idx.valid := state === sReadUop
331  io.uop.idx.bits := uop_idx
332
333  // inp
334  io.inp.rd.idx.valid := state === sReadTensor
335  io.inp.rd.idx.bits := uop_inp
336  io.inp.tieoffWrite() // read-only
337
338  // wgt
339  io.wgt.rd.idx.valid := state === sReadTensor
340  io.wgt.rd.idx.bits := uop_wgt
341  io.wgt.tieoffWrite() // read-only
342
343  // acc_i
344  io.acc.rd.idx.valid := state === sReadTensor
345  io.acc.rd.idx.bits := uop_acc
346
347  // mvc
348  mvc.io.reset := dec.reset & state === sExe
349  mvc.io.inp.data <> io.inp.rd.data
350  mvc.io.wgt.data <> io.wgt.rd.data
351  mvc.io.acc_i.data <> io.acc.rd.data
352
353  // acc_o
354  io.acc.wr.valid := mvc.io.acc_o.data.valid & Mux(dec.reset,
355                                                   true.B,
356                                                   wrpipe.io.deq.valid)
357  io.acc.wr.bits.idx := Mux(dec.reset, uop_acc, wrpipe.io.deq.bits)
358  io.acc.wr.bits.data <> mvc.io.acc_o.data.bits
359
360  // out
361  io.out.wr.valid := mvc.io.out.data.valid & wrpipe.io.deq.valid
362  io.out.wr.bits.idx := wrpipe.io.deq.bits
363  io.out.wr.bits.data <> mvc.io.out.data.bits
364  io.out.tieoffRead() // write-only
365
366  io.done := done
367
368  if (debug) {
369    when(state === sReadUop && ~dec.reset) {
370      printf("[TensorGemm] [uop] idx:%x\n", uop_idx)
371    }
372
373    when(state === sReadTensor && ~dec.reset) {
374      printf("[TensorGemm] [uop] acc:%x inp:%x wgt:%x\n",
375             uop_acc,
376             uop_inp,
377             uop_wgt)
378    }
379
380    io.inp.rd.data.bits.zipWithIndex.foreach {
381      case (r, i) =>
382        when(io.inp.rd.data.valid && ~dec.reset) {
383          printf("[TensorGemm] [inp] i:%x val:%x\n", i.U, r.asUInt)
384        }
385    }
386
387    io.wgt.rd.data.bits.zipWithIndex.foreach {
388      case (r, i) =>
389        when(io.wgt.rd.data.valid && ~dec.reset) {
390          printf("[TensorGemm] [wgt] i:%x val:%x\n", i.U, r.asUInt)
391        }
392    }
393
394    io.acc.rd.data.bits.foreach { tensor =>
395      tensor.zipWithIndex.foreach {
396        case (elem, i) =>
397          when(io.acc.rd.data.valid && ~dec.reset) {
398            printf("[TensorGemm] [acc_i] i:%x val:%x\n", i.U, elem)
399          }
400      }
401    }
402
403    mvc.io.acc_o.data.bits.foreach { tensor =>
404      tensor.zipWithIndex.foreach {
405        case (elem, i) =>
406          when(mvc.io.acc_o.data.valid && ~dec.reset) {
407            printf("[TensorGemm] [acc_o] i:%x val:%x\n", i.U, elem)
408          }
409      }
410    }
411
412    mvc.io.out.data.bits.foreach { tensor =>
413      tensor.zipWithIndex.foreach {
414        case (elem, i) =>
415          when(mvc.io.out.data.valid && ~dec.reset) {
416            printf("[TensorGemm] [out] i:%x val:%x\n", i.U, elem)
417          }
418      }
419    }
420  }
421}
422