[mlir](arithmetic) Add ceildivui to the arithmetic dialect

The specific description is [[ https://llvm.discourse.group/t/adding-unsigned-integer-ceil-and-floor-in-std-dialect/4541 | Adding unsigned integer ceil in Std Dialect ]] .

When we lower ceilDivOp this will generate below code, sometimes we know m and n are unsigned intergal.Here are some redundant judgments about positive and negative.
So we need to add some unsigned operations to simplify the instructions.
```
ceilDiv(n, m)
  x = (m > 0) ? -1 : 1
  return (n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
```
unsigned operations:
```
ceilDivU(n, m)
  return n ==0 ?  0 :  ((n - 1) / m) + 1
```

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D113363
This commit is contained in:
lipracer 2021-11-11 01:47:47 +00:00 committed by Mogball
parent 9303c7da39
commit 8165eaa885
8 changed files with 202 additions and 9 deletions

View File

@ -276,6 +276,30 @@ def Arith_DivSIOp : Arith_IntBinaryOp<"divsi"> {
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// CeilDivUIOp
//===----------------------------------------------------------------------===//
def Arith_CeilDivUIOp : Arith_IntBinaryOp<"ceildivui"> {
let summary = "unsigned ceil integer division operation";
let description = [{
Unsigned integer division. Rounds towards positive infinity. Treats the
leading bit as the most significant, i.e. for `i16` given two's complement
representation, `6 / -2 = 6 / (2^16 - 2) = 1`.
Note: the semantics of division by zero is TBD; do NOT assume any specific
behavior.
Example:
```mlir
// Scalar unsigned integer division.
%a = arith.ceildivui %b, %c : i64
```
}];
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// CeilDivSIOp
//===----------------------------------------------------------------------===//

View File

@ -306,6 +306,36 @@ static APInt signedCeilNonnegInputs(APInt a, APInt b, bool &overflow) {
return val.sadd_ov(one, overflow);
}
//===----------------------------------------------------------------------===//
// CeilDivUIOp
//===----------------------------------------------------------------------===//
OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
bool overflowOrDiv0 = false;
auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
if (overflowOrDiv0 || !b) {
overflowOrDiv0 = true;
return a;
}
APInt quotient = a.udiv(b);
if (!a.urem(b))
return quotient;
APInt one(a.getBitWidth(), 1, true);
return quotient.uadd_ov(one, overflowOrDiv0);
});
// Fold out ceil division by one. Assumes all tensors of all ones are
// splats.
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
if (rhs.getValue() == 1)
return getLhs();
} else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
return getLhs();
}
return overflowOrDiv0 ? Attribute() : result;
}
//===----------------------------------------------------------------------===//
// CeilDivSIOp
//===----------------------------------------------------------------------===//
@ -342,7 +372,7 @@ OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
return zero.ssub_ov(div, overflowOrDiv0);
});
// Fold out floor division by one. Assumes all tensors of all ones are
// Fold out ceil division by one. Assumes all tensors of all ones are
// splats.
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
if (rhs.getValue() == 1)

View File

@ -13,6 +13,30 @@ using namespace mlir;
namespace {
/// Expands CeilDivUIOp (n, m) into
/// n == 0 ? 0 : ((n-1) / m) + 1
struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
PatternRewriter &rewriter) const final {
Location loc = op.getLoc();
Value a = op.lhs();
Value b = op.rhs();
Value zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(a.getType(), 0));
Value compare =
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
Value one = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(a.getType(), 1));
Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
Value res = rewriter.create<SelectOp>(loc, compare, zero, plusOne);
rewriter.replaceOp(op, {res});
return success();
}
};
/// Expands CeilDivSIOp (n, m) into
/// 1) x = (m > 0) ? -1 : 1
/// 2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
@ -132,7 +156,8 @@ struct ArithmeticExpandOpsPass
arith::populateArithmeticExpandOpsPatterns(patterns);
target.addLegalDialect<arith::ArithmeticDialect, StandardOpsDialect>();
target.addIllegalOp<arith::CeilDivSIOp, arith::FloorDivSIOp>();
target.addIllegalOp<arith::CeilDivUIOp, arith::CeilDivSIOp,
arith::FloorDivSIOp>();
if (failed(
applyPartialConversion(getFunction(), target, std::move(patterns))))
@ -144,8 +169,9 @@ struct ArithmeticExpandOpsPass
void mlir::arith::populateArithmeticExpandOpsPatterns(
RewritePatternSet &patterns) {
patterns.add<CeilDivSIOpConverter, FloorDivSIOpConverter>(
patterns.getContext());
patterns
.add<CeilDivUIOpConverter, CeilDivSIOpConverter, FloorDivSIOpConverter>(
patterns.getContext());
}
std::unique_ptr<Pass> mlir::arith::createArithmeticExpandOpsPass() {

View File

@ -175,7 +175,8 @@ struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect,
StandardOpsDialect>();
target.addIllegalOp<arith::CeilDivSIOp, arith::FloorDivSIOp>();
target.addIllegalOp<arith::CeilDivUIOp, arith::CeilDivSIOp,
arith::FloorDivSIOp>();
target.addDynamicallyLegalOp<AtomicRMWOp>([](AtomicRMWOp op) {
return op.getKind() != AtomicRMWKind::maxf &&
op.getKind() != AtomicRMWKind::minf;

View File

@ -111,3 +111,37 @@ func @floordivi_index(%arg0: index, %arg1: index) -> (index) {
// CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
// CHECK: [[RES:%.+]] = select [[CMP2]], [[TRUE3]], [[FALSE]] : index
}
// -----
// Test ceil divide with unsigned integer
// CHECK-LABEL: func @ceildivui
// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> i32 {
func @ceildivui(%arg0: i32, %arg1: i32) -> (i32) {
%res = arith.ceildivui %arg0, %arg1 : i32
return %res : i32
// CHECK: [[ZERO:%.+]] = arith.constant 0 : i32
// CHECK: [[ISZERO:%.+]] = arith.cmpi eq, %arg0, [[ZERO]] : i32
// CHECK: [[ONE:%.+]] = arith.constant 1 : i32
// CHECK: [[SUB:%.+]] = arith.subi %arg0, [[ONE]] : i32
// CHECK: [[DIV:%.+]] = arith.divui [[SUB]], %arg1 : i32
// CHECK: [[REM:%.+]] = arith.addi [[DIV]], [[ONE]] : i32
// CHECK: [[RES:%.+]] = select [[ISZERO]], [[ZERO]], [[REM]] : i32
}
// -----
// Test unsigned ceil divide with index
// CHECK-LABEL: func @ceildivui_index
// CHECK-SAME: ([[ARG0:%.+]]: index, [[ARG1:%.+]]: index) -> index {
func @ceildivui_index(%arg0: index, %arg1: index) -> (index) {
%res = arith.ceildivui %arg0, %arg1 : index
return %res : index
// CHECK: [[ZERO:%.+]] = arith.constant 0 : index
// CHECK: [[ISZERO:%.+]] = arith.cmpi eq, %arg0, [[ZERO]] : index
// CHECK: [[ONE:%.+]] = arith.constant 1 : index
// CHECK: [[SUB:%.+]] = arith.subi %arg0, [[ONE]] : index
// CHECK: [[DIV:%.+]] = arith.divui [[SUB]], %arg1 : index
// CHECK: [[REM:%.+]] = arith.addi [[DIV]], [[ONE]] : index
// CHECK: [[RES:%.+]] = select [[ISZERO]], [[ZERO]], [[REM]] : index
}

View File

@ -17,6 +17,7 @@ func @entry() {
%c20 = arith.constant 20: i32
%c10 = arith.constant 10: i32
%cmin10 = arith.constant -10: i32
%cmax_int = arith.constant 2147483647: i32
%A = memref.alloc() : memref<40xi32>
// print numerator
@ -64,20 +65,39 @@ func @entry() {
}
call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> ()
// test with ceildivui(*, 10)
affine.for %i = 0 to 40 {
%ii = arith.index_cast %i: index to i32
%val = arith.ceildivui %ii, %c10 : i32
memref.store %val, %A[%i] : memref<40xi32>
}
call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> ()
// test with ceildivui(*, -1)
affine.for %i = 0 to 40 {
%ii = arith.index_cast %i: index to i32
%ii30 = arith.subi %ii, %c20 : i32
%val = arith.ceildivui %ii30, %cmax_int : i32
memref.store %val, %A[%i] : memref<40xi32>
}
call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> ()
memref.dealloc %A : memref<40xi32>
return
}
// List below is aligned for easy manual check
// legend: num, ceil(num, 10), floor(num, 10), ceil(num, -10), floor(num, -10)
// legend: num, signed_ceil(num, 10), floor(num, 10), signed_ceil(num, -10), floor(num, -10), unsigned_ceil(num, 10), unsigned_ceil(num, max_int)
// ( -20, -19, -18, -17, -16, -15, -14, -13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19 )
// ( -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2 )
// ( -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -1, -1, -1,-1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
// ( 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
// ( 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -2, -2, -2, -2 )
// ( -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
// ( 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
// ( 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -2, -2, -2, -2 )
// CHECK:( -20, -19, -18, -17, -16, -15, -14, -13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19 )
// CHECK:( -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2 )
// CHECK:( -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
// CHECK:( 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
// CHECK:( 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -2, -2, -2, -2 )
// CHECK:( 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4 )
// CHECK:( 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )

View File

@ -1028,6 +1028,26 @@ func @tensor_arith.ceildivsi_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
// -----
// CHECK-LABEL: func @arith.ceildivui_by_one
// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]]
func @arith.ceildivui_by_one(%arg0: i32) -> (i32) {
%c1 = arith.constant 1 : i32
%res = arith.ceildivui %arg0, %c1 : i32
// CHECK: return %[[ARG]]
return %res : i32
}
// CHECK-LABEL: func @tensor_arith.ceildivui_by_one
// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]]
func @tensor_arith.ceildivui_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
%c1 = arith.constant dense<1> : tensor<4x5xi32>
%res = arith.ceildivui %arg0, %c1 : tensor<4x5xi32>
// CHECK: return %[[ARG]]
return %res : tensor<4x5xi32>
}
// -----
// CHECK-LABEL: func @memref_cast_folding_subview
func @memref_cast_folding_subview(%arg0: memref<4x5xf32>, %i: index) -> (memref<?x?xf32, offset:? , strides: [?, ?]>) {
%0 = memref.cast %arg0 : memref<4x5xf32> to memref<?x?xf32>

View File

@ -478,6 +478,44 @@ func @simple_arith.ceildivsi() -> (i32, i32, i32, i32, i32) {
// -----
// CHECK-LABEL: func @simple_arith.ceildivui
func @simple_arith.ceildivui() -> (i32, i32, i32, i32, i32) {
// CHECK-DAG: [[C0:%.+]] = arith.constant 0
%z = arith.constant 0 : i32
// CHECK-DAG: [[C6:%.+]] = arith.constant 7
%0 = arith.constant 7 : i32
%1 = arith.constant 2 : i32
// ceil(7, 2) = 4
// CHECK-NEXT: [[C3:%.+]] = arith.constant 4 : i32
%2 = arith.ceildivui %0, %1 : i32
%3 = arith.constant -2 : i32
// ceil(7, -2) = 0
// CHECK-NEXT: [[CM1:%.+]] = arith.constant 1 : i32
%4 = arith.ceildivui %0, %3 : i32
%5 = arith.constant -8 : i32
// ceil(-8, 2) = 2147483644
// CHECK-NEXT: [[CM4:%.+]] = arith.constant 2147483644 : i32
%6 = arith.ceildivui %5, %1 : i32
%7 = arith.constant -15 : i32
// ceil(-15, -2) = 0
// CHECK-NOT: arith.constant 1 : i32
%8 = arith.ceildivui %7, %3 : i32
// CHECK-NEXT: [[XZ:%.+]] = arith.ceildivui [[C6]], [[C0]]
%9 = arith.ceildivui %0, %z : i32
return %2, %4, %6, %8, %9 : i32, i32, i32, i32, i32
}
// -----
// CHECK-LABEL: func @simple_arith.remsi
func @simple_arith.remsi(%a : i32) -> (i32, i32, i32) {
%0 = arith.constant 5 : i32