Calyx Binary Floating Point AddF Operator (#7089)

* binary floating point add operator for IEEE754
This commit is contained in:
Jiahan Xie 2024-10-31 17:41:21 -04:00 committed by GitHub
parent d22a6957da
commit 963d6950a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 202 additions and 16 deletions

View File

@ -335,6 +335,49 @@ def AndLibOp : CombinationalArithBinaryLibraryOp<"and"> {}
def OrLibOp : CombinationalArithBinaryLibraryOp<"or"> {}
def XorLibOp : CombinationalArithBinaryLibraryOp<"xor"> {}
class ArithBinaryFloatingPointLibraryOp<string mnemonic> : ArithBinaryLibraryOp<mnemonic, [
SameTypeConstraint<"left", "out">]> {}
def AddFNOp : ArithBinaryFloatingPointLibraryOp<"addFN"> {
let results = (outs I1:$clk, I1:$reset, I1:$go, I1:$control, I1:$subOp,
AnyFloat:$left, AnyFloat:$right, AnySignlessInteger:$roundingMode, AnyFloat:$out,
AnySignlessInteger:$exceptionalFlags, I1:$done);
let extraClassDefinition = [{
SmallVector<StringRef> $cppClass::portNames() {
return {clkPort, resetPort, goPort, "control", "subOp",
"left", "right", "roundingMode", "out", "exceptionalFlags", donePort
};
}
SmallVector<Direction> $cppClass::portDirections() {
return {Input, Input, Input, Input, Input, Input, Input, Input, Output, Output, Output};
}
void $cppClass::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
getCellAsmResultNames(setNameFn, *this, this->portNames());
}
bool $cppClass::isCombinational() { return false; }
SmallVector<DictionaryAttr> $cppClass::portAttributes() {
IntegerAttr isSet = IntegerAttr::get(IntegerType::get(getContext(), 1), 1);
NamedAttrList go, clk, reset, done;
go.append(goPort, isSet);
clk.append(clkPort, isSet);
reset.append(resetPort, isSet);
done.append(donePort, isSet);
return {clk.getDictionary(getContext()), reset.getDictionary(getContext()),
go.getDictionary(getContext()), DictionaryAttr::get(getContext()),
DictionaryAttr::get(getContext()), DictionaryAttr::get(getContext()),
DictionaryAttr::get(getContext()), DictionaryAttr::get(getContext()),
DictionaryAttr::get(getContext()), done.getDictionary(getContext()),
DictionaryAttr::get(getContext())
};
}
}];
}
def MuxLibOp : CalyxLibraryOp<"mux", [
Combinational, SameTypeConstraint<"tru", "fal">, SameTypeConstraint<"tru", "out">
]> {

View File

@ -29,6 +29,7 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include <variant>
@ -281,6 +282,9 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
AddIOp, SubIOp, CmpIOp, ShLIOp, ShRUIOp, ShRSIOp,
AndIOp, XOrIOp, OrIOp, ExtUIOp, ExtSIOp, TruncIOp,
MulIOp, DivUIOp, DivSIOp, RemUIOp, RemSIOp,
/// floating point
AddFOp,
/// others
SelectOp, IndexCastOp, CallOp>(
[&](auto op) { return buildOp(rewriter, op).succeeded(); })
.template Case<FuncOp, scf::ConditionOp>([&](auto) {
@ -314,6 +318,7 @@ private:
LogicalResult buildOp(PatternRewriter &rewriter, DivSIOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, RemUIOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, RemSIOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, AddFOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, ShRUIOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, ShRSIOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, ShLIOp op) const;
@ -409,7 +414,7 @@ private:
// Pass the result from the Operation to the Calyx primitive.
op.getResult().replaceAllUsesWith(out);
auto reg = createRegister(
op.getLoc(), rewriter, getComponent(), width.getIntOrFloatBitWidth(),
op.getLoc(), rewriter, getComponent(), width,
getState<ComponentLoweringState>().getUniqueName(opName));
// Operation pipelines are not combinational, so a GroupOp is required.
auto group = createGroupForOp<calyx::GroupOp>(rewriter, op);
@ -434,6 +439,19 @@ private:
// The group is done when the register write is complete.
rewriter.create<calyx::GroupDoneOp>(loc, reg.getDone());
if (isa<calyx::AddFNOp>(opPipe)) {
auto opFN = cast<calyx::AddFNOp>(opPipe);
hw::ConstantOp subOp;
if (isa<arith::AddFOp>(op)) {
subOp = createConstant(loc, rewriter, getComponent(), /*width=*/1,
/*subtract=*/0);
} else {
subOp = createConstant(loc, rewriter, getComponent(), /*width=*/1,
/*subtract=*/1);
}
rewriter.create<calyx::AssignOp>(loc, opFN.getSubOp(), subOp);
}
// Register the values for the pipeline.
getState<ComponentLoweringState>().registerEvaluatingGroup(out, group);
getState<ComponentLoweringState>().registerEvaluatingGroup(opPipe.getLeft(),
@ -666,6 +684,21 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
/*out=*/remPipe.getOut());
}
LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
AddFOp addf) const {
Location loc = addf.getLoc();
Type width = addf.getResult().getType();
IntegerType one = rewriter.getI1Type(), three = rewriter.getIntegerType(3),
five = rewriter.getIntegerType(5);
auto addFN =
getState<ComponentLoweringState>()
.getNewLibraryOpInstance<calyx::AddFNOp>(
rewriter, loc,
{one, one, one, one, one, width, width, three, width, five, one});
return buildLibraryBinaryPipeOp<calyx::AddFNOp>(rewriter, addf, addFN,
addFN.getOut());
}
template <typename TAllocOp>
static LogicalResult buildAllocOp(ComponentLoweringState &componentState,
PatternRewriter &rewriter, TAllocOp allocOp) {
@ -1868,7 +1901,7 @@ public:
ShRSIOp, AndIOp, XOrIOp, OrIOp, ExtUIOp, TruncIOp,
CondBranchOp, BranchOp, MulIOp, DivUIOp, DivSIOp, RemUIOp,
RemSIOp, ReturnOp, arith::ConstantOp, IndexCastOp, FuncOp,
ExtSIOp, CallOp>();
ExtSIOp, CallOp, AddFOp>();
RewritePatternSet legalizePatterns(&getContext());
legalizePatterns.add<DummyPattern>(&getContext());

View File

@ -149,6 +149,10 @@ private:
static constexpr std::string_view sFloat = "float";
return {sFloat};
})
.Case<AddFNOp>([&](auto op) -> FailureOr<StringRef> {
static constexpr std::string_view sFloatingPoint = "float/addFN";
return {sFloatingPoint};
})
.Default([&](auto op) {
auto diag = op->emitOpError() << "not supported for emission";
return diag;
@ -288,6 +292,9 @@ struct Emitter {
void emitLibraryPrimTypedByFirstOutputPort(
Operation *op, std::optional<StringRef> calyxLibName = {});
// Emits a library floating point primitives
void emitLibraryFloatingPoint(Operation *op);
private:
/// Used to track which imports are required for this program.
ImportTracker importTracker;
@ -668,6 +675,7 @@ void Emitter::emitComponent(ComponentInterface op) {
emitLibraryPrimTypedByFirstOutputPort(
op, /*calyxLibName=*/{"std_sdiv_pipe"});
})
.Case<AddFNOp>([&](auto op) { emitLibraryFloatingPoint(op); })
.Default([&](auto op) {
emitOpError(op, "not supported for emission inside component");
});
@ -964,6 +972,43 @@ void Emitter::emitLibraryPrimTypedByFirstOutputPort(
<< LParen() << bitWidth << RParen() << semicolonEndL();
}
void Emitter::emitLibraryFloatingPoint(Operation *op) {
auto cell = cast<CellInterface>(op);
unsigned bitWidth =
cell.getOutputPorts()[0].getType().getIntOrFloatBitWidth();
// Since Calyx interacts with HardFloat, we'll also only be using expWidth and
// sigWidth. See
// http://www.jhauser.us/arithmetic/HardFloat-1/doc/HardFloat-Verilog.html
unsigned expWidth, sigWidth;
switch (bitWidth) {
case 16:
expWidth = 5;
sigWidth = 11;
break;
case 32:
expWidth = 8;
sigWidth = 24;
break;
case 64:
expWidth = 11;
sigWidth = 53;
break;
case 128:
expWidth = 15;
sigWidth = 113;
break;
default:
op->emitError("The supported bitwidths are 16, 32, 64, and 128");
return;
}
StringRef opName = op->getName().getStringRef();
indent() << getAttributes(op, /*atFormat=*/true) << cell.instanceName()
<< space() << equals() << space() << removeCalyxPrefix(opName)
<< LParen() << expWidth << comma() << sigWidth << comma() << bitWidth
<< RParen() << semicolonEndL();
}
void Emitter::emitAssignment(AssignOp op) {
emitValue(op.getDest(), /*isIndented=*/true);

View File

@ -657,10 +657,10 @@ void InlineCombGroups::recurseInlineCombGroups(
// LateSSAReplacement)
if (isa<BlockArgument>(src) ||
isa<calyx::RegisterOp, calyx::MemoryOp, calyx::SeqMemoryOp,
calyx::ConstantOp, hw::ConstantOp, mlir::arith::ConstantOp,
calyx::MultPipeLibOp, calyx::DivUPipeLibOp, calyx::DivSPipeLibOp,
calyx::RemSPipeLibOp, calyx::RemUPipeLibOp, mlir::scf::WhileOp,
calyx::InstanceOp>(src.getDefiningOp()))
hw::ConstantOp, mlir::arith::ConstantOp, calyx::MultPipeLibOp,
calyx::DivUPipeLibOp, calyx::DivSPipeLibOp, calyx::RemSPipeLibOp,
calyx::RemUPipeLibOp, mlir::scf::WhileOp, calyx::InstanceOp,
calyx::ConstantOp, calyx::AddFNOp>(src.getDefiningOp()))
continue;
auto srcCombGroup = dyn_cast<calyx::CombGroupOp>(

View File

@ -233,3 +233,27 @@ module {
return %arg0, %0, %1 : f32, i32, f32
}
}
// -----
// Test floating point add
// CHECK: calyx.group @bb0_0 {
// CHECK-DAG: calyx.assign %std_addFN_0.left = %in0 : f32
// CHECK-DAG: calyx.assign %std_addFN_0.right = %cst : f32
// CHECK-DAG: calyx.assign %addf_0_reg.in = %std_addFN_0.out : f32
// CHECK-DAG: calyx.assign %addf_0_reg.write_en = %std_addFN_0.done : i1
// CHECK-DAG: %0 = comb.xor %std_addFN_0.done, %true : i1
// CHECK-DAG: calyx.assign %std_addFN_0.go = %0 ? %true : i1
// CHECK-DAG: calyx.assign %std_addFN_0.subOp = %false : i1
// CHECK-DAG: calyx.group_done %addf_0_reg.done : i1
// CHECK-DAG: }
module {
func.func @main(%arg0 : f32) -> f32 {
%0 = arith.constant 4.2 : f32
%1 = arith.addf %arg0, %0 : f32
return %1 : f32
}
}

View File

@ -1,15 +1,5 @@
// RUN: circt-opt --lower-scf-to-calyx %s -split-input-file -verify-diagnostics
module {
func.func @f(%arg0 : f32, %arg1 : f32) -> f32 {
// expected-error @+1 {{failed to legalize operation 'arith.addf' that was explicitly marked illegal}}
%2 = arith.addf %arg0, %arg1 : f32
return %2 : f32
}
}
// -----
// expected-error @+1 {{Module contains multiple functions, but no top level function was set. Please see --top-level-function}}
module {
func.func @f1() {

View File

@ -280,3 +280,54 @@ module attributes {calyx.entrypoint = "main"} {
} {toplevel}
}
// -----
module attributes {calyx.entrypoint = "main"} {
// CHECK: import "primitives/float/addFN.futil";
calyx.component @main(%in0: f32, %clk: i1 {clk}, %reset: i1 {reset}, %go: i1 {go}) -> (%out0: f32, %done: i1 {done}) {
// CHECK: std_addFN_0 = std_addFN(8, 24, 32);
%cst = calyx.constant {sym_name = "cst_0"} 4.200000e+00 : f32
%true = hw.constant true
%false = hw.constant false
%addf_0_reg.in, %addf_0_reg.write_en, %addf_0_reg.clk, %addf_0_reg.reset, %addf_0_reg.out, %addf_0_reg.done = calyx.register @addf_0_reg : f32, i1, i1, i1, f32, i1
%std_addFN_0.clk, %std_addFN_0.reset, %std_addFN_0.go, %std_addFN_0.control, %std_addFN_0.subOp, %std_addFN_0.left, %std_addFN_0.right, %std_addFN_0.roundingMode, %std_addFN_0.out, %std_addFN_0.exceptionalFlags, %std_addFN_0.done = calyx.std_addFN @std_addFN_0 : i1, i1, i1, i1, i1, f32, f32, i3, f32, i5, i1
%ret_arg0_reg.in, %ret_arg0_reg.write_en, %ret_arg0_reg.clk, %ret_arg0_reg.reset, %ret_arg0_reg.out, %ret_arg0_reg.done = calyx.register @ret_arg0_reg : f32, i1, i1, i1, f32, i1
calyx.wires {
calyx.assign %out0 = %ret_arg0_reg.out : f32
// CHECK-LABEL: group bb0_0 {
// CHECK-NEXT: std_addFN_0.left = in0;
// CHECK-NEXT: std_addFN_0.right = cst_0.out;
// CHECK-NEXT: addf_0_reg.in = std_addFN_0.out;
// CHECK-NEXT: addf_0_reg.write_en = std_addFN_0.done;
// CHECK-NEXT: std_addFN_0.go = !std_addFN_0.done ? 1'd1;
// CHECK-NEXT: std_addFN_0.subOp = 1'd0;
// CHECK-NEXT: bb0_0[done] = addf_0_reg.done;
// CHECK-NEXT: }
calyx.group @bb0_0 {
calyx.assign %std_addFN_0.left = %in0 : f32
calyx.assign %std_addFN_0.right = %cst : f32
calyx.assign %addf_0_reg.in = %std_addFN_0.out : f32
calyx.assign %addf_0_reg.write_en = %std_addFN_0.done : i1
%0 = comb.xor %std_addFN_0.done, %true : i1
calyx.assign %std_addFN_0.go = %0 ? %true : i1
calyx.assign %std_addFN_0.subOp = %false : i1
calyx.group_done %addf_0_reg.done : i1
}
calyx.group @ret_assign_0 {
calyx.assign %ret_arg0_reg.in = %std_addFN_0.out : f32
calyx.assign %ret_arg0_reg.write_en = %true : i1
calyx.group_done %ret_arg0_reg.done : i1
}
}
calyx.control {
calyx.seq {
calyx.seq {
calyx.enable @bb0_0
calyx.enable @ret_assign_0
}
}
}
} {toplevel}
}