1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements.  See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership.  The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License.  You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied.  See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20package vta.core
21
22import chisel3._
23import chisel3.util._
24import vta.util.config._
25import vta.shell._
26
27/** Compute.
28  *
29  * The compute unit is in charge of the following:
30  * - Loading micro-ops from memory (loadUop module)
31  * - Loading biases (acc) from memory (tensorAcc module)
32  * - Compute ALU instructions (tensorAlu module)
33  * - Compute GEMM instructions (tensorGemm module)
34  */
35class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module {
36  val mp = p(ShellKey).memParams
37  val io = IO(new Bundle {
38    val i_post = Vec(2, Input(Bool()))
39    val o_post = Vec(2, Output(Bool()))
40    val inst = Flipped(Decoupled(UInt(INST_BITS.W)))
41    val uop_baddr = Input(UInt(mp.addrBits.W))
42    val acc_baddr = Input(UInt(mp.addrBits.W))
43    val vme_rd = Vec(2, new VMEReadMaster)
44    val inp = new TensorMaster(tensorType = "inp")
45    val wgt = new TensorMaster(tensorType = "wgt")
46    val out = new TensorMaster(tensorType = "out")
47    val finish = Output(Bool())
48  })
49  val sIdle :: sSync :: sExe :: Nil = Enum(3)
50  val state = RegInit(sIdle)
51
52  val s = Seq.tabulate(2)(_ =>
53    Module(new Semaphore(counterBits = 8, counterInitValue = 0)))
54
55  val loadUop = Module(new LoadUop)
56  val tensorAcc = Module(new TensorLoad(tensorType = "acc"))
57  val tensorGemm = Module(new TensorGemm)
58  val tensorAlu = Module(new TensorAlu)
59
60  val inst_q = Module(new Queue(UInt(INST_BITS.W), p(CoreKey).instQueueEntries))
61
62  // decode
63  val dec = Module(new ComputeDecode)
64  dec.io.inst := inst_q.io.deq.bits
65
66  val inst_type =
67    Cat(dec.io.isFinish,
68        dec.io.isAlu,
69        dec.io.isGemm,
70        dec.io.isLoadAcc,
71        dec.io.isLoadUop).asUInt
72
73  val sprev = inst_q.io.deq.valid & Mux(dec.io.pop_prev, s(0).io.sready, true.B)
74  val snext = inst_q.io.deq.valid & Mux(dec.io.pop_next, s(1).io.sready, true.B)
75  val start = snext & sprev
76  val done =
77    MuxLookup(
78      inst_type,
79      false.B, // default
80      Array(
81        "h_01".U -> loadUop.io.done,
82        "h_02".U -> tensorAcc.io.done,
83        "h_04".U -> tensorGemm.io.done,
84        "h_08".U -> tensorAlu.io.done,
85        "h_10".U -> true.B // Finish
86      )
87    )
88
89  // control
90  switch(state) {
91    is(sIdle) {
92      when(start) {
93        when(dec.io.isSync) {
94          state := sSync
95        }.elsewhen(inst_type.orR) {
96          state := sExe
97        }
98      }
99    }
100    is(sSync) {
101      state := sIdle
102    }
103    is(sExe) {
104      when(done) {
105        state := sIdle
106      }
107    }
108  }
109
110  // instructions
111  inst_q.io.enq <> io.inst
112  inst_q.io.deq.ready := (state === sExe & done) | (state === sSync)
113
114  // uop
115  loadUop.io.start := state === sIdle & start & dec.io.isLoadUop
116  loadUop.io.inst := inst_q.io.deq.bits
117  loadUop.io.baddr := io.uop_baddr
118  io.vme_rd(0) <> loadUop.io.vme_rd
119  loadUop.io.uop.idx <> Mux(dec.io.isGemm,
120                            tensorGemm.io.uop.idx,
121                            tensorAlu.io.uop.idx)
122
123  // acc
124  tensorAcc.io.start := state === sIdle & start & dec.io.isLoadAcc
125  tensorAcc.io.inst := inst_q.io.deq.bits
126  tensorAcc.io.baddr := io.acc_baddr
127  tensorAcc.io.tensor.rd.idx <> Mux(dec.io.isGemm,
128                                    tensorGemm.io.acc.rd.idx,
129                                    tensorAlu.io.acc.rd.idx)
130  tensorAcc.io.tensor.wr <> Mux(dec.io.isGemm,
131                                tensorGemm.io.acc.wr,
132                                tensorAlu.io.acc.wr)
133  io.vme_rd(1) <> tensorAcc.io.vme_rd
134
135  // gemm
136  tensorGemm.io.start := state === sIdle & start & dec.io.isGemm
137  tensorGemm.io.inst := inst_q.io.deq.bits
138  tensorGemm.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isGemm
139  tensorGemm.io.uop.data.bits <> loadUop.io.uop.data.bits
140  tensorGemm.io.inp <> io.inp
141  tensorGemm.io.wgt <> io.wgt
142  tensorGemm.io.acc.rd.data.valid := tensorAcc.io.tensor.rd.data.valid & dec.io.isGemm
143  tensorGemm.io.acc.rd.data.bits <> tensorAcc.io.tensor.rd.data.bits
144  tensorGemm.io.out.rd.data.valid := io.out.rd.data.valid & dec.io.isGemm
145  tensorGemm.io.out.rd.data.bits <> io.out.rd.data.bits
146
147  // alu
148  tensorAlu.io.start := state === sIdle & start & dec.io.isAlu
149  tensorAlu.io.inst := inst_q.io.deq.bits
150  tensorAlu.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isAlu
151  tensorAlu.io.uop.data.bits <> loadUop.io.uop.data.bits
152  tensorAlu.io.acc.rd.data.valid := tensorAcc.io.tensor.rd.data.valid & dec.io.isAlu
153  tensorAlu.io.acc.rd.data.bits <> tensorAcc.io.tensor.rd.data.bits
154  tensorAlu.io.out.rd.data.valid := io.out.rd.data.valid & dec.io.isAlu
155  tensorAlu.io.out.rd.data.bits <> io.out.rd.data.bits
156
157  // out
158  io.out.rd.idx <> Mux(dec.io.isGemm,
159                       tensorGemm.io.out.rd.idx,
160                       tensorAlu.io.out.rd.idx)
161  io.out.wr <> Mux(dec.io.isGemm, tensorGemm.io.out.wr, tensorAlu.io.out.wr)
162
163  // semaphore
164  s(0).io.spost := io.i_post(0)
165  s(1).io.spost := io.i_post(1)
166  s(0).io.swait := dec.io.pop_prev & (state === sIdle & start)
167  s(1).io.swait := dec.io.pop_next & (state === sIdle & start)
168  io.o_post(0) := dec.io.push_prev & ((state === sExe & done) | (state === sSync))
169  io.o_post(1) := dec.io.push_next & ((state === sExe & done) | (state === sSync))
170
171  // finish
172  io.finish := state === sExe & done & dec.io.isFinish
173
174  // debug
175  if (debug) {
176    // start
177    when(state === sIdle && start) {
178      when(dec.io.isSync) {
179        printf("[Compute] start sync\n")
180      }.elsewhen(dec.io.isLoadUop) {
181          printf("[Compute] start load uop\n")
182        }
183        .elsewhen(dec.io.isLoadAcc) {
184          printf("[Compute] start load acc\n")
185        }
186        .elsewhen(dec.io.isGemm) {
187          printf("[Compute] start gemm\n")
188        }
189        .elsewhen(dec.io.isAlu) {
190          printf("[Compute] start alu\n")
191        }
192        .elsewhen(dec.io.isFinish) {
193          printf("[Compute] start finish\n")
194        }
195    }
196    // done
197    when(state === sSync) {
198      printf("[Compute] done sync\n")
199    }
200    when(state === sExe) {
201      when(done) {
202        when(dec.io.isLoadUop) {
203          printf("[Compute] done load uop\n")
204        }.elsewhen(dec.io.isLoadAcc) {
205            printf("[Compute] done load acc\n")
206          }
207          .elsewhen(dec.io.isGemm) {
208            printf("[Compute] done gemm\n")
209          }
210          .elsewhen(dec.io.isAlu) {
211            printf("[Compute] done alu\n")
212          }
213          .elsewhen(dec.io.isFinish) {
214            printf("[Compute] done finish\n")
215          }
216      }
217    }
218  }
219}
220