1/* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20package vta.core 21 22import chisel3._ 23import chisel3.util._ 24import chisel3.experimental._ 25import vta.util.config._ 26import scala.math.pow 27 28/** Pipelined multiply and accumulate */ 29class MAC(aBits: Int = 8, bBits: Int = 8, cBits: Int = 16) extends Module { 30 val outBits = Math.max(aBits + bBits, cBits) + 1 31 val io = IO(new Bundle { 32 val a = Input(SInt(aBits.W)) 33 val b = Input(SInt(bBits.W)) 34 val c = Input(SInt(cBits.W)) 35 val y = Output(SInt(outBits.W)) 36 }) 37 val mult = Wire(SInt((aBits + bBits).W)) 38 val add = Wire(SInt(outBits.W)) 39 val rA = RegNext(io.a) 40 val rB = RegNext(io.b) 41 val rC = RegNext(io.c) 42 43 mult := rA * rB 44 add := rC +& mult 45 46 io.y := add 47} 48 49/** PipeAdder 50 * 51 * This unit loads input bits into register and performs addition in the next cycle 52 */ 53class PipeAdder(aBits: Int = 8, bBits: Int = 8) extends Module { 54 val outBits = Math.max(aBits, bBits) + 1 55 val io = IO(new Bundle { 56 val a = Input(SInt(aBits.W)) 57 val b = Input(SInt(bBits.W)) 58 val y = Output(SInt(outBits.W)) 59 }) 60 val add = Wire(SInt(outBits.W)) 61 val rA = RegNext(io.a) 62 val rB = RegNext(io.b) 63 add := rA +& rB 64 io.y := add 65} 66 67/** Adder 68 * 69 * This unit wires input bits to an adder directly. 70 * The output comes out of combinational logic without waiting for another cycle. 71 */ 72class Adder(aBits: Int = 8, bBits: Int = 8) extends Module { 73 val outBits = Math.max(aBits, bBits) + 1 74 val io = IO(new Bundle { 75 val a = Input(SInt(aBits.W)) 76 val b = Input(SInt(bBits.W)) 77 val y = Output(SInt(outBits.W)) 78 }) 79 val add = Wire(SInt(outBits.W)) 80 val rA = Wire(SInt(aBits.W)) 81 val rB = Wire(SInt(bBits.W)) 82 rA := io.a 83 rB := io.b 84 add := rA +& rB 85 io.y := add 86} 87 88/** Pipelined DotProduct based on MAC and PipeAdder */ 89class DotProduct(aBits: Int = 8, bBits: Int = 8, size: Int = 16) 90 extends Module { 91 val errorMsg = 92 s"\n\n[VTA] [DotProduct] size must be greater than 4 and a power of 2\n\n" 93 require(size >= 2 && isPow2(size), errorMsg) 94 val b = aBits + bBits 95 val outBits = b + log2Ceil(size) + 1 96 val io = IO(new Bundle { 97 val a = Input(Vec(size, SInt(aBits.W))) 98 val b = Input(Vec(size, SInt(bBits.W))) 99 val y = Output(SInt(outBits.W)) 100 }) 101 val s = Seq.tabulate(log2Ceil(size + 1))(i => 102 pow(2, log2Ceil(size) - i).toInt) // # of total layers 103 val p = log2Ceil(size / 2) + 1 // # of adder layers 104 val m = Seq.fill(s(0))(Module(new MAC(aBits, bBits, cBits = 1))) // # of total vector pairs 105 val a = Seq.tabulate(p)( 106 i => 107 Seq.fill(s(i + 1))( 108 if (i == 0) 109 Module(new PipeAdder(aBits = (b + i + 1), bBits = (b + i + 1))) 110 else 111 Module(new Adder(aBits = (b + i + 1), bBits = (b + i + 1))))) // # adders within each layer 112 113 // Vector MACs 114 for (i <- 0 until s(0)) { 115 m(i).io.a := io.a(i) 116 m(i).io.b := io.b(i) 117 m(i).io.c := 0.S 118 } 119 120 // PipeAdder Reduction 121 for (i <- 0 until p) { 122 for (j <- 0 until s(i + 1)) { 123 if (i == 0) { 124 // First layer of PipeAdders 125 a(i)(j).io.a := m(2 * j).io.y 126 a(i)(j).io.b := m(2 * j + 1).io.y 127 } else { 128 a(i)(j).io.a := a(i - 1)(2 * j).io.y 129 a(i)(j).io.b := a(i - 1)(2 * j + 1).io.y 130 } 131 } 132 } 133 134 // last adder 135 io.y := a(p - 1)(0).io.y 136} 137 138/** Perform matrix-vector-multiplication based on DotProduct */ 139class MatrixVectorMultiplication(implicit p: Parameters) extends Module { 140 val accBits = p(CoreKey).accBits 141 val size = p(CoreKey).blockOut 142 val inpBits = p(CoreKey).inpBits 143 val wgtBits = p(CoreKey).wgtBits 144 val outBits = p(CoreKey).outBits 145 val io = IO(new Bundle { 146 val reset = Input(Bool()) // FIXME: reset should be replaced by a load-acc instr 147 val inp = new TensorMasterData(tensorType = "inp") 148 val wgt = new TensorMasterData(tensorType = "wgt") 149 val acc_i = new TensorMasterData(tensorType = "acc") 150 val acc_o = new TensorClientData(tensorType = "acc") 151 val out = new TensorClientData(tensorType = "out") 152 }) 153 val dot = Seq.fill(size)( 154 Module(new DotProduct(aBits = inpBits, bBits = wgtBits, size))) 155 // Latency is defined as two in the following, because there is one cycle in the MAC module, 156 // and another cycle in the pipelined adders as the first layer of the accumulator 157 val acc = Seq.fill(size)(Module(new Pipe(UInt(accBits.W), latency = 2))) 158 val add = Seq.fill(size)(Wire(SInt(accBits.W))) 159 val vld = Wire(Vec(size, Bool())) 160 161 for (i <- 0 until size) { 162 acc(i).io.enq.valid := io.inp.data.valid & io.wgt.data.valid & io.acc_i.data.valid & ~io.reset 163 acc(i).io.enq.bits := io.acc_i.data.bits(0)(i) 164 for (j <- 0 until size) { 165 dot(i).io.a(j) := io.inp.data.bits(0)(j).asSInt 166 dot(i).io.b(j) := io.wgt.data.bits(i)(j).asSInt 167 } 168 add(i) := acc(i).io.deq.bits.asSInt + dot(i).io.y 169 io.acc_o.data.bits(0)(i) := Mux(io.reset, 0.U, add(i).asUInt) 170 io.out.data.bits(0)(i) := add(i).asUInt 171 vld(i) := acc(i).io.deq.valid 172 } 173 io.acc_o.data.valid := vld.asUInt.andR | io.reset 174 io.out.data.valid := vld.asUInt.andR 175} 176 177/** TensorGemm. 178 * 179 * This unit instantiate the MatrixVectorMultiplication and go over the 180 * micro-ops (uops) which are used to read inputs, weights and biases, 181 * and writes results back to the acc and out scratchpads. 182 * 183 * Also, the TensorGemm uses the reset field in the Gemm instruction to 184 * clear or zero-out the acc-scratchpad locations based on the micro-ops. 185 */ 186class TensorGemm(debug: Boolean = false)(implicit p: Parameters) 187 extends Module { 188 val io = IO(new Bundle { 189 val start = Input(Bool()) 190 val done = Output(Bool()) 191 val inst = Input(UInt(INST_BITS.W)) 192 val uop = new UopMaster 193 val inp = new TensorMaster(tensorType = "inp") 194 val wgt = new TensorMaster(tensorType = "wgt") 195 val acc = new TensorMaster(tensorType = "acc") 196 val out = new TensorMaster(tensorType = "out") 197 }) 198 val sIdle :: sReadUop :: sComputeIdx :: sReadTensor :: sExe :: sWait :: Nil = 199 Enum(6) 200 val state = RegInit(sIdle) 201 val mvc = Module(new MatrixVectorMultiplication) 202 val dec = io.inst.asTypeOf(new GemmDecode) 203 val uop_idx = Reg(chiselTypeOf(dec.uop_end)) 204 val uop_end = dec.uop_end 205 val uop_acc = Reg(chiselTypeOf(dec.uop_end)) 206 val uop_inp = Reg(chiselTypeOf(dec.uop_end)) 207 val uop_wgt = Reg(chiselTypeOf(dec.uop_end)) 208 val cnt_o = Reg(chiselTypeOf(dec.lp_0)) 209 val acc_o = Reg(chiselTypeOf(dec.uop_end)) 210 val inp_o = Reg(chiselTypeOf(dec.uop_end)) 211 val wgt_o = Reg(chiselTypeOf(dec.uop_end)) 212 val cnt_i = Reg(chiselTypeOf(dec.lp_1)) 213 val acc_i = Reg(chiselTypeOf(dec.uop_end)) 214 val inp_i = Reg(chiselTypeOf(dec.uop_end)) 215 val wgt_i = Reg(chiselTypeOf(dec.uop_end)) 216 val pBits = log2Ceil(p(CoreKey).blockOut) + 1 217 val inflight = Reg(UInt(pBits.W)) 218 // Latency is defined as two in the following, because there is one cycle in the MAC module, 219 // and another cycle in the pipelined adders as the first layer of the accumulator 220 val wrpipe = Module(new Pipe(chiselTypeOf(dec.uop_end), latency = 2)) 221 val done = inflight === 0.U & 222 ((state === sExe & 223 cnt_o === dec.lp_0 - 1.U & 224 cnt_i === dec.lp_1 - 1.U & 225 uop_idx === uop_end - 1.U & 226 inflight === 0.U) | 227 state === sWait) 228 229 switch(state) { 230 is(sIdle) { 231 when(io.start) { 232 state := sReadUop 233 } 234 } 235 is(sReadUop) { 236 state := sComputeIdx 237 } 238 is(sComputeIdx) { 239 state := sReadTensor 240 } 241 is(sReadTensor) { 242 state := sExe 243 } 244 is(sExe) { 245 when( 246 (cnt_o === dec.lp_0 - 1.U) && 247 (cnt_i === dec.lp_1 - 1.U) && 248 (uop_idx === uop_end - 1.U)) { 249 when(inflight =/= 0.U) { 250 state := sWait 251 }.otherwise { 252 state := sIdle 253 } 254 }.otherwise { 255 state := sReadUop 256 } 257 } 258 is(sWait) { 259 when(inflight === 0.U) { 260 state := sIdle 261 } 262 } 263 } 264 265 when(state === sIdle) { 266 inflight := 0.U 267 }.elsewhen(!dec.reset) { 268 when((state === sReadTensor) && mvc.io.acc_o.data.valid) { // issue & commit 269 inflight := inflight 270 }.elsewhen(state === sReadTensor) { // issue a tensor 271 inflight := inflight + 1.U 272 } 273 .elsewhen(mvc.io.acc_o.data.valid) { // commit a tensor 274 inflight := inflight - 1.U 275 } 276 } 277 278 when( 279 state === sIdle || 280 (state === sExe && 281 uop_idx === uop_end - 1.U)) { 282 uop_idx := dec.uop_begin 283 }.elsewhen(state === sExe && dec.uop_begin =/= uop_end) { 284 uop_idx := uop_idx + 1.U 285 } 286 287 when(state === sIdle) { 288 cnt_o := 0.U 289 acc_o := 0.U 290 inp_o := 0.U 291 wgt_o := 0.U 292 }.elsewhen( 293 state === sExe && 294 uop_idx === uop_end - 1.U && 295 cnt_i === dec.lp_1 - 1.U) { 296 cnt_o := cnt_o + 1.U 297 acc_o := acc_o + dec.acc_0 298 inp_o := inp_o + dec.inp_0 299 wgt_o := wgt_o + dec.wgt_0 300 } 301 302 when(state === sIdle) { 303 cnt_i := 0.U 304 acc_i := 0.U 305 inp_i := 0.U 306 wgt_i := 0.U 307 }.elsewhen(state === sReadUop && cnt_i === dec.lp_1) { 308 cnt_i := 0.U 309 acc_i := acc_o 310 inp_i := inp_o 311 wgt_i := wgt_o 312 } 313 .elsewhen(state === sExe && uop_idx === uop_end - 1.U) { 314 cnt_i := cnt_i + 1.U 315 acc_i := acc_i + dec.acc_1 316 inp_i := inp_i + dec.inp_1 317 wgt_i := wgt_i + dec.wgt_1 318 } 319 320 when(state === sComputeIdx && io.uop.data.valid) { 321 uop_acc := io.uop.data.bits.u0 + acc_i 322 uop_inp := io.uop.data.bits.u1 + inp_i 323 uop_wgt := io.uop.data.bits.u2 + wgt_i 324 } 325 326 wrpipe.io.enq.valid := state === sExe & ~dec.reset 327 wrpipe.io.enq.bits := uop_acc 328 329 // uop 330 io.uop.idx.valid := state === sReadUop 331 io.uop.idx.bits := uop_idx 332 333 // inp 334 io.inp.rd.idx.valid := state === sReadTensor 335 io.inp.rd.idx.bits := uop_inp 336 io.inp.tieoffWrite() // read-only 337 338 // wgt 339 io.wgt.rd.idx.valid := state === sReadTensor 340 io.wgt.rd.idx.bits := uop_wgt 341 io.wgt.tieoffWrite() // read-only 342 343 // acc_i 344 io.acc.rd.idx.valid := state === sReadTensor 345 io.acc.rd.idx.bits := uop_acc 346 347 // mvc 348 mvc.io.reset := dec.reset & state === sExe 349 mvc.io.inp.data <> io.inp.rd.data 350 mvc.io.wgt.data <> io.wgt.rd.data 351 mvc.io.acc_i.data <> io.acc.rd.data 352 353 // acc_o 354 io.acc.wr.valid := mvc.io.acc_o.data.valid & Mux(dec.reset, 355 true.B, 356 wrpipe.io.deq.valid) 357 io.acc.wr.bits.idx := Mux(dec.reset, uop_acc, wrpipe.io.deq.bits) 358 io.acc.wr.bits.data <> mvc.io.acc_o.data.bits 359 360 // out 361 io.out.wr.valid := mvc.io.out.data.valid & wrpipe.io.deq.valid 362 io.out.wr.bits.idx := wrpipe.io.deq.bits 363 io.out.wr.bits.data <> mvc.io.out.data.bits 364 io.out.tieoffRead() // write-only 365 366 io.done := done 367 368 if (debug) { 369 when(state === sReadUop && ~dec.reset) { 370 printf("[TensorGemm] [uop] idx:%x\n", uop_idx) 371 } 372 373 when(state === sReadTensor && ~dec.reset) { 374 printf("[TensorGemm] [uop] acc:%x inp:%x wgt:%x\n", 375 uop_acc, 376 uop_inp, 377 uop_wgt) 378 } 379 380 io.inp.rd.data.bits.zipWithIndex.foreach { 381 case (r, i) => 382 when(io.inp.rd.data.valid && ~dec.reset) { 383 printf("[TensorGemm] [inp] i:%x val:%x\n", i.U, r.asUInt) 384 } 385 } 386 387 io.wgt.rd.data.bits.zipWithIndex.foreach { 388 case (r, i) => 389 when(io.wgt.rd.data.valid && ~dec.reset) { 390 printf("[TensorGemm] [wgt] i:%x val:%x\n", i.U, r.asUInt) 391 } 392 } 393 394 io.acc.rd.data.bits.foreach { tensor => 395 tensor.zipWithIndex.foreach { 396 case (elem, i) => 397 when(io.acc.rd.data.valid && ~dec.reset) { 398 printf("[TensorGemm] [acc_i] i:%x val:%x\n", i.U, elem) 399 } 400 } 401 } 402 403 mvc.io.acc_o.data.bits.foreach { tensor => 404 tensor.zipWithIndex.foreach { 405 case (elem, i) => 406 when(mvc.io.acc_o.data.valid && ~dec.reset) { 407 printf("[TensorGemm] [acc_o] i:%x val:%x\n", i.U, elem) 408 } 409 } 410 } 411 412 mvc.io.out.data.bits.foreach { tensor => 413 tensor.zipWithIndex.foreach { 414 case (elem, i) => 415 when(mvc.io.out.data.valid && ~dec.reset) { 416 printf("[TensorGemm] [out] i:%x val:%x\n", i.U, elem) 417 } 418 } 419 } 420 } 421} 422