Initial commit
This commit is contained in:
commit
79febf1a8f
|
@ -0,0 +1,8 @@
|
|||
out/
|
||||
.bsp/
|
||||
.idea/
|
||||
.idea_modules/
|
||||
build/
|
||||
test_run_dir
|
||||
testfloat_gen
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
[submodule "berkeley-softfloat-3"]
|
||||
path = berkeley-softfloat-3
|
||||
url = https://github.com/ucb-bar/berkeley-softfloat-3.git
|
||||
[submodule "berkeley-testfloat-3"]
|
||||
path = berkeley-testfloat-3
|
||||
url = https://github.com/ucb-bar/berkeley-testfloat-3.git
|
|
@ -0,0 +1 @@
|
|||
0.9.7
|
|
@ -0,0 +1,26 @@
|
|||
version = 2.6.4
|
||||
|
||||
maxColumn = 80
|
||||
align = most
|
||||
continuationIndent.defnSite = 2
|
||||
assumeStandardLibraryStripMargin = true
|
||||
docstrings = ScalaDoc
|
||||
lineEndings = preserve
|
||||
includeCurlyBraceInSelectChains = false
|
||||
danglingParentheses = true
|
||||
|
||||
align.tokens.add = [
|
||||
{
|
||||
code = ":"
|
||||
}
|
||||
]
|
||||
|
||||
newlines.alwaysBeforeCurlyBraceLambdaParams = false
|
||||
newlines.alwaysBeforeMultilineDef = false
|
||||
newlines.implicitParamListModifierForce = [before]
|
||||
|
||||
verticalMultiline.atDefnSite = true
|
||||
|
||||
optIn.annotationNewlines = true
|
||||
|
||||
rewrite.rules = [SortImports, PreferCurlyFors, AvoidInfix]
|
|
@ -0,0 +1,66 @@
|
|||
compile:
|
||||
mill -i fudian.compile
|
||||
|
||||
bsp:
|
||||
mill -i mill.bsp.BSP/install
|
||||
|
||||
clean:
|
||||
rm -rf ./build
|
||||
|
||||
reformat:
|
||||
mill -i __.reformat
|
||||
|
||||
checkformat:
|
||||
mill -i __.checkFormat
|
||||
|
||||
berkeley-softfloat-3/build/Linux-x86_64-GCC/softfloat.a: berkeley-softfloat-3/.git
|
||||
$(MAKE) -C berkeley-softfloat-3/build/Linux-x86_64-GCC SPECIALIZE_TYPE=RISCV
|
||||
|
||||
berkeley-testfloat-3/build/Linux-x86_64-GCC/testfloat_gen: berkeley-testfloat-3/.git \
|
||||
berkeley-softfloat-3/build/Linux-x86_64-GCC/softfloat.a
|
||||
$(MAKE) -C berkeley-testfloat-3/build/Linux-x86_64-GCC SPECIALIZE_TYPE=RISCV
|
||||
|
||||
TEST_FLOAT_GEN = berkeley-testfloat-3/build/Linux-x86_64-GCC/testfloat_gen
|
||||
BUILD_DIR = $(abspath ./build)
|
||||
CSRC_DIR = $(abspath ./src/test/resources/csrc)
|
||||
SCALA_SRC = $(shell find ./src/main/scala -name "*.scala")
|
||||
|
||||
all_tests: f32_add_tests f32_sub_tests f64_add_tests f64_sub_tests
|
||||
|
||||
define test_template
|
||||
|
||||
$(1)_emu = $$(BUILD_DIR)/$(2)_$(3)/$(2).emu
|
||||
$(1)_v = $$(BUILD_DIR)/$(2)_$(3)/$(2).v
|
||||
|
||||
$$($(1)_v): $$(SCALA_SRC)
|
||||
mill fudian.runMain fudian.$(2) --full-stacktrace -td $$(@D) $(3)
|
||||
|
||||
$$($(1)_emu): $$($(1)_v) $$(CSRC_DIR)/$(2)_Test.cpp
|
||||
verilator --cc --exe $$^ -Mdir $$(@D) -o $$@ --build
|
||||
|
||||
$(1)_test_rnear_even: $$($(1)_emu)
|
||||
$$(TEST_FLOAT_GEN) $(1) -tininessafter -rnear_even -level 2 | $$< -rnear_even $(4)
|
||||
|
||||
$(1)_test_rminMag: $$($(1)_emu)
|
||||
$$(TEST_FLOAT_GEN) $(1) -tininessafter -rminMag -level 2 | $$< -rminMag $(4)
|
||||
|
||||
$(1)_test_rmin: $$($(1)_emu)
|
||||
$$(TEST_FLOAT_GEN) $(1) -tininessafter -rmin -level 2 | $$< -rmin $(4)
|
||||
|
||||
$(1)_test_rmax: $$($(1)_emu)
|
||||
$$(TEST_FLOAT_GEN) $(1) -tininessafter -rmax -level 2 | $$< -rmax $(4)
|
||||
|
||||
$(1)_test_rnear_maxMag: $$($(1)_emu)
|
||||
$$(TEST_FLOAT_GEN) $(1) -tininessafter -rnear_maxMag -level 2 | $$< -rnear_maxMag $(4)
|
||||
|
||||
$(1)_tests: $(1)_test_rnear_even \
|
||||
$(1)_test_rminMag \
|
||||
$(1)_test_rmin \
|
||||
$(1)_test_rmax \
|
||||
$(1)_test_rnear_maxMag
|
||||
endef
|
||||
|
||||
$(eval $(call test_template,f32_add,FADD,32,add))
|
||||
$(eval $(call test_template,f32_sub,FADD,32,sub))
|
||||
$(eval $(call test_template,f64_add,FADD,64,add))
|
||||
$(eval $(call test_template,f64_sub,FADD,64,sub))
|
|
@ -0,0 +1 @@
|
|||
Subproject commit b64af41c3276f97f0e181920400ee056b9c88037
|
|
@ -0,0 +1 @@
|
|||
Subproject commit 06b20075dd3c1a5d0dd007a93643282832221612
|
|
@ -0,0 +1,50 @@
|
|||
// import Mill dependency
|
||||
import mill._
|
||||
import scalalib._
|
||||
import scalafmt._
|
||||
import coursier.maven.MavenRepository
|
||||
|
||||
val defaultVersions = Map(
|
||||
"chisel3" -> "3.5-SNAPSHOT",
|
||||
"chisel3-plugin" -> "3.5-SNAPSHOT",
|
||||
"scala" -> "2.12.13",
|
||||
"chiseltest" -> "latest.integration",
|
||||
"scalatest" -> "3.2.7"
|
||||
)
|
||||
|
||||
def getVersion(dep: String, org: String = "edu.berkeley.cs", cross: Boolean = false) = {
|
||||
val version = sys.env.getOrElse(dep + "Version", defaultVersions(dep))
|
||||
if (cross)
|
||||
ivy"$org:::$dep:$version"
|
||||
else
|
||||
ivy"$org::$dep:$version"
|
||||
}
|
||||
|
||||
object fudian extends SbtModule with ScalaModule with ScalafmtModule {
|
||||
|
||||
override def millSourcePath = millOuterCtx.millSourcePath
|
||||
|
||||
override def repositoriesTask = T.task {
|
||||
super.repositoriesTask() ++ Seq(
|
||||
MavenRepository("https://oss.sonatype.org/content/repositories/snapshots")
|
||||
)
|
||||
}
|
||||
|
||||
def scalaVersion = defaultVersions("scala")
|
||||
|
||||
|
||||
override def scalacPluginIvyDeps = super.scalacPluginIvyDeps() ++ Agg(getVersion("chisel3-plugin", cross = true))
|
||||
|
||||
override def ivyDeps = super.ivyDeps() ++ Agg(
|
||||
getVersion("chisel3"),
|
||||
getVersion("chiseltest")
|
||||
)
|
||||
|
||||
object tests extends Tests {
|
||||
override def ivyDeps = super.ivyDeps() ++ Agg(
|
||||
getVersion("scalatest","org.scalatest")
|
||||
)
|
||||
override def testFramework: T[String] = T("org.scalatest.tools.testFramework")
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,404 @@
|
|||
package fudian
|
||||
|
||||
import chisel3._
|
||||
import chisel3.stage.{ChiselGeneratorAnnotation, ChiselStage}
|
||||
import chisel3.util._
|
||||
import chisel3.util.experimental.decode._
|
||||
import fudian.utils._
|
||||
|
||||
class NearPath(val expWidth: Int, val precision: Int) extends Module {
|
||||
val io = IO(new Bundle() {
|
||||
val in = Input(new Bundle() {
|
||||
val a, b = new RawFloat(expWidth, precision)
|
||||
val need_shift_b = Bool()
|
||||
})
|
||||
val out = Output(new Bundle() {
|
||||
val result = new RawFloat(expWidth, precision + 2)
|
||||
val sig_is_zero = Bool()
|
||||
val a_lt_b = Bool()
|
||||
})
|
||||
})
|
||||
// we assue a >= b
|
||||
val (a, b) = (io.in.a, io.in.b)
|
||||
val need_shift = io.in.need_shift_b
|
||||
val a_sig = Cat(a.sig, 0.U(1.W))
|
||||
val b_sig = (Cat(b.sig, 0.U(1.W)) >> need_shift).asUInt()
|
||||
val b_neg = (~b_sig).asUInt()
|
||||
// extend 1 bit to get 'a_lt_b'
|
||||
val a_minus_b = Cat(0.U(1.W), a_sig) + Cat(1.U(1.W), b_neg) + 1.U
|
||||
val a_lt_b = a_minus_b.head(1).asBool()
|
||||
// we do not need carry out here
|
||||
val sig_raw = a_minus_b.tail(1)
|
||||
val lza_ab = Module(new LZA(precision + 1))
|
||||
lza_ab.io.a := a_sig
|
||||
lza_ab.io.b := b_neg
|
||||
val lza_str = lza_ab.io.f
|
||||
val lza_str_zero = !Cat(lza_str).orR()
|
||||
|
||||
// need to limit the shamt? (if a.exp is not large enough, a.exp-lzc may < 1)
|
||||
val need_shift_lim = a.exp < (precision + 1).U
|
||||
val mask_table_k_width = log2Up(precision + 1)
|
||||
val shift_lim_mask_raw = decoder(
|
||||
QMCMinimizer,
|
||||
a.exp(mask_table_k_width - 1, 0),
|
||||
TruthTable(
|
||||
(1 to precision + 1).map { i =>
|
||||
BitPat(i.U(mask_table_k_width.W)) -> BitPat(
|
||||
(BigInt(1) << (precision + 1 - i)).U((precision + 1).W)
|
||||
)
|
||||
},
|
||||
BitPat.dontCare(precision + 1)
|
||||
)
|
||||
)
|
||||
val shift_lim_mask = Mux(need_shift_lim, shift_lim_mask_raw, 0.U)
|
||||
val lzc_str = shift_lim_mask | lza_str
|
||||
val lzc = CLZ(lzc_str)
|
||||
val int_bit_mask = Wire(Vec(precision + 1, Bool()))
|
||||
for (i <- int_bit_mask.indices) {
|
||||
int_bit_mask(i) := {
|
||||
if (i == int_bit_mask.size - 1) {
|
||||
lzc_str(i)
|
||||
} else {
|
||||
lzc_str(i) & !lzc_str.head(int_bit_mask.size - i - 1).orR()
|
||||
}
|
||||
}
|
||||
}
|
||||
val exceed_lim_mask = Wire(Vec(precision + 1, Bool()))
|
||||
for (i <- exceed_lim_mask.indices) {
|
||||
exceed_lim_mask(i) := {
|
||||
if (i == exceed_lim_mask.size - 1) {
|
||||
false.B
|
||||
} else {
|
||||
lza_str.head(exceed_lim_mask.size - i - 1).orR()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val exceed_lim =
|
||||
need_shift_lim && !(exceed_lim_mask.asUInt() & shift_lim_mask).orR()
|
||||
val int_bit_predicted =
|
||||
((int_bit_mask.asUInt() | lza_str_zero) & sig_raw).orR()
|
||||
val lza_error = !int_bit_predicted && !exceed_lim
|
||||
val int_bit = Mux(
|
||||
lza_error,
|
||||
((int_bit_mask.asUInt() >> 1.U).asUInt() & sig_raw).orR(),
|
||||
int_bit_predicted
|
||||
)
|
||||
|
||||
val exp_s1 = a.exp - lzc
|
||||
val exp_s2 = exp_s1 - lza_error
|
||||
val sig_s1 = (sig_raw << lzc)(precision, 0)
|
||||
val sig_s2 = Mux(lza_error, Cat(sig_s1.tail(1), 0.U(1.W)), sig_s1)
|
||||
val near_path_sig = sig_s2
|
||||
val near_path_exp = Mux(int_bit, exp_s2, 0.U)
|
||||
val near_path_sign = Mux(a_lt_b, b.sign, a.sign)
|
||||
|
||||
val result = Wire(new RawFloat(expWidth, precision + 2))
|
||||
result.sign := near_path_sign
|
||||
result.exp := near_path_exp
|
||||
result.sig := Cat(near_path_sig, false.B) // 'sticky' always 0
|
||||
io.out.result := result
|
||||
io.out.sig_is_zero := lza_str_zero && !sig_raw(0)
|
||||
io.out.a_lt_b := a_lt_b
|
||||
}
|
||||
|
||||
class FarPath(val expWidth: Int, val precision: Int) extends Module {
|
||||
val io = IO(new Bundle() {
|
||||
val in = Input(new Bundle() {
|
||||
val a, b = new RawFloat(expWidth, precision)
|
||||
val expDiff = UInt(expWidth.W)
|
||||
val effSub = Bool()
|
||||
val smallAdd = Bool()
|
||||
})
|
||||
val out = Output(new Bundle() {
|
||||
val result = new RawFloat(expWidth, precision + 2)
|
||||
})
|
||||
})
|
||||
|
||||
val in = io.in
|
||||
val (a, b, expDiff, effSub, smallAdd) =
|
||||
(in.a, in.b, in.expDiff, in.effSub, in.smallAdd)
|
||||
|
||||
// shamt <- [2, precision + 2]
|
||||
val sig_b_shamt =
|
||||
Mux(
|
||||
expDiff > (precision + 2).U,
|
||||
(precision + 2).U,
|
||||
expDiff(log2Up(precision + 2) - 1, 0)
|
||||
)
|
||||
|
||||
val sig_b_main = Cat(b.sig, 0.U(2.W)) >> sig_b_shamt
|
||||
|
||||
val sticky_mask = ((1.U(1.W) << sig_b_shamt).asUInt() - 1.U)(precision + 1, 2)
|
||||
val sig_b_sticky = (sticky_mask & b.sig).orR()
|
||||
|
||||
val adder_in_sig_b = Cat(0.U(1.W), sig_b_main, sig_b_sticky)
|
||||
val adder_in_sig_a = Cat(0.U(1.W), a.sig, 0.U(3.W))
|
||||
val adder_result =
|
||||
adder_in_sig_a +
|
||||
Mux(effSub, ~adder_in_sig_b, adder_in_sig_b).asUInt() + effSub
|
||||
|
||||
val exp_a_plus_1 = a.exp + 1.U
|
||||
val exp_a_minus_1 = a.exp - 1.U
|
||||
|
||||
val cout = adder_result.head(1).asBool()
|
||||
val keep = adder_result.head(2) === 1.U
|
||||
val cancellation = adder_result.head(2) === 0.U
|
||||
|
||||
val far_path_sig = Mux1H(
|
||||
Seq(cout, keep || smallAdd, cancellation && !smallAdd),
|
||||
Seq(
|
||||
adder_result.head(precision + 1),
|
||||
Cat(adder_result.tail(1).head(precision + 1)),
|
||||
Cat(adder_result.tail(2).head(precision + 1))
|
||||
)
|
||||
)
|
||||
|
||||
val far_path_sticky = Mux1H(
|
||||
Seq(cout, keep || smallAdd, cancellation && !smallAdd),
|
||||
Seq(
|
||||
adder_result.tail(precision + 1).orR(),
|
||||
adder_result.tail(precision + 2).orR(),
|
||||
adder_result.tail(precision + 3).orR()
|
||||
)
|
||||
)
|
||||
|
||||
val far_path_exp = Mux1H(
|
||||
Seq(cout, keep, cancellation),
|
||||
Seq(exp_a_plus_1, a.exp, exp_a_minus_1)
|
||||
)
|
||||
|
||||
val result = Wire(new RawFloat(expWidth, precision + 2))
|
||||
result.sign := a.sign
|
||||
result.exp := far_path_exp
|
||||
result.sig := Cat(far_path_sig, far_path_sticky)
|
||||
io.out.result := result
|
||||
}
|
||||
|
||||
class FADD(val expWidth: Int, val precision: Int) extends Module {
|
||||
val io = IO(new Bundle() {
|
||||
val a, b = Input(UInt((expWidth + precision).W))
|
||||
val rm = Input(UInt(3.W))
|
||||
val do_sub = Input(Bool())
|
||||
val result = Output(UInt((expWidth + precision).W))
|
||||
val fflags = Output(UInt(5.W))
|
||||
})
|
||||
|
||||
val fp_a = FloatPoint.fromUInt(io.a, expWidth, precision)
|
||||
val fp_b = FloatPoint.fromUInt(
|
||||
Cat(io.b.head(1) ^ io.do_sub, io.b.tail(1)),
|
||||
expWidth,
|
||||
precision
|
||||
)
|
||||
val decode_a = fp_a.decode
|
||||
val decode_b = fp_b.decode
|
||||
val raw_a = RawFloat.fromFP(fp_a, Some(decode_a.expNotZero))
|
||||
val raw_b = RawFloat.fromFP(fp_b, Some(decode_b.expNotZero))
|
||||
val eff_sub = raw_a.sign ^ raw_b.sign
|
||||
|
||||
val small_add = decode_a.expIsZero && decode_b.expIsZero
|
||||
|
||||
// deal with special cases
|
||||
val special_path_hasNaN = decode_a.isNaN || decode_b.isNaN
|
||||
val special_path_hasSNaN = decode_a.isSNaN || decode_b.isSNaN
|
||||
val special_path_hasInf = decode_a.isInf || decode_b.isInf
|
||||
val special_path_inf_iv = decode_a.isInf && decode_b.isInf && eff_sub
|
||||
|
||||
val exp_diff_a_b = Cat(0.U(1.W), raw_a.exp) - Cat(0.U(1.W), raw_b.exp)
|
||||
val exp_diff_b_a = Cat(0.U(1.W), raw_b.exp) - Cat(0.U(1.W), raw_a.exp)
|
||||
val need_swap = exp_diff_a_b.head(1).asBool()
|
||||
|
||||
// make sure exp_a >= exp_b in near path and far path
|
||||
def swap[T <: Data](swp: Bool, x: T, y: T): (T, T) = {
|
||||
val nx = Mux(swp, y, x)
|
||||
val ny = Mux(swp, x, y)
|
||||
(nx, ny)
|
||||
}
|
||||
val (a, b) = swap(need_swap, raw_a, raw_b)
|
||||
|
||||
val ea_minus_eb = Mux(need_swap, exp_diff_b_a.tail(1), exp_diff_a_b.tail(1))
|
||||
val sel_far_path = !eff_sub || ea_minus_eb > 1.U
|
||||
|
||||
/*
|
||||
Far path
|
||||
*/
|
||||
|
||||
val far_path_inputs = Seq(
|
||||
(raw_a, raw_b, exp_diff_a_b),
|
||||
(raw_b, raw_a, exp_diff_b_a)
|
||||
)
|
||||
|
||||
val far_path_mods = far_path_inputs.map { in =>
|
||||
val far_path = Module(new FarPath(expWidth, precision))
|
||||
far_path.io.in.a := in._1
|
||||
far_path.io.in.b := in._2
|
||||
far_path.io.in.expDiff := in._3
|
||||
far_path.io.in.effSub := eff_sub
|
||||
far_path.io.in.smallAdd := small_add
|
||||
far_path
|
||||
}
|
||||
|
||||
val far_path_res = Mux1H(
|
||||
Seq(!need_swap, need_swap),
|
||||
far_path_mods.map(_.io.out.result)
|
||||
)
|
||||
val far_path_exp = far_path_res.exp
|
||||
val far_path_sig = far_path_res.sig
|
||||
|
||||
val far_path_rounder = Module(new RoundingUnit(precision - 1))
|
||||
far_path_rounder.io.in := far_path_sig.tail(1).head(precision - 1)
|
||||
far_path_rounder.io.roundIn := far_path_sig(1)
|
||||
far_path_rounder.io.stickyIn := far_path_sig(0)
|
||||
far_path_rounder.io.signIn := far_path_res.sign
|
||||
far_path_rounder.io.rm := io.rm
|
||||
|
||||
val far_path_exp_rounded = far_path_rounder.io.cout + far_path_exp
|
||||
val far_path_sig_rounded = far_path_rounder.io.out
|
||||
|
||||
val far_path_may_uf = (far_path_exp === 0.U) && !far_path_rounder.io.cout
|
||||
val far_path_of = Mux(
|
||||
far_path_rounder.io.cout,
|
||||
far_path_exp === ((BigInt(1) << expWidth) - 2).U,
|
||||
far_path_exp === ((BigInt(1) << expWidth) - 1).U
|
||||
)
|
||||
val far_path_ix = far_path_rounder.io.inexact | far_path_of
|
||||
val far_path_uf = far_path_may_uf & far_path_ix
|
||||
|
||||
val rmin =
|
||||
io.rm === RTZ || (io.rm === RDN && !a.sign) || (io.rm === RUP && a.sign)
|
||||
val far_path_result_exp = Mux(
|
||||
far_path_of && rmin,
|
||||
((BigInt(1) << expWidth) - 2).U(expWidth.W),
|
||||
far_path_exp_rounded
|
||||
)
|
||||
val far_path_result_sig = Mux(
|
||||
far_path_of,
|
||||
Mux(rmin, Fill(precision - 1, 1.U(1.W)), 0.U((precision - 1).W)),
|
||||
far_path_sig_rounded
|
||||
)
|
||||
val far_path_result = Cat(a.sign, far_path_result_exp, far_path_result_sig)
|
||||
|
||||
/*
|
||||
Near path
|
||||
*/
|
||||
|
||||
val near_path_inputs = Seq(
|
||||
(raw_a, raw_b, false.B),
|
||||
(raw_a, raw_b, true.B),
|
||||
(raw_b, raw_a, false.B),
|
||||
(raw_b, raw_a, true.B)
|
||||
)
|
||||
val near_path_mods = near_path_inputs.map { in =>
|
||||
val near_path = Module(new NearPath(expWidth, precision))
|
||||
near_path.io.in.a := in._1
|
||||
near_path.io.in.b := in._2
|
||||
near_path.io.in.need_shift_b := in._3
|
||||
near_path
|
||||
}
|
||||
val exp_eq = raw_a.exp === raw_b.exp
|
||||
/*
|
||||
exp_eq => (a - b, b - a)
|
||||
expa > expb => a - b_shift
|
||||
expb > expa => b - a_shift
|
||||
*/
|
||||
val near_path_out = Mux1H(
|
||||
Seq(
|
||||
exp_eq && !near_path_mods.head.io.out.a_lt_b, // exp_eq && a_sig >= b_sig
|
||||
!exp_eq && !need_swap, // expa > expb
|
||||
exp_eq && near_path_mods.head.io.out.a_lt_b, // exp_eq && a_sig < b_sig
|
||||
need_swap // expb > expa
|
||||
),
|
||||
near_path_mods.map(_.io.out)
|
||||
)
|
||||
|
||||
val near_path_res = near_path_out.result
|
||||
val near_path_sign = near_path_res.sign
|
||||
val near_path_exp = near_path_res.exp
|
||||
val near_path_sig = near_path_res.sig
|
||||
val near_path_sig_zero = near_path_out.sig_is_zero
|
||||
val near_path_is_zero = near_path_exp === 0.U && near_path_sig_zero
|
||||
|
||||
val near_path_rounder = Module(new RoundingUnit(precision - 1))
|
||||
near_path_rounder.io.in := near_path_sig.tail(1).head(precision - 1)
|
||||
near_path_rounder.io.signIn := near_path_res.sign
|
||||
near_path_rounder.io.roundIn := near_path_sig(1)
|
||||
near_path_rounder.io.stickyIn := false.B
|
||||
near_path_rounder.io.rm := io.rm
|
||||
|
||||
val near_path_exp_rounded = near_path_rounder.io.cout + near_path_exp
|
||||
val near_path_sig_rounded = near_path_rounder.io.out
|
||||
val near_path_zero_sign = io.rm === RDN
|
||||
val near_path_result = Cat(
|
||||
(near_path_sign && !near_path_is_zero) || (near_path_zero_sign && near_path_is_zero),
|
||||
near_path_exp_rounded,
|
||||
near_path_sig_rounded
|
||||
)
|
||||
|
||||
val near_path_ix = near_path_rounder.io.inexact
|
||||
val near_path_may_uf = (near_path_exp === 0.U) && !near_path_rounder.io.cout
|
||||
val near_path_uf = near_path_may_uf && near_path_ix
|
||||
|
||||
/*
|
||||
Final result <- [special, far, near]
|
||||
*/
|
||||
|
||||
val iv = special_path_hasSNaN || special_path_inf_iv
|
||||
val dz = false.B
|
||||
val of = Mux(
|
||||
special_path_hasNaN || special_path_hasInf,
|
||||
false.B,
|
||||
sel_far_path && far_path_of
|
||||
)
|
||||
val uf = Mux(
|
||||
special_path_hasNaN || special_path_hasInf,
|
||||
false.B,
|
||||
(sel_far_path && far_path_uf) || (!sel_far_path && near_path_uf)
|
||||
)
|
||||
val ix = Mux(
|
||||
special_path_hasNaN || special_path_hasInf,
|
||||
false.B,
|
||||
(sel_far_path && far_path_ix) || (!sel_far_path && near_path_ix)
|
||||
)
|
||||
|
||||
io.result := Mux(
|
||||
special_path_hasNaN || special_path_inf_iv,
|
||||
FloatPoint.defaultNaNUInt(expWidth, precision),
|
||||
Mux(
|
||||
special_path_hasInf,
|
||||
Mux1H(
|
||||
Seq(
|
||||
decode_a.isInf -> fp_a.asUInt(),
|
||||
decode_b.isInf -> fp_b.asUInt()
|
||||
)
|
||||
),
|
||||
Mux1H(
|
||||
Seq(
|
||||
sel_far_path -> far_path_result,
|
||||
!sel_far_path -> near_path_result
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
io.fflags := Cat(iv, dz, of, uf, ix)
|
||||
}
|
||||
|
||||
object FADD extends App {
|
||||
override def main(args: Array[String]): Unit = {
|
||||
// arg fmt: -td ... 32 / -td ... 64
|
||||
val (expWidth, precision) = args.last match {
|
||||
case "32" =>
|
||||
(8, 24)
|
||||
case "64" =>
|
||||
(11, 53)
|
||||
case _ =>
|
||||
println("usage: runMain fudian.FADD -td <build dir> <ftype>")
|
||||
sys.exit(-1)
|
||||
}
|
||||
(new ChiselStage).execute(
|
||||
args,
|
||||
Seq(
|
||||
ChiselGeneratorAnnotation(() => new FADD(expWidth, precision))
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
package fudian
|
||||
|
||||
import chisel3._
|
||||
import chisel3.util._
|
||||
|
||||
class RoundingUnit(val width: Int) extends Module {
|
||||
val io = IO(new Bundle() {
|
||||
val in = Input(UInt(width.W))
|
||||
val roundIn = Input(Bool())
|
||||
val stickyIn = Input(Bool())
|
||||
val signIn = Input(Bool())
|
||||
val rm = Input(UInt(3.W))
|
||||
val out = Output(UInt(width.W))
|
||||
val inexact = Output(Bool())
|
||||
val cout = Output(Bool())
|
||||
})
|
||||
|
||||
val (g, r, s) = (io.in(0).asBool(), io.roundIn, io.stickyIn)
|
||||
val inexact = r | s
|
||||
val r_up = MuxLookup(
|
||||
io.rm,
|
||||
false.B,
|
||||
Seq(
|
||||
RNE -> ((r && s) || (r && !s && g)),
|
||||
RTZ -> false.B,
|
||||
RUP -> (inexact & !io.signIn),
|
||||
RDN -> (inexact & io.signIn),
|
||||
RMM -> r
|
||||
)
|
||||
)
|
||||
val out_r_up = io.in + 1.U
|
||||
io.out := Mux(r_up, out_r_up, io.in)
|
||||
io.inexact := inexact
|
||||
// r_up && io.in === 111...1
|
||||
io.cout := r_up && io.in.andR()
|
||||
}
|
|
@ -0,0 +1,83 @@
|
|||
import chisel3._
|
||||
import chisel3.util._
|
||||
|
||||
package object fudian {
|
||||
|
||||
def RNE: UInt = 0.U(3.W)
|
||||
def RTZ: UInt = 1.U(3.W)
|
||||
def RDN: UInt = 2.U(3.W)
|
||||
def RUP: UInt = 3.U(3.W)
|
||||
def RMM: UInt = 4.U(3.W)
|
||||
|
||||
class FPDecodeBundle extends Bundle {
|
||||
val expNotZero = Bool()
|
||||
val expIsZero = Bool()
|
||||
val expIsOnes = Bool()
|
||||
val sigNotZero = Bool()
|
||||
val sigIsZero = Bool()
|
||||
val isSubnormal = Bool()
|
||||
val isInf = Bool()
|
||||
val isZero = Bool()
|
||||
val isNaN = Bool()
|
||||
val isSNaN = Bool()
|
||||
val isQNaN = Bool()
|
||||
}
|
||||
|
||||
class FloatPoint(val expWidth: Int, val precision: Int) extends Bundle {
|
||||
def sigWidth = precision - 1
|
||||
val sign = Bool()
|
||||
val exp = UInt(expWidth.W)
|
||||
val sig = UInt(sigWidth.W)
|
||||
def decode: FPDecodeBundle = {
|
||||
val expNotZero = exp.orR()
|
||||
val expIsOnes = exp.andR()
|
||||
val sigNotZero = sig.orR()
|
||||
val bundle = Wire(new FPDecodeBundle)
|
||||
bundle.expNotZero := expNotZero
|
||||
bundle.expIsZero := !expNotZero
|
||||
bundle.expIsOnes := expIsOnes
|
||||
bundle.sigNotZero := sigNotZero
|
||||
bundle.sigIsZero := !sigNotZero
|
||||
bundle.isSubnormal := bundle.expIsZero && sigNotZero
|
||||
bundle.isInf := bundle.expIsOnes && bundle.sigIsZero
|
||||
bundle.isZero := bundle.expIsZero && bundle.sigIsZero
|
||||
bundle.isNaN := bundle.expIsOnes && bundle.sigNotZero
|
||||
bundle.isSNaN := bundle.isNaN && !sig.head(1).asBool()
|
||||
bundle.isQNaN := bundle.isNaN && sig.head(1).asBool()
|
||||
bundle
|
||||
}
|
||||
}
|
||||
object FloatPoint {
|
||||
def fromUInt(x: UInt, expWidth: Int, pc: Int): FloatPoint = {
|
||||
val fp = Wire(new FloatPoint(expWidth, pc))
|
||||
fp.sign := x(expWidth + pc - 1)
|
||||
fp.exp := x(expWidth + pc - 2, pc - 1)
|
||||
fp.sig := x(pc - 2, 0)
|
||||
fp
|
||||
}
|
||||
def defaultNaNUInt(expWidth: Int, pc: Int): UInt = {
|
||||
Cat(0.U(1.W), Fill(expWidth + 1, 1.U(1.W)), 0.U((pc - 2).W))
|
||||
}
|
||||
def defaultNaN(expWidth: Int, pc: Int): FloatPoint = {
|
||||
fromUInt(defaultNaNUInt(expWidth, pc), expWidth, pc)
|
||||
}
|
||||
}
|
||||
|
||||
class RawFloat(val expWidth: Int, val precision: Int) extends Bundle {
|
||||
val sign = Bool()
|
||||
val exp = UInt(expWidth.W)
|
||||
val sig = UInt(precision.W)
|
||||
}
|
||||
|
||||
object RawFloat {
|
||||
def fromFP(fp: FloatPoint, expNotZero: Option[Bool] = None): RawFloat = {
|
||||
val inner = Wire(new RawFloat(fp.expWidth, fp.precision))
|
||||
val nz = if (expNotZero.isDefined) expNotZero.get else fp.exp.orR()
|
||||
inner.sign := fp.sign
|
||||
inner.exp := fp.exp | !nz
|
||||
inner.sig := Cat(nz, fp.sig)
|
||||
inner
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
package fudian.utils
|
||||
|
||||
import chisel3._
|
||||
import chisel3.util._
|
||||
import chisel3.util.experimental.decode._
|
||||
|
||||
class CLZ(len: Int, zero: Boolean) extends Module {
|
||||
|
||||
val inWidth = len
|
||||
val outWidth = (inWidth - 1).U.getWidth
|
||||
|
||||
val io = IO(new Bundle() {
|
||||
val in = Input(UInt(inWidth.W))
|
||||
val out = Output(UInt(outWidth.W))
|
||||
})
|
||||
|
||||
val normalTerms = Seq.tabulate(inWidth) { i =>
|
||||
BitPat("b" + ("0" * i) + "1" + ("?" * (inWidth - i - 1))) -> BitPat(
|
||||
i.U(outWidth.W)
|
||||
)
|
||||
}
|
||||
val zeroTerm = BitPat(0.U(inWidth.W)) -> BitPat((inWidth - 1).U(outWidth.W))
|
||||
val terms = if (zero) normalTerms :+ zeroTerm else normalTerms
|
||||
val table = TruthTable(terms, BitPat.dontCare(outWidth))
|
||||
io.out := decoder(QMCMinimizer, io.in, table)
|
||||
}
|
||||
|
||||
object CLZ {
|
||||
def apply(value: UInt): UInt = {
|
||||
val clz = Module(new CLZ(value.getWidth, true))
|
||||
clz.io.in := value
|
||||
clz.io.out
|
||||
}
|
||||
def apply(xs: Seq[Bool]): UInt = {
|
||||
apply(Cat(xs.reverse))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
package fudian.utils
|
||||
|
||||
import chisel3._
|
||||
import chisel3.util._
|
||||
|
||||
class LzaIO(val len: Int) extends Bundle {
|
||||
val a, b = Input(UInt(len.W))
|
||||
val f = Output(UInt(len.W))
|
||||
}
|
||||
|
||||
class LZA(len: Int) extends Module {
|
||||
val io = IO(new LzaIO(len))
|
||||
|
||||
val (a, b) = (io.a, io.b)
|
||||
|
||||
val p, k, f = Wire(Vec(len, Bool()))
|
||||
for (i <- 0 until len) {
|
||||
p(i) := a(i) ^ b(i)
|
||||
k(i) := (!a(i)) & (!b(i))
|
||||
if (i == 0) {
|
||||
f(i) := false.B
|
||||
} else {
|
||||
f(i) := p(i) ^ (!k(i - 1))
|
||||
}
|
||||
}
|
||||
io.f := Cat(f.reverse)
|
||||
}
|
|
@ -0,0 +1,68 @@
|
|||
#include <stdio.h>
|
||||
#include <stdint.h>
|
||||
#include <verilated.h>
|
||||
#include <VFADD.h>
|
||||
#include "common.h"
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
if(argc != 3){
|
||||
printf("usage: %s <rounding-mode> <op>\n", argv[0]);
|
||||
return -1;
|
||||
}
|
||||
|
||||
int rm = get_str_index(argv[1], rounding_modes, 5);
|
||||
if(rm == -1){
|
||||
printf("unknown rounding mode: %s\n", argv[1]);
|
||||
return -1;
|
||||
}
|
||||
|
||||
const char* op_list[] = {"add", "sub"};
|
||||
int op = get_str_index(argv[2], op_list, 2);
|
||||
if(op == -1){
|
||||
printf("unknown op: %s\n", argv[2]);
|
||||
return -1;
|
||||
}
|
||||
|
||||
VFADD module;
|
||||
|
||||
for(int i = 0; i<10; i++){
|
||||
module.reset = 1;
|
||||
module.clock = 0;
|
||||
module.eval();
|
||||
module.clock = 1;
|
||||
module.eval();
|
||||
}
|
||||
module.reset = 0;
|
||||
module.clock =0;
|
||||
module.eval();
|
||||
module.clock = 1;
|
||||
module.eval();
|
||||
|
||||
uint64_t a, b, ref_sum, ref_fflags;
|
||||
uint64_t dut_sum, dut_fflags;
|
||||
|
||||
uint64_t cnt = 0;
|
||||
uint64_t error = 0;
|
||||
|
||||
module.io_rm = rm;
|
||||
module.io_do_sub = op;
|
||||
while(scanf("%lx %lx %lx %lx", &a, &b, &ref_sum, &ref_fflags) != EOF){
|
||||
module.io_a = a;
|
||||
module.io_b = b;
|
||||
module.clock = 0;
|
||||
module.eval();
|
||||
module.clock = 1;
|
||||
module.eval();
|
||||
dut_sum = module.io_result;
|
||||
dut_fflags = module.io_fflags;
|
||||
if( (dut_sum != ref_sum || dut_fflags != ref_fflags) ){
|
||||
printf("[%ld] input: %lx %lx\n", cnt, a, b);
|
||||
printf("[%ld] dut_sum: %lx dut_fflags: %lx\n", cnt, dut_sum, dut_fflags);
|
||||
printf("[%ld] ref_sum: %lx ref_fflags: %lx\n", cnt, ref_sum, ref_fflags);
|
||||
error++;
|
||||
return -1;
|
||||
}
|
||||
cnt++;
|
||||
}
|
||||
printf("cnt = %ld error=%ld\n", cnt, error);
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
#ifndef __COMMON_H__
|
||||
#define __COMMON_H__
|
||||
#include <string.h>
|
||||
const char* rounding_modes[] = {
|
||||
"-rnear_even",
|
||||
"-rminMag", // rtz
|
||||
"-rmin", // rdown
|
||||
"-rmax", // rup
|
||||
"-rnear_maxMag", // rmm
|
||||
};
|
||||
|
||||
int get_rounding_mode(char* rm_str) {
|
||||
for(int i = 0; i < 5; i++){
|
||||
if(strcmp(rm_str, rounding_modes[i]) == 0) return i;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int get_str_index(const char* key, const char* str_lst[], int len) {
|
||||
for(int i = 0; i < len; i++){
|
||||
if(strcmp(key, str_lst[i]) == 0) return i;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
#endif
|
Loading…
Reference in New Issue