[mlir][arith] Clean up ExpandOps pass

This commit is contained in:
Mogball 2021-12-20 21:58:39 +00:00
parent 557a17eec0
commit 8cb785cad1
1 changed files with 24 additions and 30 deletions

View File

@ -15,6 +15,13 @@
using namespace mlir;
/// Create an integer or index constant.
static Value createConst(Location loc, Type type, int value,
PatternRewriter &rewriter) {
return rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(type, value));
}
namespace {
/// Expands CeilDivUIOp (n, m) into
@ -26,17 +33,14 @@ struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
Location loc = op.getLoc();
Value a = op.getLhs();
Value b = op.getRhs();
Value zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(a.getType(), 0));
Value zero = createConst(loc, a.getType(), 0, rewriter);
Value compare =
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
Value one = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(a.getType(), 1));
Value one = createConst(loc, a.getType(), 1, rewriter);
Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
Value res = rewriter.create<SelectOp>(loc, compare, zero, plusOne);
rewriter.replaceOp(op, {res});
rewriter.replaceOpWithNewOp<SelectOp>(op, compare, zero, plusOne);
return success();
}
};
@ -49,16 +53,12 @@ struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
PatternRewriter &rewriter) const final {
Location loc = op.getLoc();
auto signedCeilDivIOp = cast<arith::CeilDivSIOp>(op);
Type type = signedCeilDivIOp.getType();
Value a = signedCeilDivIOp.getLhs();
Value b = signedCeilDivIOp.getRhs();
Value plusOne = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(type, 1));
Value zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(type, 0));
Value minusOne = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(type, -1));
Type type = op.getType();
Value a = op.getLhs();
Value b = op.getRhs();
Value plusOne = createConst(loc, type, 1, rewriter);
Value zero = createConst(loc, type, 0, rewriter);
Value minusOne = createConst(loc, type, -1, rewriter);
// Compute x = (b>0) ? -1 : 1.
Value compare =
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
@ -90,9 +90,8 @@ struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos);
Value compareRes =
rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
Value res = rewriter.create<SelectOp>(loc, compareRes, posRes, negRes);
// Perform substitution and return success.
rewriter.replaceOp(op, {res});
rewriter.replaceOpWithNewOp<SelectOp>(op, compareRes, posRes, negRes);
return success();
}
};
@ -105,16 +104,12 @@ struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
PatternRewriter &rewriter) const final {
Location loc = op.getLoc();
arith::FloorDivSIOp signedFloorDivIOp = cast<arith::FloorDivSIOp>(op);
Type type = signedFloorDivIOp.getType();
Value a = signedFloorDivIOp.getLhs();
Value b = signedFloorDivIOp.getRhs();
Value plusOne = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(type, 1));
Value zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(type, 0));
Value minusOne = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(type, -1));
Type type = op.getType();
Value a = op.getLhs();
Value b = op.getRhs();
Value plusOne = createConst(loc, type, 1, rewriter);
Value zero = createConst(loc, type, 0, rewriter);
Value minusOne = createConst(loc, type, -1, rewriter);
// Compute x = (b<0) ? 1 : -1.
Value compare =
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
@ -144,9 +139,8 @@ struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bNeg);
Value compareRes =
rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
Value res = rewriter.create<SelectOp>(loc, compareRes, negRes, posRes);
// Perform substitution and return success.
rewriter.replaceOp(op, {res});
rewriter.replaceOpWithNewOp<SelectOp>(op, compareRes, negRes, posRes);
return success();
}
};