diff --git a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp index d9ab92751238..d06c3043664d 100644 --- a/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arithmetic/Transforms/ExpandOps.cpp @@ -15,6 +15,13 @@ using namespace mlir; +/// Create an integer or index constant. +static Value createConst(Location loc, Type type, int value, + PatternRewriter &rewriter) { + return rewriter.create( + loc, rewriter.getIntegerAttr(type, value)); +} + namespace { /// Expands CeilDivUIOp (n, m) into @@ -26,17 +33,14 @@ struct CeilDivUIOpConverter : public OpRewritePattern { Location loc = op.getLoc(); Value a = op.getLhs(); Value b = op.getRhs(); - Value zero = rewriter.create( - loc, rewriter.getIntegerAttr(a.getType(), 0)); + Value zero = createConst(loc, a.getType(), 0, rewriter); Value compare = rewriter.create(loc, arith::CmpIPredicate::eq, a, zero); - Value one = rewriter.create( - loc, rewriter.getIntegerAttr(a.getType(), 1)); + Value one = createConst(loc, a.getType(), 1, rewriter); Value minusOne = rewriter.create(loc, a, one); Value quotient = rewriter.create(loc, minusOne, b); Value plusOne = rewriter.create(loc, quotient, one); - Value res = rewriter.create(loc, compare, zero, plusOne); - rewriter.replaceOp(op, {res}); + rewriter.replaceOpWithNewOp(op, compare, zero, plusOne); return success(); } }; @@ -49,16 +53,12 @@ struct CeilDivSIOpConverter : public OpRewritePattern { LogicalResult matchAndRewrite(arith::CeilDivSIOp op, PatternRewriter &rewriter) const final { Location loc = op.getLoc(); - auto signedCeilDivIOp = cast(op); - Type type = signedCeilDivIOp.getType(); - Value a = signedCeilDivIOp.getLhs(); - Value b = signedCeilDivIOp.getRhs(); - Value plusOne = rewriter.create( - loc, rewriter.getIntegerAttr(type, 1)); - Value zero = rewriter.create( - loc, rewriter.getIntegerAttr(type, 0)); - Value minusOne = rewriter.create( - loc, rewriter.getIntegerAttr(type, -1)); + Type type = op.getType(); + Value a = op.getLhs(); + Value b = op.getRhs(); + Value plusOne = createConst(loc, type, 1, rewriter); + Value zero = createConst(loc, type, 0, rewriter); + Value minusOne = createConst(loc, type, -1, rewriter); // Compute x = (b>0) ? -1 : 1. Value compare = rewriter.create(loc, arith::CmpIPredicate::sgt, b, zero); @@ -90,9 +90,8 @@ struct CeilDivSIOpConverter : public OpRewritePattern { Value secondTerm = rewriter.create(loc, aPos, bPos); Value compareRes = rewriter.create(loc, firstTerm, secondTerm); - Value res = rewriter.create(loc, compareRes, posRes, negRes); // Perform substitution and return success. - rewriter.replaceOp(op, {res}); + rewriter.replaceOpWithNewOp(op, compareRes, posRes, negRes); return success(); } }; @@ -105,16 +104,12 @@ struct FloorDivSIOpConverter : public OpRewritePattern { LogicalResult matchAndRewrite(arith::FloorDivSIOp op, PatternRewriter &rewriter) const final { Location loc = op.getLoc(); - arith::FloorDivSIOp signedFloorDivIOp = cast(op); - Type type = signedFloorDivIOp.getType(); - Value a = signedFloorDivIOp.getLhs(); - Value b = signedFloorDivIOp.getRhs(); - Value plusOne = rewriter.create( - loc, rewriter.getIntegerAttr(type, 1)); - Value zero = rewriter.create( - loc, rewriter.getIntegerAttr(type, 0)); - Value minusOne = rewriter.create( - loc, rewriter.getIntegerAttr(type, -1)); + Type type = op.getType(); + Value a = op.getLhs(); + Value b = op.getRhs(); + Value plusOne = createConst(loc, type, 1, rewriter); + Value zero = createConst(loc, type, 0, rewriter); + Value minusOne = createConst(loc, type, -1, rewriter); // Compute x = (b<0) ? 1 : -1. Value compare = rewriter.create(loc, arith::CmpIPredicate::slt, b, zero); @@ -144,9 +139,8 @@ struct FloorDivSIOpConverter : public OpRewritePattern { Value secondTerm = rewriter.create(loc, aPos, bNeg); Value compareRes = rewriter.create(loc, firstTerm, secondTerm); - Value res = rewriter.create(loc, compareRes, negRes, posRes); // Perform substitution and return success. - rewriter.replaceOp(op, {res}); + rewriter.replaceOpWithNewOp(op, compareRes, negRes, posRes); return success(); } };