Calyx ConstantOp Support (#7086)

* create register based on type

* support calyx constant op and the corresponding emitter

Co-authored-by: Chris Gyurgyik <Gyurgyikcp@gmail.com>

---------

Co-authored-by: Chris Gyurgyik <Gyurgyikcp@gmail.com>
This commit is contained in:
Jiahan Xie 2024-10-30 19:29:01 -04:00 committed by GitHub
parent f22f11aa40
commit b49d2b3adc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 235 additions and 17 deletions

View File

@ -30,6 +30,10 @@ calyx::RegisterOp createRegister(Location loc, OpBuilder &builder,
ComponentOp component, size_t width,
Twine prefix);
calyx::RegisterOp createRegister(Location loc, OpBuilder &builder,
ComponentOp component, Type type,
Twine prefix);
/// A helper function to create constants in the HW dialect.
hw::ConstantOp createConstant(Location loc, OpBuilder &builder,
ComponentOp component, size_t width,

View File

@ -10,6 +10,8 @@
//
//===----------------------------------------------------------------------===//
include "mlir/IR/BuiltinAttributeInterfaces.td"
/// Base class for Calyx primitives.
class CalyxPrimitive<string mnemonic, list<Trait> traits = []> :
CalyxCell<mnemonic, traits> {
@ -18,6 +20,40 @@ class CalyxPrimitive<string mnemonic, list<Trait> traits = []> :
let skipDefaultBuilders = 1;
}
def ConstantOp: CalyxPrimitive<"constant",
[ConstantLike, FirstAttrDerivedResultType,
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
AllTypesMatch<["value", "out"]>
]> {
let summary = "integer or floating point constant";
let description = [{
The `constant` operation produces an SSA value equal to some integer or
floating-point constant specified by an attribute.
Example:
```
// Integer constant
%1 = calyx.constant 42 : i32
// Floating point constant
%1 = calyx.constant 42.00+e00 : f32
```
}];
let arguments = (ins TypedAttrInterface:$value);
let results = (outs SignlessIntegerOrFloatLike:$out);
let builders = [
/// Build a ConstantOp from a prebuilt attribute.
OpBuilder <(ins "StringRef":$sym_name, "TypedAttr":$attr)>,
];
let hasFolder = 1;
let assemblyFormat = "attr-dict $value";
let hasVerifier = 1;
}
/// The n-bit, undef op which only provides the out signal
def UndefLibOp: CalyxPrimitive<"undefined", []> {
let summary = "An undefined signal";

View File

@ -891,12 +891,23 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
arith::ConstantOp constOp) const {
/// Move constant operations to the compOp body as hw::ConstantOp's.
APInt value;
calyx::matchConstantOp(constOp, value);
auto hwConstOp = rewriter.replaceOpWithNewOp<hw::ConstantOp>(constOp, value);
hwConstOp->moveAfter(getComponent().getBodyBlock(),
getComponent().getBodyBlock()->begin());
if (isa<IntegerType>(constOp.getType())) {
/// Move constant operations to the compOp body as hw::ConstantOp's.
APInt value;
calyx::matchConstantOp(constOp, value);
auto hwConstOp =
rewriter.replaceOpWithNewOp<hw::ConstantOp>(constOp, value);
hwConstOp->moveAfter(getComponent().getBodyBlock(),
getComponent().getBodyBlock()->begin());
} else {
std::string name = getState<ComponentLoweringState>().getUniqueName("cst");
auto calyxConstOp = rewriter.create<calyx::ConstantOp>(
constOp.getLoc(), name, constOp.getValueAttr());
calyxConstOp->moveAfter(getComponent().getBodyBlock(),
getComponent().getBodyBlock()->begin());
rewriter.replaceAllUsesWith(constOp, calyxConstOp.getOut());
}
return success();
}

View File

@ -1953,6 +1953,74 @@ ParseResult GroupDoneOp::parse(OpAsmParser &parser, OperationState &result) {
return parseGroupPort(parser, result);
}
//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//
void ConstantOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
if (isa<FloatAttr>(getValue())) {
setNameFn(getResult(), "cst");
return;
}
auto intCst = llvm::dyn_cast<IntegerAttr>(getValue());
auto intType = llvm::dyn_cast<IntegerType>(getType());
// Sugar i1 constants with 'true' and 'false'.
if (intType && intType.getWidth() == 1)
return setNameFn(getResult(), intCst.getInt() > 0 ? "true" : "false");
// Otherwise, build a complex name with the value and type.
SmallString<32> specialNameBuffer;
llvm::raw_svector_ostream specialName(specialNameBuffer);
specialName << 'c' << intCst.getValue();
if (intType)
specialName << '_' << getType();
setNameFn(getResult(), specialName.str());
}
LogicalResult ConstantOp::verify() {
auto type = getType();
// The value's type must match the return type.
if (auto valType = getValue().getType(); valType != type) {
return emitOpError() << "value type " << valType
<< " must match return type: " << type;
}
// Integer values must be signless.
if (llvm::isa<IntegerType>(type) &&
!llvm::cast<IntegerType>(type).isSignless())
return emitOpError("integer return type must be signless");
// Any float or integers attribute are acceptable.
if (!llvm::isa<IntegerAttr, FloatAttr>(getValue())) {
return emitOpError("value must be an integer or float attribute");
}
return success();
}
OpFoldResult calyx::ConstantOp::fold(FoldAdaptor adaptor) {
return getValueAttr();
}
void calyx::ConstantOp::build(OpBuilder &builder, OperationState &state,
StringRef symName, TypedAttr attr) {
state.addAttribute(SymbolTable::getSymbolAttrName(),
builder.getStringAttr(symName));
state.addAttribute("value", attr);
SmallVector<Type> types;
types.push_back(attr.getType()); // Out
state.addTypes(types);
}
SmallVector<StringRef> ConstantOp::portNames() { return {"out"}; }
SmallVector<Direction> ConstantOp::portDirections() { return {Output}; }
SmallVector<DictionaryAttr> ConstantOp::portAttributes() {
return {DictionaryAttr::get(getContext())};
}
bool ConstantOp::isCombinational() { return true; }
//===----------------------------------------------------------------------===//
// RegisterOp
//===----------------------------------------------------------------------===//

View File

@ -22,7 +22,10 @@
#include "mlir/Tools/mlir-translate/Translation.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include <bitset>
#include <string>
using namespace circt;
using namespace calyx;
@ -142,6 +145,10 @@ private:
static constexpr std::string_view sMemories = "memories/seq";
return {sMemories};
})
.Case<ConstantOp>([&](auto op) -> FailureOr<StringRef> {
static constexpr std::string_view sFloat = "float";
return {sFloat};
})
.Default([&](auto op) {
auto diag = op->emitOpError() << "not supported for emission";
return diag;
@ -253,6 +260,9 @@ struct Emitter {
// Invoke emission
void emitInvoke(InvokeOp invoke);
// Floating point Constant emission
void emitConstant(ConstantOp constant);
// Emits a library primitive with template parameters based on all in- and
// output ports.
// e.g.:
@ -445,7 +455,7 @@ private:
return;
}
auto definingOp = value.getDefiningOp();
auto *definingOp = value.getDefiningOp();
assert(definingOp && "Value does not have a defining operation.");
TypeSwitch<Operation *>(definingOp)
@ -638,6 +648,7 @@ void Emitter::emitComponent(ComponentInterface op) {
.Case<MemoryOp>([&](auto op) { emitMemory(op); })
.Case<SeqMemoryOp>([&](auto op) { emitSeqMemory(op); })
.Case<hw::ConstantOp>([&](auto op) { /*Do nothing*/ })
.Case<calyx::ConstantOp>([&](auto op) { emitConstant(op); })
.Case<SliceLibOp, PadLibOp, ExtSILibOp>(
[&](auto op) { emitLibraryPrimTypedByAllPorts(op); })
.Case<LtLibOp, GtLibOp, EqLibOp, NeqLibOp, GeLibOp, LeLibOp, SltLibOp,
@ -899,6 +910,23 @@ void Emitter::emitInvoke(InvokeOp invoke) {
os << RParen() << semicolonEndL();
}
void Emitter::emitConstant(ConstantOp constantOp) {
TypedAttr attr = constantOp.getValueAttr();
assert(isa<FloatAttr>(attr) && "must be a floating point constant");
auto fltAttr = cast<FloatAttr>(attr);
APFloat value = fltAttr.getValue();
auto type = cast<FloatType>(fltAttr.getType());
double doubleValue = value.convertToDouble();
auto floatBits = value.getSizeInBits(type.getFloatSemantics());
indent() << constantOp.getName().str() << space() << equals() << space()
<< "std_float_const";
// Currently defaults to IEEE-754 representation [1].
// [1]: https://github.com/calyxir/calyx/blob/main/primitives/float.futil
static constexpr int32_t IEEE754 = 0;
os << LParen() << std::to_string(IEEE754) << comma() << floatBits << comma()
<< std::to_string(doubleValue) << RParen() << semicolonEndL();
}
/// Calling getName() on a calyx operation will return "calyx.${opname}". This
/// function returns whatever is left after the first '.' in the string,
/// removing the 'calyx' prefix.
@ -954,8 +982,8 @@ void Emitter::emitWires(WiresOp op) {
TypeSwitch<Operation *>(&bodyOp)
.Case<GroupInterface>([&](auto op) { emitGroup(op); })
.Case<AssignOp>([&](auto op) { emitAssignment(op); })
.Case<hw::ConstantOp, comb::AndOp, comb::OrOp, comb::XorOp, CycleOp>(
[&](auto op) { /* Do nothing. */ })
.Case<hw::ConstantOp, calyx::ConstantOp, comb::AndOp, comb::OrOp,
comb::XorOp, CycleOp>([&](auto op) { /* Do nothing. */ })
.Default([&](auto op) {
emitOpError(op, "not supported for emission inside wires section");
});

View File

@ -27,6 +27,14 @@ calyx::RegisterOp createRegister(Location loc, OpBuilder &builder,
return builder.create<RegisterOp>(loc, (prefix + "_reg").str(), width);
}
calyx::RegisterOp createRegister(Location loc, OpBuilder &builder,
ComponentOp component, Type type,
Twine prefix) {
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToStart(component.getBodyBlock());
return builder.create<RegisterOp>(loc, (prefix + "_reg").str(), type);
}
hw::ConstantOp createConstant(Location loc, OpBuilder &builder,
ComponentOp component, size_t width,
size_t value) {

View File

@ -657,10 +657,10 @@ void InlineCombGroups::recurseInlineCombGroups(
// LateSSAReplacement)
if (isa<BlockArgument>(src) ||
isa<calyx::RegisterOp, calyx::MemoryOp, calyx::SeqMemoryOp,
hw::ConstantOp, mlir::arith::ConstantOp, calyx::MultPipeLibOp,
calyx::DivUPipeLibOp, calyx::DivSPipeLibOp, calyx::RemSPipeLibOp,
calyx::RemUPipeLibOp, mlir::scf::WhileOp, calyx::InstanceOp>(
src.getDefiningOp()))
calyx::ConstantOp, hw::ConstantOp, mlir::arith::ConstantOp,
calyx::MultPipeLibOp, calyx::DivUPipeLibOp, calyx::DivSPipeLibOp,
calyx::RemSPipeLibOp, calyx::RemUPipeLibOp, mlir::scf::WhileOp,
calyx::InstanceOp>(src.getDefiningOp()))
continue;
auto srcCombGroup = dyn_cast<calyx::CombGroupOp>(
@ -753,11 +753,11 @@ BuildReturnRegs::partiallyLowerFuncToComp(mlir::func::FuncOp funcOp,
for (auto argType : enumerate(funcOp.getResultTypes())) {
auto convArgType = calyx::convIndexType(rewriter, argType.value());
assert(isa<IntegerType>(convArgType) && "unsupported return type");
unsigned width = convArgType.getIntOrFloatBitWidth();
assert((isa<IntegerType>(convArgType) || isa<FloatType>(convArgType)) &&
"unsupported return type");
std::string name = "ret_arg" + std::to_string(argType.index());
auto reg =
createRegister(funcOp.getLoc(), rewriter, getComponent(), width, name);
auto reg = createRegister(funcOp.getLoc(), rewriter, getComponent(),
convArgType, name);
getState().addReturnReg(reg, argType.index());
rewriter.setInsertionPointToStart(

View File

@ -209,3 +209,27 @@ module {
return %0, %1 : i8, i8
}
}
// -----
// Test integer and floating point constant
// CHECK: calyx.group @ret_assign_0 {
// CHECK-DAG: calyx.assign %ret_arg0_reg.in = %in0 : f32
// CHECK-DAG: calyx.assign %ret_arg0_reg.write_en = %true : i1
// CHECK-DAG: calyx.assign %ret_arg1_reg.in = %c42_i32 : i32
// CHECK-DAG: calyx.assign %ret_arg1_reg.write_en = %true : i1
// CHECK-DAG: calyx.assign %ret_arg2_reg.in = %cst : f32
// CHECK-DAG: calyx.assign %ret_arg2_reg.write_en = %true : i1
// CHECK-DAG: %0 = comb.and %ret_arg2_reg.done, %ret_arg1_reg.done, %ret_arg0_reg.done : i1
// CHECK-DAG: calyx.group_done %0 ? %true : i1
// CHECK-DAG: }
module {
func.func @main(%arg0 : f32) -> (f32, i32, f32) {
%0 = arith.constant 42 : i32
%1 = arith.constant 4.2e+1 : f32
return %arg0, %0, %1 : f32, i32, f32
}
}

View File

@ -241,3 +241,42 @@ module attributes {calyx.entrypoint = "main"} {
}
}
}
// -----
module attributes {calyx.entrypoint = "main"} {
calyx.component @main(%clk: i1 {clk}, %reset: i1 {reset}, %go: i1 {go}) -> (%out0: i32, %out1: f32, %done: i1 {done}) {
// CHECK: cst_0 = std_float_const(0, 32, 4.200000);
%c42_i32 = hw.constant 42 : i32
%cst = calyx.constant {sym_name = "cst_0"} 4.200000e+00 : f32
%true = hw.constant true
%ret_arg1_reg.in, %ret_arg1_reg.write_en, %ret_arg1_reg.clk, %ret_arg1_reg.reset, %ret_arg1_reg.out, %ret_arg1_reg.done = calyx.register @ret_arg1_reg : f32, i1, i1, i1, f32, i1
%ret_arg0_reg.in, %ret_arg0_reg.write_en, %ret_arg0_reg.clk, %ret_arg0_reg.reset, %ret_arg0_reg.out, %ret_arg0_reg.done = calyx.register @ret_arg0_reg : i32, i1, i1, i1, i32, i1
calyx.wires {
calyx.assign %out1 = %ret_arg1_reg.out : f32
calyx.assign %out0 = %ret_arg0_reg.out : i32
// CHECK-LABEL: group ret_assign_0 {
// CHECK-NEXT: ret_arg0_reg.in = 32'd42;
// CHECK-NEXT: ret_arg0_reg.write_en = 1'd1;
// CHECK-NEXT: ret_arg1_reg.in = cst_0.out;
// CHECK-NEXT: ret_arg1_reg.write_en = 1'd1;
// CHECK-NEXT: ret_assign_0[done] = (ret_arg1_reg.done & ret_arg0_reg.done) ? 1'd1;
// CHECK-NEXT: }
calyx.group @ret_assign_0 {
calyx.assign %ret_arg0_reg.in = %c42_i32 : i32
calyx.assign %ret_arg0_reg.write_en = %true : i1
calyx.assign %ret_arg1_reg.in = %cst : f32
calyx.assign %ret_arg1_reg.write_en = %true : i1
%0 = comb.and %ret_arg1_reg.done, %ret_arg0_reg.done : i1
calyx.group_done %0 ? %true : i1
}
}
calyx.control {
calyx.seq {
calyx.enable @ret_assign_0
}
}
} {toplevel}
}