Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zmmul: Multiply without Divide Extension #3391

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/main/scala/rocket/CSR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ class CSRFile(
io.pmp := reg_pmp.map(PMP(_))

val isaMaskString =
(if (usingMulDiv) "M" else "") +
(if (usingMulDiv && coreParams.mulDiv.get.divEnabled) "M" else "") +
(if (usingAtomics) "A" else "") +
(if (fLen >= 32) "F" else "") +
(if (fLen >= 64) "D" else "") +
Expand Down
17 changes: 13 additions & 4 deletions src/main/scala/rocket/IDecode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -244,29 +244,38 @@ class Hypervisor64Decode(aluFn: ALUFN = ALUFN())(implicit val p: Parameters) ext
HLV_WU-> List(Y,N,N,N,N,N,N,Y,N,N,N,N,A2_ZERO, A1_RS1, IMM_X, DW_XPR, aluFn.FN_ADD, Y,M_XRD, N,N,N,N,N,N,Y,CSR.I,N,N,N,N))
}

class MDecode(pipelinedMul: Boolean, aluFn: ALUFN = ALUFN())(implicit val p: Parameters) extends DecodeConstants
class MMulDecode(pipelinedMul: Boolean, aluFn: ALUFN = ALUFN())(implicit val p: Parameters) extends DecodeConstants
{
val M = if (pipelinedMul) Y else N
val D = if (pipelinedMul) N else Y

val table: Array[(BitPat, List[BitPat])] = Array(
MUL-> List(Y,N,N,N,N,N,Y,Y,N,N,N,N,A2_RS2, A1_RS1, IMM_X, DW_XPR,aluFn.FN_MUL, N,M_X, N,N,N,N,M,D,Y,CSR.N,N,N,N,N),
MULH-> List(Y,N,N,N,N,N,Y,Y,N,N,N,N,A2_RS2, A1_RS1, IMM_X, DW_XPR,aluFn.FN_MULH, N,M_X, N,N,N,N,M,D,Y,CSR.N,N,N,N,N),
MULHU-> List(Y,N,N,N,N,N,Y,Y,N,N,N,N,A2_RS2, A1_RS1, IMM_X, DW_XPR,aluFn.FN_MULHU, N,M_X, N,N,N,N,M,D,Y,CSR.N,N,N,N,N),
MULHSU-> List(Y,N,N,N,N,N,Y,Y,N,N,N,N,A2_RS2, A1_RS1, IMM_X, DW_XPR,aluFn.FN_MULHSU,N,M_X, N,N,N,N,M,D,Y,CSR.N,N,N,N,N),
MULHSU-> List(Y,N,N,N,N,N,Y,Y,N,N,N,N,A2_RS2, A1_RS1, IMM_X, DW_XPR,aluFn.FN_MULHSU,N,M_X, N,N,N,N,M,D,Y,CSR.N,N,N,N,N))
}

class MDivDecode(aluFn: ALUFN = ALUFN())(implicit val p: Parameters) extends DecodeConstants
{
val table: Array[(BitPat, List[BitPat])] = Array(
DIV-> List(Y,N,N,N,N,N,Y,Y,N,N,N,N,A2_RS2, A1_RS1, IMM_X, DW_XPR,aluFn.FN_DIV, N,M_X, N,N,N,N,N,Y,Y,CSR.N,N,N,N,N),
DIVU-> List(Y,N,N,N,N,N,Y,Y,N,N,N,N,A2_RS2, A1_RS1, IMM_X, DW_XPR,aluFn.FN_DIVU, N,M_X, N,N,N,N,N,Y,Y,CSR.N,N,N,N,N),
REM-> List(Y,N,N,N,N,N,Y,Y,N,N,N,N,A2_RS2, A1_RS1, IMM_X, DW_XPR,aluFn.FN_REM, N,M_X, N,N,N,N,N,Y,Y,CSR.N,N,N,N,N),
REMU-> List(Y,N,N,N,N,N,Y,Y,N,N,N,N,A2_RS2, A1_RS1, IMM_X, DW_XPR,aluFn.FN_REMU, N,M_X, N,N,N,N,N,Y,Y,CSR.N,N,N,N,N))
}

class M64Decode(pipelinedMul: Boolean, aluFn: ALUFN = ALUFN())(implicit val p: Parameters) extends DecodeConstants
class MMul64Decode(pipelinedMul: Boolean, aluFn: ALUFN = ALUFN())(implicit val p: Parameters) extends DecodeConstants
{
val M = if (pipelinedMul) Y else N
val D = if (pipelinedMul) N else Y
val table: Array[(BitPat, List[BitPat])] = Array(
MULW-> List(Y,N,N,N,N,N,Y,Y,N,N,N,N,A2_RS2, A1_RS1, IMM_X, DW_32, aluFn.FN_MUL, N,M_X, N,N,N,N,M,D,Y,CSR.N,N,N,N,N),
MULW-> List(Y,N,N,N,N,N,Y,Y,N,N,N,N,A2_RS2, A1_RS1, IMM_X, DW_32, aluFn.FN_MUL, N,M_X, N,N,N,N,M,D,Y,CSR.N,N,N,N,N))
}

class MDiv64Decode(aluFn: ALUFN = ALUFN())(implicit val p: Parameters) extends DecodeConstants
{
val table: Array[(BitPat, List[BitPat])] = Array(
DIVW-> List(Y,N,N,N,N,N,Y,Y,N,N,N,N,A2_RS2, A1_RS1, IMM_X, DW_32, aluFn.FN_DIV, N,M_X, N,N,N,N,N,Y,Y,CSR.N,N,N,N,N),
DIVUW-> List(Y,N,N,N,N,N,Y,Y,N,N,N,N,A2_RS2, A1_RS1, IMM_X, DW_32, aluFn.FN_DIVU, N,M_X, N,N,N,N,N,Y,Y,CSR.N,N,N,N,N),
REMW-> List(Y,N,N,N,N,N,Y,Y,N,N,N,N,A2_RS2, A1_RS1, IMM_X, DW_32, aluFn.FN_REM, N,M_X, N,N,N,N,N,Y,Y,CSR.N,N,N,N,N),
Expand Down
11 changes: 6 additions & 5 deletions src/main/scala/rocket/Multiplier.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class MultiplierIO(val dataBits: Int, val tagBits: Int, aluFn: ALUFN = new ALUFN
case class MulDivParams(
mulUnroll: Int = 1,
divUnroll: Int = 1,
divEnabled: Boolean = true,
mulEarlyOut: Boolean = false,
divEarlyOut: Boolean = false,
divEarlyOutGranularity: Int = 1
Expand All @@ -49,7 +50,7 @@ class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32, aluFn: ALUFN = new A

val req = Reg(chiselTypeOf(io.req.bits))
val count = Reg(UInt(log2Ceil(
((cfg.divUnroll != 0).option(w/cfg.divUnroll + 1).toSeq ++
((cfg.divEnabled && cfg.divUnroll != 0).option(w/cfg.divUnroll + 1).toSeq ++
(cfg.mulUnroll != 0).option(mulw/cfg.mulUnroll)).reduce(_ max _)).W))
val neg_out = Reg(Bool())
val isHi = Reg(Bool())
Expand All @@ -69,7 +70,7 @@ class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32, aluFn: ALUFN = new A
aluFn.FN_REMU -> List(N, Y, N, N))
val cmdMul :: cmdHi :: lhsSigned :: rhsSigned :: Nil =
DecodeLogic(io.req.bits.fn, List(X, X, X, X),
(if (cfg.divUnroll != 0) divDecode else Nil) ++ (if (cfg.mulUnroll != 0) mulDecode else Nil)).map(_.asBool)
(if (cfg.divEnabled && cfg.divUnroll != 0) divDecode else Nil) ++ (if (cfg.mulUnroll != 0) mulDecode else Nil)).map(_.asBool)

require(w == 32 || w == 64)
def halfWidth(req: MultiplierReq) = (w > 32).B && req.dw === DW_32
Expand All @@ -86,7 +87,7 @@ class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32, aluFn: ALUFN = new A
val result = Mux(resHi, remainder(2*w, w+1), remainder(w-1, 0))
val negated_remainder = -result

if (cfg.divUnroll != 0) when (state === s_neg_inputs) {
if (cfg.divEnabled && cfg.divUnroll != 0) when (state === s_neg_inputs) {
when (remainder(w-1)) {
remainder := negated_remainder
}
Expand All @@ -95,7 +96,7 @@ class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32, aluFn: ALUFN = new A
}
state := s_div
}
if (cfg.divUnroll != 0) when (state === s_neg_output) {
if (cfg.divEnabled && cfg.divUnroll != 0) when (state === s_neg_output) {
remainder := negated_remainder
state := s_done_div
resHi := false.B
Expand Down Expand Up @@ -123,7 +124,7 @@ class MulDiv(cfg: MulDivParams, width: Int, nXpr: Int = 32, aluFn: ALUFN = new A
resHi := isHi
}
}
if (cfg.divUnroll != 0) when (state === s_div) {
if (cfg.divEnabled && cfg.divUnroll != 0) when (state === s_div) {
val unrolls = ((0 until cfg.divUnroll) scanLeft remainder) { case (rem, i) =>
// the special case for iteration 0 is to save HW, not for correctness
val difference = if (i == 0) subtractor else rem(2*w,w) - divisor(w-1,0)
Expand Down
5 changes: 3 additions & 2 deletions src/main/scala/rocket/RocketCore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ class Rocket(tile: RocketTile)(implicit p: Parameters) extends CoreModule()(p)
val pipelinedMul = usingMulDiv && mulDivParams.mulUnroll == xLen
val decode_table = {
require(!usingRoCC || !rocketParams.useSCIE)
(if (usingMulDiv) new MDecode(pipelinedMul, aluFn) +: (xLen > 32).option(new M64Decode(pipelinedMul, aluFn)).toSeq else Nil) ++:
(if (usingMulDiv) new MMulDecode(pipelinedMul, aluFn) +: (xLen > 32).option(new MMul64Decode(pipelinedMul, aluFn)).toSeq else Nil) ++:
(if (usingMulDiv && mulDivParams.divEnabled) new MDivDecode(aluFn) +: (xLen > 32).option(new MDiv64Decode(aluFn)).toSeq else Nil) ++:
(if (usingAtomics) new ADecode(aluFn) +: (xLen > 32).option(new A64Decode(aluFn)).toSeq else Nil) ++:
(if (fLen >= 32) new FDecode(aluFn) +: (xLen > 32).option(new F64Decode(aluFn)).toSeq else Nil) ++:
(if (fLen >= 64) new DDecode(aluFn) +: (xLen > 32).option(new D64Decode(aluFn)).toSeq else Nil) ++:
Expand Down Expand Up @@ -327,7 +328,7 @@ class Rocket(tile: RocketTile)(implicit p: Parameters) extends CoreModule()(p)
}
val id_illegal_rnum = if (usingCryptoNIST) (id_ctrl.zkn && aluFn.isKs1(id_ctrl.alu_fn) && id_inst(0)(23,20) > 0xA.U(4.W)) else false.B
val id_illegal_insn = !id_ctrl.legal ||
(id_ctrl.mul || id_ctrl.div) && !csr.io.status.isa('m'-'a') ||
id_ctrl.div && !csr.io.status.isa('m'-'a') ||
id_ctrl.amo && !csr.io.status.isa('a'-'a') ||
id_ctrl.fp && (csr.io.decode(0).fp_illegal || io.fpu.illegal_rm) ||
id_ctrl.dp && !csr.io.status.isa('d'-'a') ||
Expand Down
9 changes: 9 additions & 0 deletions src/main/scala/subsystem/Configs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,15 @@ class WithoutMulDiv extends Config((site, here, up) => {
}
})

class WithoutDiv extends Config((site, here, up) => {
case TilesLocated(InSubsystem) => up(TilesLocated(InSubsystem), site) map {
case tp: RocketTileAttachParams => tp.copy(tileParams = tp.tileParams.copy(
core = tp.tileParams.core.copy(
mulDiv = tp.tileParams.core.mulDiv.map(_.copy(divEnabled = false)))))
case t => t
}
})

class WithoutFPU extends Config((site, here, up) => {
case TilesLocated(InSubsystem) => up(TilesLocated(InSubsystem), site) map {
case tp: RocketTileAttachParams => tp.copy(tileParams = tp.tileParams.copy(
Expand Down
3 changes: 2 additions & 1 deletion src/main/scala/tile/BaseTile.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ trait HasNonDiplomaticTileParameters {
// TODO merge with isaString in CSR.scala
def isaDTS: String = {
val ie = if (tileParams.core.useRVE) "e" else "i"
val m = if (tileParams.core.mulDiv.nonEmpty) "m" else ""
val m = if (tileParams.core.mulDiv.nonEmpty && tileParams.core.mulDiv.get.divEnabled) "m" else ""
val a = if (tileParams.core.useAtomics) "a" else ""
val f = if (tileParams.core.fpu.nonEmpty) "f" else ""
val d = if (tileParams.core.fpu.nonEmpty && tileParams.core.fpu.get.fLen > 32) "d" else ""
Expand All @@ -112,6 +112,7 @@ trait HasNonDiplomaticTileParameters {
//Some(Seq("zicntr")) ++
Option.when(tileParams.core.useConditionalZero)(Seq("zicond")) ++
Some(Seq("zicsr", "zifencei", "zihpm")) ++
Option.when(tileParams.core.mulDiv.nonEmpty)(Seq("zmmul")) ++
Option.when(tileParams.core.fpu.nonEmpty && tileParams.core.fpu.get.fLen >= 16 && tileParams.core.fpu.get.minFLen <= 16)(Seq("zfh")) ++
Option.when(tileParams.core.useBitManip)(Seq("zba", "zbb", "zbc")) ++
Option.when(tileParams.core.hasBitManipCrypto)(Seq("zbkb", "zbkc", "zbkx")) ++
Expand Down
Loading