[Calyx] Support lowering of `arith.select` (#5857)

* Add calyx mux op

* legalize lowering for arith.select in scf-to-calyx

* hitting assert

* basic lowering works

* lowering works!

* calyx emitter supports std_mux

* add arith lowering test

* emission test

* line length nit
This commit is contained in:
Rachit Nigam 2023-08-17 16:43:52 +05:30 committed by GitHub
parent 33d868974a
commit b7d08125b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 71 additions and 19 deletions

View File

@ -123,7 +123,7 @@ def MemoryOp : CalyxPrimitive<"memory", []> {
def SeqMemoryOp : CalyxPrimitive<"seq_mem", []> {
let summary = "Defines a memory with sequential read";
let description = [{
The "calyx.seq_mem" op defines a memory with sequential reads. Memories can
The "calyx.seq_mem" op defines a memory with sequential reads. Memories can
have any number of dimensions, as specified by the length of the `$sizes` and
`$addrSizes` arrays. The `$addrSizes` specify the bitwidth of each dimension's
address, and should be wide enough to address the range of the corresponding
@ -267,6 +267,12 @@ def AndLibOp : CombinationalArithBinaryLibraryOp<"and"> {}
def OrLibOp : CombinationalArithBinaryLibraryOp<"or"> {}
def XorLibOp : CombinationalArithBinaryLibraryOp<"xor"> {}
def MuxLibOp : CalyxLibraryOp<"mux", [
Combinational, SameTypeConstraint<"tru", "fal">, SameTypeConstraint<"tru", "out">
]> {
let results = (outs I1:$sel, AnyType:$tru, AnyType:$fal, AnyType:$out);
}
class ArithBinaryPipeLibraryOp<string mnemonic> : ArithBinaryLibraryOp<mnemonic # "_pipe", [
SameTypeConstraint<"left", "out">
]> {

View File

@ -237,6 +237,20 @@ private:
.Case([&](XorLibOp op) {
convertArithBinaryOp<XorLibOp, XorOp>(op, wires, b);
})
.Case([&](MuxLibOp op) {
auto sel = wireIn(op.getSel(), op.instanceName(),
op.portName(op.getSel()), b);
auto tru = wireIn(op.getTru(), op.instanceName(),
op.portName(op.getTru()), b);
auto fal = wireIn(op.getFal(), op.instanceName(),
op.portName(op.getFal()), b);
auto mux = b.create<MuxOp>(sel, tru, fal);
auto out =
wireOut(mux, op.instanceName(), op.portName(op.getOut()), b);
wires.append({sel.getInput(), tru.getInput(), fal.getInput(), out});
})
// Pipelined arithmetic operations.
.Case([&](MultPipeLibOp op) {
convertPipelineOp<MultPipeLibOp, comb::MulOp>(op, wires, b);

View File

@ -210,7 +210,7 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
AddIOp, SubIOp, CmpIOp, ShLIOp, ShRUIOp, ShRSIOp,
AndIOp, XOrIOp, OrIOp, ExtUIOp, ExtSIOp, TruncIOp,
MulIOp, DivUIOp, DivSIOp, RemUIOp, RemSIOp,
IndexCastOp>(
SelectOp, IndexCastOp>(
[&](auto op) { return buildOp(rewriter, op).succeeded(); })
.template Case<FuncOp, scf::ConditionOp>([&](auto) {
/// Skip: these special cases will be handled separately.
@ -235,6 +235,7 @@ private:
BranchOpInterface brOp) const;
LogicalResult buildOp(PatternRewriter &rewriter,
arith::ConstantOp constOp) const;
LogicalResult buildOp(PatternRewriter &rewriter, SelectOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, AddIOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, SubIOp op) const;
LogicalResult buildOp(PatternRewriter &rewriter, MulIOp op) const;
@ -788,6 +789,10 @@ LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
XOrIOp op) const {
return buildLibraryOp<calyx::CombGroupOp, calyx::XorLibOp>(rewriter, op);
}
LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
SelectOp op) const {
return buildLibraryOp<calyx::CombGroupOp, calyx::MuxLibOp>(rewriter, op);
}
LogicalResult BuildOpGroups::buildOp(PatternRewriter &rewriter,
CmpIOp op) const {
@ -1557,10 +1562,11 @@ public:
// Only accept std operations which we've added lowerings for
target.addIllegalDialect<FuncDialect>();
target.addIllegalDialect<ArithDialect>();
target.addLegalOp<AddIOp, SubIOp, CmpIOp, ShLIOp, ShRUIOp, ShRSIOp, AndIOp,
XOrIOp, OrIOp, ExtUIOp, TruncIOp, CondBranchOp, BranchOp,
MulIOp, DivUIOp, DivSIOp, RemUIOp, RemSIOp, ReturnOp,
arith::ConstantOp, IndexCastOp, FuncOp, ExtSIOp>();
target.addLegalOp<AddIOp, SelectOp, SubIOp, CmpIOp, ShLIOp, ShRUIOp,
ShRSIOp, AndIOp, XOrIOp, OrIOp, ExtUIOp, TruncIOp,
CondBranchOp, BranchOp, MulIOp, DivUIOp, DivSIOp, RemUIOp,
RemSIOp, ReturnOp, arith::ConstantOp, IndexCastOp, FuncOp,
ExtSIOp>();
RewritePatternSet legalizePatterns(&getContext());
legalizePatterns.add<DummyPattern>(&getContext());

View File

@ -2851,6 +2851,21 @@ LogicalResult SliceLibOp::verify() {
return success();
}
SmallVector<StringRef> MuxLibOp::portNames() {
return {"sel", "tru", "fal", "out"};
}
SmallVector<Direction> MuxLibOp::portDirections() {
return {Input, Input, Input, Output};
}
void MuxLibOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
getCellAsmResultNames(setNameFn, *this, this->portNames());
}
bool MuxLibOp::isCombinational() { return true; }
SmallVector<DictionaryAttr> MuxLibOp::portAttributes() {
return {DictionaryAttr::get(getContext()), DictionaryAttr::get(getContext()),
DictionaryAttr::get(getContext()), DictionaryAttr::get(getContext())};
}
#define ImplBinPipeOpCellInterface(OpType, outName) \
SmallVector<StringRef> OpType::portNames() { \
return {"clk", "reset", "go", "left", "right", outName, "done"}; \

View File

@ -128,11 +128,11 @@ private:
return TypeSwitch<Operation *, FailureOr<StringRef>>(op)
.Case<MemoryOp, RegisterOp, NotLibOp, AndLibOp, OrLibOp, XorLibOp,
AddLibOp, SubLibOp, GtLibOp, LtLibOp, EqLibOp, NeqLibOp, GeLibOp,
LeLibOp, LshLibOp, RshLibOp, SliceLibOp, PadLibOp, WireLibOp>(
[&](auto op) -> FailureOr<StringRef> {
static constexpr std::string_view sCore = "core";
return {sCore};
})
LeLibOp, LshLibOp, RshLibOp, SliceLibOp, PadLibOp, WireLibOp,
MuxLibOp>([&](auto op) -> FailureOr<StringRef> {
static constexpr std::string_view sCore = "core";
return {sCore};
})
.Case<SgtLibOp, SltLibOp, SeqLibOp, SneqLibOp, SgeLibOp, SleLibOp,
SrshLibOp, MultPipeLibOp, RemUPipeLibOp, RemSPipeLibOp,
DivUPipeLibOp, DivSPipeLibOp>(
@ -628,6 +628,8 @@ void Emitter::emitComponent(ComponentInterface op) {
SubLibOp, ShruLibOp, RshLibOp, SrshLibOp, LshLibOp, AndLibOp,
NotLibOp, OrLibOp, XorLibOp, WireLibOp>(
[&](auto op) { emitLibraryPrimTypedByFirstInputPort(op); })
.Case<MuxLibOp>(
[&](auto op) { emitLibraryPrimTypedByFirstOutputPort(op); })
.Case<MultPipeLibOp>(
[&](auto op) { emitLibraryPrimTypedByFirstOutputPort(op); })
.Case<RemUPipeLibOp, DivUPipeLibOp>([&](auto op) {

View File

@ -1,7 +1,7 @@
// RUN: circt-opt %s --lower-scf-to-calyx -canonicalize -split-input-file | FileCheck %s
// CHECK: module attributes {calyx.entrypoint = "main"} {
// CHECK-LABEL: calyx.component @main(%in0: i32, %in1: i32, %clk: i1 {clk}, %reset: i1 {reset}, %go: i1 {go}) -> (%out0: i32, %done: i1 {done}) {
// CHECK-LABEL: calyx.component @main(%in0: i1, %in1: i32, %in2: i32, %clk: i1 {clk}, %reset: i1 {reset}, %go: i1 {go}) -> (%out0: i32, %done: i1 {done}) {
// CHECK-DAG: %true = hw.constant true
// CHECK-DAG: %std_sub_0.left, %std_sub_0.right, %std_sub_0.out = calyx.std_sub @std_sub_0 : i32, i32, i32
// CHECK-DAG: %std_lsh_0.left, %std_lsh_0.right, %std_lsh_0.out = calyx.std_lsh @std_lsh_0 : i32, i32, i32
@ -9,15 +9,18 @@
// CHECK-DAG: %ret_arg0_reg.in, %ret_arg0_reg.write_en, %ret_arg0_reg.clk, %ret_arg0_reg.reset, %ret_arg0_reg.out, %ret_arg0_reg.done = calyx.register @ret_arg0_reg : i32, i1, i1, i1, i32, i1
// CHECK-NEXT: calyx.wires {
// CHECK-NEXT: calyx.assign %out0 = %ret_arg0_reg.out : i32
// CHECK-NEXT: calyx.group @ret_assign_0 {
// CHECK-NEXT: calyx.assign %ret_arg0_reg.in = %std_sub_0.out : i32
// CHECK-NEXT: calyx.group @ret_assign_0 {
// CHECK-NEXT: calyx.assign %ret_arg0_reg.in = %std_mux_0.out : i32
// CHECK-NEXT: calyx.assign %ret_arg0_reg.write_en = %true : i1
// CHECK-NEXT: calyx.assign %std_mux_0.sel = %in0 : i1
// CHECK-NEXT: calyx.assign %std_mux_0.tru = %std_sub_0.out : i32
// CHECK-NEXT: calyx.assign %std_sub_0.left = %std_lsh_0.out : i32
// CHECK-NEXT: calyx.assign %std_lsh_0.left = %std_add_0.out : i32
// CHECK-NEXT: calyx.assign %std_add_0.left = %in0 : i32
// CHECK-NEXT: calyx.assign %std_add_0.right = %in1 : i32
// CHECK-NEXT: calyx.assign %std_lsh_0.right = %in0 : i32
// CHECK-NEXT: calyx.assign %std_add_0.left = %in1 : i32
// CHECK-NEXT: calyx.assign %std_add_0.right = %in2 : i32
// CHECK-NEXT: calyx.assign %std_lsh_0.right = %in1 : i32
// CHECK-NEXT: calyx.assign %std_sub_0.right = %std_add_0.out : i32
// CHECK-NEXT: calyx.assign %std_mux_0.fal = %std_add_0.out : i32
// CHECK-NEXT: calyx.group_done %ret_arg0_reg.done : i1
// CHECK-NEXT: }
// CHECK-NEXT: }
@ -29,11 +32,12 @@
// CHECK-NEXT: } {toplevel}
// CHECK-NEXT: }
module {
func.func @main(%a0 : i32, %a1 : i32) -> i32 {
func.func @main(%sel : i1, %a0 : i32, %a1 : i32) -> i32 {
%0 = arith.addi %a0, %a1 : i32
%1 = arith.shli %0, %a0 : i32
%2 = arith.subi %1, %0 : i32
return %2 : i32
%3 = arith.select %sel, %2, %0 : i32
return %3 : i32
}
}

View File

@ -6,11 +6,16 @@ module attributes {calyx.entrypoint = "main"} {
%0 = hw.constant 1 : i32
// CHECK: add0 = std_add(32);
%1:3 = calyx.std_add @add0 : i32, i32, i32
// CHECK = mux0 = std_mux(32);
%2:4 = calyx.std_mux @mux0 : i1, i32, i32, i32
%3 = hw.constant 1 : i1
calyx.wires {
// CHECK: add0.left = in;
calyx.assign %1#0 = %in : i32
// CHECK: add0.right = 32'd1;
calyx.assign %1#1 = %0 : i32
// CHECK: mux0.sel = 1'd1;
calyx.assign %2#0 = %3 : i1
// CHECK: out = add0.out;
calyx.assign %out = %1#2 : i32
}