forked from OSchip/llvm-project
[mlir](arithmetic) Add ceildivui to the arithmetic dialect
The specific description is [[ https://llvm.discourse.group/t/adding-unsigned-integer-ceil-and-floor-in-std-dialect/4541 | Adding unsigned integer ceil in Std Dialect ]] . When we lower ceilDivOp this will generate below code, sometimes we know m and n are unsigned intergal.Here are some redundant judgments about positive and negative. So we need to add some unsigned operations to simplify the instructions. ``` ceilDiv(n, m) x = (m > 0) ? -1 : 1 return (n*m>0) ? ((n+x) / m) + 1 : - (-n / m) ``` unsigned operations: ``` ceilDivU(n, m) return n ==0 ? 0 : ((n - 1) / m) + 1 ``` Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D113363
This commit is contained in:
parent
9303c7da39
commit
8165eaa885
|
@ -276,6 +276,30 @@ def Arith_DivSIOp : Arith_IntBinaryOp<"divsi"> {
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CeilDivUIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Arith_CeilDivUIOp : Arith_IntBinaryOp<"ceildivui"> {
|
||||
let summary = "unsigned ceil integer division operation";
|
||||
let description = [{
|
||||
Unsigned integer division. Rounds towards positive infinity. Treats the
|
||||
leading bit as the most significant, i.e. for `i16` given two's complement
|
||||
representation, `6 / -2 = 6 / (2^16 - 2) = 1`.
|
||||
|
||||
Note: the semantics of division by zero is TBD; do NOT assume any specific
|
||||
behavior.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
// Scalar unsigned integer division.
|
||||
%a = arith.ceildivui %b, %c : i64
|
||||
```
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CeilDivSIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -306,6 +306,36 @@ static APInt signedCeilNonnegInputs(APInt a, APInt b, bool &overflow) {
|
|||
return val.sadd_ov(one, overflow);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CeilDivUIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
|
||||
bool overflowOrDiv0 = false;
|
||||
auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
|
||||
if (overflowOrDiv0 || !b) {
|
||||
overflowOrDiv0 = true;
|
||||
return a;
|
||||
}
|
||||
APInt quotient = a.udiv(b);
|
||||
if (!a.urem(b))
|
||||
return quotient;
|
||||
APInt one(a.getBitWidth(), 1, true);
|
||||
return quotient.uadd_ov(one, overflowOrDiv0);
|
||||
});
|
||||
// Fold out ceil division by one. Assumes all tensors of all ones are
|
||||
// splats.
|
||||
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
|
||||
if (rhs.getValue() == 1)
|
||||
return getLhs();
|
||||
} else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
|
||||
if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
|
||||
return getLhs();
|
||||
}
|
||||
|
||||
return overflowOrDiv0 ? Attribute() : result;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CeilDivSIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -342,7 +372,7 @@ OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
|
|||
return zero.ssub_ov(div, overflowOrDiv0);
|
||||
});
|
||||
|
||||
// Fold out floor division by one. Assumes all tensors of all ones are
|
||||
// Fold out ceil division by one. Assumes all tensors of all ones are
|
||||
// splats.
|
||||
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
|
||||
if (rhs.getValue() == 1)
|
||||
|
|
|
@ -13,6 +13,30 @@ using namespace mlir;
|
|||
|
||||
namespace {
|
||||
|
||||
/// Expands CeilDivUIOp (n, m) into
|
||||
/// n == 0 ? 0 : ((n-1) / m) + 1
|
||||
struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(arith::CeilDivUIOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
Location loc = op.getLoc();
|
||||
Value a = op.lhs();
|
||||
Value b = op.rhs();
|
||||
Value zero = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(a.getType(), 0));
|
||||
Value compare =
|
||||
rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero);
|
||||
Value one = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(a.getType(), 1));
|
||||
Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one);
|
||||
Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b);
|
||||
Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one);
|
||||
Value res = rewriter.create<SelectOp>(loc, compare, zero, plusOne);
|
||||
rewriter.replaceOp(op, {res});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Expands CeilDivSIOp (n, m) into
|
||||
/// 1) x = (m > 0) ? -1 : 1
|
||||
/// 2) (n*m>0) ? ((n+x) / m) + 1 : - (-n / m)
|
||||
|
@ -132,7 +156,8 @@ struct ArithmeticExpandOpsPass
|
|||
arith::populateArithmeticExpandOpsPatterns(patterns);
|
||||
|
||||
target.addLegalDialect<arith::ArithmeticDialect, StandardOpsDialect>();
|
||||
target.addIllegalOp<arith::CeilDivSIOp, arith::FloorDivSIOp>();
|
||||
target.addIllegalOp<arith::CeilDivUIOp, arith::CeilDivSIOp,
|
||||
arith::FloorDivSIOp>();
|
||||
|
||||
if (failed(
|
||||
applyPartialConversion(getFunction(), target, std::move(patterns))))
|
||||
|
@ -144,8 +169,9 @@ struct ArithmeticExpandOpsPass
|
|||
|
||||
void mlir::arith::populateArithmeticExpandOpsPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
patterns.add<CeilDivSIOpConverter, FloorDivSIOpConverter>(
|
||||
patterns.getContext());
|
||||
patterns
|
||||
.add<CeilDivUIOpConverter, CeilDivSIOpConverter, FloorDivSIOpConverter>(
|
||||
patterns.getContext());
|
||||
}
|
||||
|
||||
std::unique_ptr<Pass> mlir::arith::createArithmeticExpandOpsPass() {
|
||||
|
|
|
@ -175,7 +175,8 @@ struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
|
|||
|
||||
target.addLegalDialect<arith::ArithmeticDialect, memref::MemRefDialect,
|
||||
StandardOpsDialect>();
|
||||
target.addIllegalOp<arith::CeilDivSIOp, arith::FloorDivSIOp>();
|
||||
target.addIllegalOp<arith::CeilDivUIOp, arith::CeilDivSIOp,
|
||||
arith::FloorDivSIOp>();
|
||||
target.addDynamicallyLegalOp<AtomicRMWOp>([](AtomicRMWOp op) {
|
||||
return op.getKind() != AtomicRMWKind::maxf &&
|
||||
op.getKind() != AtomicRMWKind::minf;
|
||||
|
|
|
@ -111,3 +111,37 @@ func @floordivi_index(%arg0: index, %arg1: index) -> (index) {
|
|||
// CHECK: [[CMP2:%.+]] = arith.ori [[TERM1]], [[TERM2]] : i1
|
||||
// CHECK: [[RES:%.+]] = select [[CMP2]], [[TRUE3]], [[FALSE]] : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test ceil divide with unsigned integer
|
||||
// CHECK-LABEL: func @ceildivui
|
||||
// CHECK-SAME: ([[ARG0:%.+]]: i32, [[ARG1:%.+]]: i32) -> i32 {
|
||||
func @ceildivui(%arg0: i32, %arg1: i32) -> (i32) {
|
||||
%res = arith.ceildivui %arg0, %arg1 : i32
|
||||
return %res : i32
|
||||
// CHECK: [[ZERO:%.+]] = arith.constant 0 : i32
|
||||
// CHECK: [[ISZERO:%.+]] = arith.cmpi eq, %arg0, [[ZERO]] : i32
|
||||
// CHECK: [[ONE:%.+]] = arith.constant 1 : i32
|
||||
// CHECK: [[SUB:%.+]] = arith.subi %arg0, [[ONE]] : i32
|
||||
// CHECK: [[DIV:%.+]] = arith.divui [[SUB]], %arg1 : i32
|
||||
// CHECK: [[REM:%.+]] = arith.addi [[DIV]], [[ONE]] : i32
|
||||
// CHECK: [[RES:%.+]] = select [[ISZERO]], [[ZERO]], [[REM]] : i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test unsigned ceil divide with index
|
||||
// CHECK-LABEL: func @ceildivui_index
|
||||
// CHECK-SAME: ([[ARG0:%.+]]: index, [[ARG1:%.+]]: index) -> index {
|
||||
func @ceildivui_index(%arg0: index, %arg1: index) -> (index) {
|
||||
%res = arith.ceildivui %arg0, %arg1 : index
|
||||
return %res : index
|
||||
// CHECK: [[ZERO:%.+]] = arith.constant 0 : index
|
||||
// CHECK: [[ISZERO:%.+]] = arith.cmpi eq, %arg0, [[ZERO]] : index
|
||||
// CHECK: [[ONE:%.+]] = arith.constant 1 : index
|
||||
// CHECK: [[SUB:%.+]] = arith.subi %arg0, [[ONE]] : index
|
||||
// CHECK: [[DIV:%.+]] = arith.divui [[SUB]], %arg1 : index
|
||||
// CHECK: [[REM:%.+]] = arith.addi [[DIV]], [[ONE]] : index
|
||||
// CHECK: [[RES:%.+]] = select [[ISZERO]], [[ZERO]], [[REM]] : index
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@ func @entry() {
|
|||
%c20 = arith.constant 20: i32
|
||||
%c10 = arith.constant 10: i32
|
||||
%cmin10 = arith.constant -10: i32
|
||||
%cmax_int = arith.constant 2147483647: i32
|
||||
%A = memref.alloc() : memref<40xi32>
|
||||
|
||||
// print numerator
|
||||
|
@ -64,20 +65,39 @@ func @entry() {
|
|||
}
|
||||
call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> ()
|
||||
|
||||
// test with ceildivui(*, 10)
|
||||
affine.for %i = 0 to 40 {
|
||||
%ii = arith.index_cast %i: index to i32
|
||||
%val = arith.ceildivui %ii, %c10 : i32
|
||||
memref.store %val, %A[%i] : memref<40xi32>
|
||||
}
|
||||
call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> ()
|
||||
|
||||
// test with ceildivui(*, -1)
|
||||
affine.for %i = 0 to 40 {
|
||||
%ii = arith.index_cast %i: index to i32
|
||||
%ii30 = arith.subi %ii, %c20 : i32
|
||||
%val = arith.ceildivui %ii30, %cmax_int : i32
|
||||
memref.store %val, %A[%i] : memref<40xi32>
|
||||
}
|
||||
call @transfer_read_2d(%A, %c0) : (memref<40xi32>, index) -> ()
|
||||
|
||||
memref.dealloc %A : memref<40xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// List below is aligned for easy manual check
|
||||
// legend: num, ceil(num, 10), floor(num, 10), ceil(num, -10), floor(num, -10)
|
||||
// legend: num, signed_ceil(num, 10), floor(num, 10), signed_ceil(num, -10), floor(num, -10), unsigned_ceil(num, 10), unsigned_ceil(num, max_int)
|
||||
// ( -20, -19, -18, -17, -16, -15, -14, -13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19 )
|
||||
// ( -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2 )
|
||||
// ( -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -1, -1, -1,-1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
|
||||
// ( 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
|
||||
// ( 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -2, -2, -2, -2 )
|
||||
// ( -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
|
||||
// ( 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
|
||||
// ( 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -2, -2, -2, -2 )
|
||||
|
||||
// CHECK:( -20, -19, -18, -17, -16, -15, -14, -13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19 )
|
||||
// CHECK:( -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2 )
|
||||
// CHECK:( -2, -2, -2, -2, -2, -2, -2, -2, -2, -2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
|
||||
// CHECK:( 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 )
|
||||
// CHECK:( 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -2, -2, -2, -2 )
|
||||
// CHECK:( 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4 )
|
||||
// CHECK:( 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 )
|
||||
|
|
|
@ -1028,6 +1028,26 @@ func @tensor_arith.ceildivsi_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @arith.ceildivui_by_one
|
||||
// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]]
|
||||
func @arith.ceildivui_by_one(%arg0: i32) -> (i32) {
|
||||
%c1 = arith.constant 1 : i32
|
||||
%res = arith.ceildivui %arg0, %c1 : i32
|
||||
// CHECK: return %[[ARG]]
|
||||
return %res : i32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @tensor_arith.ceildivui_by_one
|
||||
// CHECK-SAME: %[[ARG:[a-zA-Z0-9]+]]
|
||||
func @tensor_arith.ceildivui_by_one(%arg0: tensor<4x5xi32>) -> tensor<4x5xi32> {
|
||||
%c1 = arith.constant dense<1> : tensor<4x5xi32>
|
||||
%res = arith.ceildivui %arg0, %c1 : tensor<4x5xi32>
|
||||
// CHECK: return %[[ARG]]
|
||||
return %res : tensor<4x5xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @memref_cast_folding_subview
|
||||
func @memref_cast_folding_subview(%arg0: memref<4x5xf32>, %i: index) -> (memref<?x?xf32, offset:? , strides: [?, ?]>) {
|
||||
%0 = memref.cast %arg0 : memref<4x5xf32> to memref<?x?xf32>
|
||||
|
|
|
@ -478,6 +478,44 @@ func @simple_arith.ceildivsi() -> (i32, i32, i32, i32, i32) {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @simple_arith.ceildivui
|
||||
func @simple_arith.ceildivui() -> (i32, i32, i32, i32, i32) {
|
||||
// CHECK-DAG: [[C0:%.+]] = arith.constant 0
|
||||
%z = arith.constant 0 : i32
|
||||
// CHECK-DAG: [[C6:%.+]] = arith.constant 7
|
||||
%0 = arith.constant 7 : i32
|
||||
%1 = arith.constant 2 : i32
|
||||
|
||||
// ceil(7, 2) = 4
|
||||
// CHECK-NEXT: [[C3:%.+]] = arith.constant 4 : i32
|
||||
%2 = arith.ceildivui %0, %1 : i32
|
||||
|
||||
%3 = arith.constant -2 : i32
|
||||
|
||||
// ceil(7, -2) = 0
|
||||
// CHECK-NEXT: [[CM1:%.+]] = arith.constant 1 : i32
|
||||
%4 = arith.ceildivui %0, %3 : i32
|
||||
|
||||
%5 = arith.constant -8 : i32
|
||||
|
||||
// ceil(-8, 2) = 2147483644
|
||||
// CHECK-NEXT: [[CM4:%.+]] = arith.constant 2147483644 : i32
|
||||
%6 = arith.ceildivui %5, %1 : i32
|
||||
|
||||
%7 = arith.constant -15 : i32
|
||||
|
||||
// ceil(-15, -2) = 0
|
||||
// CHECK-NOT: arith.constant 1 : i32
|
||||
%8 = arith.ceildivui %7, %3 : i32
|
||||
|
||||
// CHECK-NEXT: [[XZ:%.+]] = arith.ceildivui [[C6]], [[C0]]
|
||||
%9 = arith.ceildivui %0, %z : i32
|
||||
|
||||
return %2, %4, %6, %8, %9 : i32, i32, i32, i32, i32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @simple_arith.remsi
|
||||
func @simple_arith.remsi(%a : i32) -> (i32, i32, i32) {
|
||||
%0 = arith.constant 5 : i32
|
||||
|
|
Loading…
Reference in New Issue