forked from OSchip/llvm-project
[mlir][arith] Clean up ExpandOps pass
This commit is contained in:
parent
557a17eec0
commit
8cb785cad1
|
@ -15,6 +15,13 @@
|
|||
|
||||
using namespace mlir;
|
||||
|
||||
/// Create an integer or index constant.
|
||||
static Value createConst(Location loc, Type type, int value,
|
||||
PatternRewriter &rewriter) {
|
||||
return rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(type, value));
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/// Expands CeilDivUIOp (n, m) into
|
||||
|
@ -26,17 +33,14 @@ struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
|
|||
Location loc = op.getLoc();
|
||||
Value a = op.getLhs();
|
||||
Value b = op.getRhs();
|
||||
Value zero = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(a.getType(), 0));
|
||||
Value zero = createConst(loc, a.getType(), 0, rewriter);
|
||||
Value compare =
|
||||
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
|
||||
Value one = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(a.getType(), 1));
|
||||
Value one = createConst(loc, a.getType(), 1, rewriter);
|
||||
Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
|
||||
Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
|
||||
Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
|
||||
Value res = rewriter.create<SelectOp>(loc, compare, zero, plusOne);
|
||||
rewriter.replaceOp(op, {res});
|
||||
rewriter.replaceOpWithNewOp<SelectOp>(op, compare, zero, plusOne);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -49,16 +53,12 @@ struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
|
|||
LogicalResult matchAndRewrite(arith::CeilDivSIOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
Location loc = op.getLoc();
|
||||
auto signedCeilDivIOp = cast<arith::CeilDivSIOp>(op);
|
||||
Type type = signedCeilDivIOp.getType();
|
||||
Value a = signedCeilDivIOp.getLhs();
|
||||
Value b = signedCeilDivIOp.getRhs();
|
||||
Value plusOne = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(type, 1));
|
||||
Value zero = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(type, 0));
|
||||
Value minusOne = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(type, -1));
|
||||
Type type = op.getType();
|
||||
Value a = op.getLhs();
|
||||
Value b = op.getRhs();
|
||||
Value plusOne = createConst(loc, type, 1, rewriter);
|
||||
Value zero = createConst(loc, type, 0, rewriter);
|
||||
Value minusOne = createConst(loc, type, -1, rewriter);
|
||||
// Compute x = (b>0) ? -1 : 1.
|
||||
Value compare =
|
||||
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, b, zero);
|
||||
|
@ -90,9 +90,8 @@ struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> {
|
|||
Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bPos);
|
||||
Value compareRes =
|
||||
rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
|
||||
Value res = rewriter.create<SelectOp>(loc, compareRes, posRes, negRes);
|
||||
// Perform substitution and return success.
|
||||
rewriter.replaceOp(op, {res});
|
||||
rewriter.replaceOpWithNewOp<SelectOp>(op, compareRes, posRes, negRes);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -105,16 +104,12 @@ struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
|
|||
LogicalResult matchAndRewrite(arith::FloorDivSIOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
Location loc = op.getLoc();
|
||||
arith::FloorDivSIOp signedFloorDivIOp = cast<arith::FloorDivSIOp>(op);
|
||||
Type type = signedFloorDivIOp.getType();
|
||||
Value a = signedFloorDivIOp.getLhs();
|
||||
Value b = signedFloorDivIOp.getRhs();
|
||||
Value plusOne = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(type, 1));
|
||||
Value zero = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(type, 0));
|
||||
Value minusOne = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(type, -1));
|
||||
Type type = op.getType();
|
||||
Value a = op.getLhs();
|
||||
Value b = op.getRhs();
|
||||
Value plusOne = createConst(loc, type, 1, rewriter);
|
||||
Value zero = createConst(loc, type, 0, rewriter);
|
||||
Value minusOne = createConst(loc, type, -1, rewriter);
|
||||
// Compute x = (b<0) ? 1 : -1.
|
||||
Value compare =
|
||||
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero);
|
||||
|
@ -144,9 +139,8 @@ struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> {
|
|||
Value secondTerm = rewriter.create<arith::AndIOp>(loc, aPos, bNeg);
|
||||
Value compareRes =
|
||||
rewriter.create<arith::OrIOp>(loc, firstTerm, secondTerm);
|
||||
Value res = rewriter.create<SelectOp>(loc, compareRes, negRes, posRes);
|
||||
// Perform substitution and return success.
|
||||
rewriter.replaceOp(op, {res});
|
||||
rewriter.replaceOpWithNewOp<SelectOp>(op, compareRes, negRes, posRes);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue