1/* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20package vta.core 21 22import chisel3._ 23import chisel3.util._ 24import vta.util.config._ 25import vta.shell._ 26 27/** Compute. 28 * 29 * The compute unit is in charge of the following: 30 * - Loading micro-ops from memory (loadUop module) 31 * - Loading biases (acc) from memory (tensorAcc module) 32 * - Compute ALU instructions (tensorAlu module) 33 * - Compute GEMM instructions (tensorGemm module) 34 */ 35class Compute(debug: Boolean = false)(implicit p: Parameters) extends Module { 36 val mp = p(ShellKey).memParams 37 val io = IO(new Bundle { 38 val i_post = Vec(2, Input(Bool())) 39 val o_post = Vec(2, Output(Bool())) 40 val inst = Flipped(Decoupled(UInt(INST_BITS.W))) 41 val uop_baddr = Input(UInt(mp.addrBits.W)) 42 val acc_baddr = Input(UInt(mp.addrBits.W)) 43 val vme_rd = Vec(2, new VMEReadMaster) 44 val inp = new TensorMaster(tensorType = "inp") 45 val wgt = new TensorMaster(tensorType = "wgt") 46 val out = new TensorMaster(tensorType = "out") 47 val finish = Output(Bool()) 48 }) 49 val sIdle :: sSync :: sExe :: Nil = Enum(3) 50 val state = RegInit(sIdle) 51 52 val s = Seq.tabulate(2)(_ => 53 Module(new Semaphore(counterBits = 8, counterInitValue = 0))) 54 55 val loadUop = Module(new LoadUop) 56 val tensorAcc = Module(new TensorLoad(tensorType = "acc")) 57 val tensorGemm = Module(new TensorGemm) 58 val tensorAlu = Module(new TensorAlu) 59 60 val inst_q = Module(new Queue(UInt(INST_BITS.W), p(CoreKey).instQueueEntries)) 61 62 // decode 63 val dec = Module(new ComputeDecode) 64 dec.io.inst := inst_q.io.deq.bits 65 66 val inst_type = 67 Cat(dec.io.isFinish, 68 dec.io.isAlu, 69 dec.io.isGemm, 70 dec.io.isLoadAcc, 71 dec.io.isLoadUop).asUInt 72 73 val sprev = inst_q.io.deq.valid & Mux(dec.io.pop_prev, s(0).io.sready, true.B) 74 val snext = inst_q.io.deq.valid & Mux(dec.io.pop_next, s(1).io.sready, true.B) 75 val start = snext & sprev 76 val done = 77 MuxLookup( 78 inst_type, 79 false.B, // default 80 Array( 81 "h_01".U -> loadUop.io.done, 82 "h_02".U -> tensorAcc.io.done, 83 "h_04".U -> tensorGemm.io.done, 84 "h_08".U -> tensorAlu.io.done, 85 "h_10".U -> true.B // Finish 86 ) 87 ) 88 89 // control 90 switch(state) { 91 is(sIdle) { 92 when(start) { 93 when(dec.io.isSync) { 94 state := sSync 95 }.elsewhen(inst_type.orR) { 96 state := sExe 97 } 98 } 99 } 100 is(sSync) { 101 state := sIdle 102 } 103 is(sExe) { 104 when(done) { 105 state := sIdle 106 } 107 } 108 } 109 110 // instructions 111 inst_q.io.enq <> io.inst 112 inst_q.io.deq.ready := (state === sExe & done) | (state === sSync) 113 114 // uop 115 loadUop.io.start := state === sIdle & start & dec.io.isLoadUop 116 loadUop.io.inst := inst_q.io.deq.bits 117 loadUop.io.baddr := io.uop_baddr 118 io.vme_rd(0) <> loadUop.io.vme_rd 119 loadUop.io.uop.idx <> Mux(dec.io.isGemm, 120 tensorGemm.io.uop.idx, 121 tensorAlu.io.uop.idx) 122 123 // acc 124 tensorAcc.io.start := state === sIdle & start & dec.io.isLoadAcc 125 tensorAcc.io.inst := inst_q.io.deq.bits 126 tensorAcc.io.baddr := io.acc_baddr 127 tensorAcc.io.tensor.rd.idx <> Mux(dec.io.isGemm, 128 tensorGemm.io.acc.rd.idx, 129 tensorAlu.io.acc.rd.idx) 130 tensorAcc.io.tensor.wr <> Mux(dec.io.isGemm, 131 tensorGemm.io.acc.wr, 132 tensorAlu.io.acc.wr) 133 io.vme_rd(1) <> tensorAcc.io.vme_rd 134 135 // gemm 136 tensorGemm.io.start := state === sIdle & start & dec.io.isGemm 137 tensorGemm.io.inst := inst_q.io.deq.bits 138 tensorGemm.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isGemm 139 tensorGemm.io.uop.data.bits <> loadUop.io.uop.data.bits 140 tensorGemm.io.inp <> io.inp 141 tensorGemm.io.wgt <> io.wgt 142 tensorGemm.io.acc.rd.data.valid := tensorAcc.io.tensor.rd.data.valid & dec.io.isGemm 143 tensorGemm.io.acc.rd.data.bits <> tensorAcc.io.tensor.rd.data.bits 144 tensorGemm.io.out.rd.data.valid := io.out.rd.data.valid & dec.io.isGemm 145 tensorGemm.io.out.rd.data.bits <> io.out.rd.data.bits 146 147 // alu 148 tensorAlu.io.start := state === sIdle & start & dec.io.isAlu 149 tensorAlu.io.inst := inst_q.io.deq.bits 150 tensorAlu.io.uop.data.valid := loadUop.io.uop.data.valid & dec.io.isAlu 151 tensorAlu.io.uop.data.bits <> loadUop.io.uop.data.bits 152 tensorAlu.io.acc.rd.data.valid := tensorAcc.io.tensor.rd.data.valid & dec.io.isAlu 153 tensorAlu.io.acc.rd.data.bits <> tensorAcc.io.tensor.rd.data.bits 154 tensorAlu.io.out.rd.data.valid := io.out.rd.data.valid & dec.io.isAlu 155 tensorAlu.io.out.rd.data.bits <> io.out.rd.data.bits 156 157 // out 158 io.out.rd.idx <> Mux(dec.io.isGemm, 159 tensorGemm.io.out.rd.idx, 160 tensorAlu.io.out.rd.idx) 161 io.out.wr <> Mux(dec.io.isGemm, tensorGemm.io.out.wr, tensorAlu.io.out.wr) 162 163 // semaphore 164 s(0).io.spost := io.i_post(0) 165 s(1).io.spost := io.i_post(1) 166 s(0).io.swait := dec.io.pop_prev & (state === sIdle & start) 167 s(1).io.swait := dec.io.pop_next & (state === sIdle & start) 168 io.o_post(0) := dec.io.push_prev & ((state === sExe & done) | (state === sSync)) 169 io.o_post(1) := dec.io.push_next & ((state === sExe & done) | (state === sSync)) 170 171 // finish 172 io.finish := state === sExe & done & dec.io.isFinish 173 174 // debug 175 if (debug) { 176 // start 177 when(state === sIdle && start) { 178 when(dec.io.isSync) { 179 printf("[Compute] start sync\n") 180 }.elsewhen(dec.io.isLoadUop) { 181 printf("[Compute] start load uop\n") 182 } 183 .elsewhen(dec.io.isLoadAcc) { 184 printf("[Compute] start load acc\n") 185 } 186 .elsewhen(dec.io.isGemm) { 187 printf("[Compute] start gemm\n") 188 } 189 .elsewhen(dec.io.isAlu) { 190 printf("[Compute] start alu\n") 191 } 192 .elsewhen(dec.io.isFinish) { 193 printf("[Compute] start finish\n") 194 } 195 } 196 // done 197 when(state === sSync) { 198 printf("[Compute] done sync\n") 199 } 200 when(state === sExe) { 201 when(done) { 202 when(dec.io.isLoadUop) { 203 printf("[Compute] done load uop\n") 204 }.elsewhen(dec.io.isLoadAcc) { 205 printf("[Compute] done load acc\n") 206 } 207 .elsewhen(dec.io.isGemm) { 208 printf("[Compute] done gemm\n") 209 } 210 .elsewhen(dec.io.isAlu) { 211 printf("[Compute] done alu\n") 212 } 213 .elsewhen(dec.io.isFinish) { 214 printf("[Compute] done finish\n") 215 } 216 } 217 } 218 } 219} 220