mirror of https://github.com/llvm/circt.git
[comb] Build Reduction Ops whenever possible (#3394)
This commit adds a canonicalizer to build reduction ops when possible to reduce operations of the form `or(a[0], a[1], ..., a[n])` to `icmp ne(a, 0)`. This optimization is implemented for the `and`, `or` and `xor` operations Co-authored-by: Hideto Ueno <uenoku.tokotoko@gmail.com>
This commit is contained in:
parent
3786f40cc2
commit
ecd49163a5
|
@ -786,6 +786,42 @@ OpFoldResult AndOp::fold(ArrayRef<Attribute> constants) {
|
|||
return constFoldAssociativeOp(constants, hw::PEO::And);
|
||||
}
|
||||
|
||||
/// Returns a single common operand that all inputs of the operation `op` can
|
||||
/// be traced back to, or an empty `Value` if no such operand exists.
|
||||
///
|
||||
/// For example for `or(a[0], a[1], ..., a[n-1])` this function returns `a`
|
||||
/// (assuming the bit-width of `a` is `n`).
|
||||
template <typename Op>
|
||||
static Value getCommonOperand(Op op) {
|
||||
if (!op.getType().isInteger(1))
|
||||
return Value();
|
||||
|
||||
auto inputs = op.getInputs();
|
||||
size_t size = inputs.size();
|
||||
|
||||
auto sourceOp = inputs[0].template getDefiningOp<ExtractOp>();
|
||||
if (!sourceOp)
|
||||
return Value();
|
||||
Value source = sourceOp.getOperand();
|
||||
|
||||
// Fast path: the input size is not equal to the width of the source.
|
||||
if (size != source.getType().getIntOrFloatBitWidth())
|
||||
return Value();
|
||||
|
||||
// Tracks the bits that were encountered.
|
||||
llvm::BitVector bits(size);
|
||||
bits.set(sourceOp.getLowBit());
|
||||
|
||||
for (size_t i = 1; i != size; ++i) {
|
||||
auto extractOp = inputs[i].template getDefiningOp<ExtractOp>();
|
||||
if (!extractOp || extractOp.getOperand() != source)
|
||||
return Value();
|
||||
bits.set(extractOp.getLowBit());
|
||||
}
|
||||
|
||||
return bits.all() ? source : Value();
|
||||
}
|
||||
|
||||
/// Canonicalize an idempotent operation `op` so that only one input of any kind
|
||||
/// occurs.
|
||||
///
|
||||
|
@ -927,6 +963,15 @@ LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) {
|
|||
if (narrowOperationWidth(op, true, rewriter))
|
||||
return success();
|
||||
|
||||
// and(a[0], a[1], ..., a[n]) -> icmp eq(a, -1)
|
||||
if (auto source = getCommonOperand(op)) {
|
||||
auto cmpAgainst =
|
||||
rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getAllOnes(size));
|
||||
replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, ICmpPredicate::eq,
|
||||
source, cmpAgainst);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// TODO: and(..., x, not(x)) -> and(..., 0) -- complement
|
||||
return failure();
|
||||
}
|
||||
|
@ -1129,6 +1174,15 @@ LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
|
|||
if (narrowOperationWidth(op, true, rewriter))
|
||||
return success();
|
||||
|
||||
// or(a[0], a[1], ..., a[n]) -> icmp ne(a, 0)
|
||||
if (auto source = getCommonOperand(op)) {
|
||||
auto cmpAgainst =
|
||||
rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(size));
|
||||
replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, ICmpPredicate::ne,
|
||||
source, cmpAgainst);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// TODO: or(..., x, not(x)) -> or(..., '1) -- complement
|
||||
return failure();
|
||||
}
|
||||
|
@ -1250,6 +1304,12 @@ LogicalResult XorOp::canonicalize(XorOp op, PatternRewriter &rewriter) {
|
|||
if (narrowOperationWidth(op, true, rewriter))
|
||||
return success();
|
||||
|
||||
// xor(a[0], a[1], ..., a[n]) -> parity(a)
|
||||
if (auto source = getCommonOperand(op)) {
|
||||
replaceOpWithNewOpAndCopyName<ParityOp>(rewriter, op, source);
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
|
|
|
@ -1376,6 +1376,23 @@ hw.module @propagateNamehint(%x: i16) -> (o: i1) {
|
|||
hw.output %0 : i1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @extractToReductionOps
|
||||
hw.module @extractToReductionOps(%a: i1, %b: i2) -> (c: i1, d: i1, e: i1) {
|
||||
// CHECK-NEXT: %c-1_i2 = hw.constant -1 : i2
|
||||
// CHECK-NEXT: %c0_i2 = hw.constant 0 : i2
|
||||
// CHECK-NEXT: %0 = comb.icmp ne %b, %c0_i2 : i2
|
||||
// CHECK-NEXT: %1 = comb.icmp eq %b, %c-1_i2 : i2
|
||||
// CHECK-NEXT: %2 = comb.parity %b : i2
|
||||
// CHECK-NEXT: hw.output %0, %1, %2 : i1, i1, i1
|
||||
%0 = comb.extract %b from 1 : (i2) -> i1
|
||||
%1 = comb.extract %b from 0 : (i2) -> i1
|
||||
%2 = comb.or %0, %1 : i1
|
||||
%3 = comb.and %0, %1 : i1
|
||||
%4 = comb.xor %0, %1 : i1
|
||||
|
||||
hw.output %2, %3, %4 : i1, i1, i1
|
||||
}
|
||||
|
||||
// https://github.com/llvm/circt/issues/2546
|
||||
// CHECK-LABEL: @Issue2546
|
||||
hw.module @Issue2546() -> (b: i1) {
|
||||
|
|
Loading…
Reference in New Issue