[comb] Build Reduction Ops whenever possible (#3394)

This commit adds a canonicalizer to build reduction ops when possible to 
reduce operations of the form `or(a[0], a[1], ..., a[n])` to `icmp ne(a, 0)`.

This optimization is implemented for the `and`, `or` and `xor` operations

Co-authored-by: Hideto Ueno <uenoku.tokotoko@gmail.com>
This commit is contained in:
Schottkyc137 2022-10-17 17:31:49 +02:00 committed by GitHub
parent 3786f40cc2
commit ecd49163a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 77 additions and 0 deletions

View File

@ -786,6 +786,42 @@ OpFoldResult AndOp::fold(ArrayRef<Attribute> constants) {
return constFoldAssociativeOp(constants, hw::PEO::And);
}
/// Returns a single common operand that all inputs of the operation `op` can
/// be traced back to, or an empty `Value` if no such operand exists.
///
/// For example for `or(a[0], a[1], ..., a[n-1])` this function returns `a`
/// (assuming the bit-width of `a` is `n`).
template <typename Op>
static Value getCommonOperand(Op op) {
if (!op.getType().isInteger(1))
return Value();
auto inputs = op.getInputs();
size_t size = inputs.size();
auto sourceOp = inputs[0].template getDefiningOp<ExtractOp>();
if (!sourceOp)
return Value();
Value source = sourceOp.getOperand();
// Fast path: the input size is not equal to the width of the source.
if (size != source.getType().getIntOrFloatBitWidth())
return Value();
// Tracks the bits that were encountered.
llvm::BitVector bits(size);
bits.set(sourceOp.getLowBit());
for (size_t i = 1; i != size; ++i) {
auto extractOp = inputs[i].template getDefiningOp<ExtractOp>();
if (!extractOp || extractOp.getOperand() != source)
return Value();
bits.set(extractOp.getLowBit());
}
return bits.all() ? source : Value();
}
/// Canonicalize an idempotent operation `op` so that only one input of any kind
/// occurs.
///
@ -927,6 +963,15 @@ LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) {
if (narrowOperationWidth(op, true, rewriter))
return success();
// and(a[0], a[1], ..., a[n]) -> icmp eq(a, -1)
if (auto source = getCommonOperand(op)) {
auto cmpAgainst =
rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getAllOnes(size));
replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, ICmpPredicate::eq,
source, cmpAgainst);
return success();
}
/// TODO: and(..., x, not(x)) -> and(..., 0) -- complement
return failure();
}
@ -1129,6 +1174,15 @@ LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
if (narrowOperationWidth(op, true, rewriter))
return success();
// or(a[0], a[1], ..., a[n]) -> icmp ne(a, 0)
if (auto source = getCommonOperand(op)) {
auto cmpAgainst =
rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(size));
replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, ICmpPredicate::ne,
source, cmpAgainst);
return success();
}
/// TODO: or(..., x, not(x)) -> or(..., '1) -- complement
return failure();
}
@ -1250,6 +1304,12 @@ LogicalResult XorOp::canonicalize(XorOp op, PatternRewriter &rewriter) {
if (narrowOperationWidth(op, true, rewriter))
return success();
// xor(a[0], a[1], ..., a[n]) -> parity(a)
if (auto source = getCommonOperand(op)) {
replaceOpWithNewOpAndCopyName<ParityOp>(rewriter, op, source);
return success();
}
return failure();
}

View File

@ -1376,6 +1376,23 @@ hw.module @propagateNamehint(%x: i16) -> (o: i1) {
hw.output %0 : i1
}
// CHECK-LABEL: @extractToReductionOps
hw.module @extractToReductionOps(%a: i1, %b: i2) -> (c: i1, d: i1, e: i1) {
// CHECK-NEXT: %c-1_i2 = hw.constant -1 : i2
// CHECK-NEXT: %c0_i2 = hw.constant 0 : i2
// CHECK-NEXT: %0 = comb.icmp ne %b, %c0_i2 : i2
// CHECK-NEXT: %1 = comb.icmp eq %b, %c-1_i2 : i2
// CHECK-NEXT: %2 = comb.parity %b : i2
// CHECK-NEXT: hw.output %0, %1, %2 : i1, i1, i1
%0 = comb.extract %b from 1 : (i2) -> i1
%1 = comb.extract %b from 0 : (i2) -> i1
%2 = comb.or %0, %1 : i1
%3 = comb.and %0, %1 : i1
%4 = comb.xor %0, %1 : i1
hw.output %2, %3, %4 : i1, i1, i1
}
// https://github.com/llvm/circt/issues/2546
// CHECK-LABEL: @Issue2546
hw.module @Issue2546() -> (b: i1) {