[mlir] add unsigned comparison builders to Affine EDSC

Current Affine comparison builders, which use operator overload, default to signed comparison.  This creates the possibility of misuse of these builders and potential correctness issues when dealing with unsigned integers.  This change makes the distinction between signed and unsigned comparison builders and forces the caller to make a choice between the two.

Differential Revision: https://reviews.llvm.org/D82323
This commit is contained in:
Adam D Straw 2020-06-29 19:35:11 +02:00 committed by Alex Zinenko
parent bd2c3014e1
commit 25055a4fb9
7 changed files with 198 additions and 57 deletions

View File

@ -66,10 +66,14 @@ Value operator^(Value lhs, Value rhs);
/// Comparison operator overloadings. /// Comparison operator overloadings.
Value eq(Value lhs, Value rhs); Value eq(Value lhs, Value rhs);
Value ne(Value lhs, Value rhs); Value ne(Value lhs, Value rhs);
Value operator<(Value lhs, Value rhs); Value slt(Value lhs, Value rhs);
Value operator<=(Value lhs, Value rhs); Value sle(Value lhs, Value rhs);
Value operator>(Value lhs, Value rhs); Value sgt(Value lhs, Value rhs);
Value operator>=(Value lhs, Value rhs); Value sge(Value lhs, Value rhs);
Value ult(Value lhs, Value rhs);
Value ule(Value lhs, Value rhs);
Value ugt(Value lhs, Value rhs);
Value uge(Value lhs, Value rhs);
} // namespace op } // namespace op
@ -159,24 +163,44 @@ Value TemplatedIndexedValue<Load, Store>::ne(Value e) {
return ne(value, e); return ne(value, e);
} }
template <typename Load, typename Store> template <typename Load, typename Store>
Value TemplatedIndexedValue<Load, Store>::operator<(Value e) { Value TemplatedIndexedValue<Load, Store>::slt(Value e) {
using op::operator<; using op::slt;
return static_cast<Value>(*this) < e; return slt(static_cast<Value>(*this), e);
} }
template <typename Load, typename Store> template <typename Load, typename Store>
Value TemplatedIndexedValue<Load, Store>::operator<=(Value e) { Value TemplatedIndexedValue<Load, Store>::sle(Value e) {
using op::operator<=; using op::sle;
return static_cast<Value>(*this) <= e; return sle(static_cast<Value>(*this), e);
} }
template <typename Load, typename Store> template <typename Load, typename Store>
Value TemplatedIndexedValue<Load, Store>::operator>(Value e) { Value TemplatedIndexedValue<Load, Store>::sgt(Value e) {
using op::operator>; using op::sgt;
return static_cast<Value>(*this) > e; return sgt(static_cast<Value>(*this), e);
} }
template <typename Load, typename Store> template <typename Load, typename Store>
Value TemplatedIndexedValue<Load, Store>::operator>=(Value e) { Value TemplatedIndexedValue<Load, Store>::sge(Value e) {
using op::operator>=; using op::sge;
return static_cast<Value>(*this) >= e; return sge(static_cast<Value>(*this), e);
}
template <typename Load, typename Store>
Value TemplatedIndexedValue<Load, Store>::ult(Value e) {
using op::ult;
return ult(static_cast<Value>(*this), e);
}
template <typename Load, typename Store>
Value TemplatedIndexedValue<Load, Store>::ule(Value e) {
using op::ule;
return ule(static_cast<Value>(*this), e);
}
template <typename Load, typename Store>
Value TemplatedIndexedValue<Load, Store>::ugt(Value e) {
using op::ugt;
return ugt(static_cast<Value>(*this), e);
}
template <typename Load, typename Store>
Value TemplatedIndexedValue<Load, Store>::uge(Value e) {
using op::uge;
return uge(static_cast<Value>(*this), e);
} }
} // namespace edsc } // namespace edsc

View File

@ -288,21 +288,37 @@ public:
/// Comparison operator overloadings. /// Comparison operator overloadings.
Value eq(Value e); Value eq(Value e);
Value ne(Value e); Value ne(Value e);
Value operator<(Value e); Value slt(Value e);
Value operator<=(Value e); Value sle(Value e);
Value operator>(Value e); Value sgt(Value e);
Value operator>=(Value e); Value sge(Value e);
Value operator<(TemplatedIndexedValue e) { Value ult(Value e);
return *this < static_cast<Value>(e); Value ule(Value e);
Value ugt(Value e);
Value uge(Value e);
Value slt(TemplatedIndexedValue e) {
return slt(*this, static_cast<Value>(e));
} }
Value operator<=(TemplatedIndexedValue e) { Value sle(TemplatedIndexedValue e) {
return *this <= static_cast<Value>(e); return sle(*this, static_cast<Value>(e));
} }
Value operator>(TemplatedIndexedValue e) { Value sgt(TemplatedIndexedValue e) {
return *this > static_cast<Value>(e); return sgt(*this, static_cast<Value>(e));
} }
Value operator>=(TemplatedIndexedValue e) { Value sge(TemplatedIndexedValue e) {
return *this >= static_cast<Value>(e); return sge(*this, static_cast<Value>(e));
}
Value ult(TemplatedIndexedValue e) {
return ult(*this, static_cast<Value>(e));
}
Value ule(TemplatedIndexedValue e) {
return ule(*this, static_cast<Value>(e));
}
Value ugt(TemplatedIndexedValue e) {
return ugt(*this, static_cast<Value>(e));
}
Value uge(TemplatedIndexedValue e) {
return uge(*this, static_cast<Value>(e));
} }
private: private:

View File

@ -187,7 +187,7 @@ Value NDTransferOpHelper<ConcreteOp>::emitInBoundsCondition(
using namespace mlir::edsc::op; using namespace mlir::edsc::op;
majorIvsPlusOffsets.push_back(iv + off); majorIvsPlusOffsets.push_back(iv + off);
if (xferOp.isMaskedDim(leadingRank + idx)) { if (xferOp.isMaskedDim(leadingRank + idx)) {
Value inBounds = majorIvsPlusOffsets.back() < ub; Value inBounds = slt(majorIvsPlusOffsets.back(), ub);
inBoundsCondition = inBoundsCondition =
(inBoundsCondition) ? (inBoundsCondition && inBounds) : inBounds; (inBoundsCondition) ? (inBoundsCondition && inBounds) : inBounds;
} }
@ -433,16 +433,16 @@ clip(TransferOpTy transfer, MemRefBoundsCapture &bounds, ArrayRef<Value> ivs) {
auto i = memRefAccess[memRefDim]; auto i = memRefAccess[memRefDim];
if (loopIndex < 0) { if (loopIndex < 0) {
auto N_minus_1 = N - one; auto N_minus_1 = N - one;
auto select_1 = std_select(i < N, i, N_minus_1); auto select_1 = std_select(slt(i, N), i, N_minus_1);
clippedScalarAccessExprs[memRefDim] = clippedScalarAccessExprs[memRefDim] =
std_select(i < zero, zero, select_1); std_select(slt(i, zero), zero, select_1);
} else { } else {
auto ii = ivs[loopIndex]; auto ii = ivs[loopIndex];
auto i_plus_ii = i + ii; auto i_plus_ii = i + ii;
auto N_minus_1 = N - one; auto N_minus_1 = N - one;
auto select_1 = std_select(i_plus_ii < N, i_plus_ii, N_minus_1); auto select_1 = std_select(slt(i_plus_ii, N), i_plus_ii, N_minus_1);
clippedScalarAccessExprs[memRefDim] = clippedScalarAccessExprs[memRefDim] =
std_select(i_plus_ii < zero, zero, select_1); std_select(slt(i_plus_ii, zero), zero, select_1);
} }
} }

View File

@ -221,29 +221,51 @@ Value mlir::edsc::op::ne(Value lhs, Value rhs) {
? createFComparisonExpr(CmpFPredicate::ONE, lhs, rhs) ? createFComparisonExpr(CmpFPredicate::ONE, lhs, rhs)
: createIComparisonExpr(CmpIPredicate::ne, lhs, rhs); : createIComparisonExpr(CmpIPredicate::ne, lhs, rhs);
} }
Value mlir::edsc::op::operator<(Value lhs, Value rhs) { Value mlir::edsc::op::slt(Value lhs, Value rhs) {
auto type = lhs.getType(); auto type = lhs.getType();
return type.isa<FloatType>() return type.isa<FloatType>()
? createFComparisonExpr(CmpFPredicate::OLT, lhs, rhs) ? createFComparisonExpr(CmpFPredicate::OLT, lhs, rhs)
: : createIComparisonExpr(CmpIPredicate::slt, lhs, rhs);
// TODO(ntv,zinenko): signed by default, how about unsigned?
createIComparisonExpr(CmpIPredicate::slt, lhs, rhs);
} }
Value mlir::edsc::op::operator<=(Value lhs, Value rhs) { Value mlir::edsc::op::sle(Value lhs, Value rhs) {
auto type = lhs.getType(); auto type = lhs.getType();
return type.isa<FloatType>() return type.isa<FloatType>()
? createFComparisonExpr(CmpFPredicate::OLE, lhs, rhs) ? createFComparisonExpr(CmpFPredicate::OLE, lhs, rhs)
: createIComparisonExpr(CmpIPredicate::sle, lhs, rhs); : createIComparisonExpr(CmpIPredicate::sle, lhs, rhs);
} }
Value mlir::edsc::op::operator>(Value lhs, Value rhs) { Value mlir::edsc::op::sgt(Value lhs, Value rhs) {
auto type = lhs.getType(); auto type = lhs.getType();
return type.isa<FloatType>() return type.isa<FloatType>()
? createFComparisonExpr(CmpFPredicate::OGT, lhs, rhs) ? createFComparisonExpr(CmpFPredicate::OGT, lhs, rhs)
: createIComparisonExpr(CmpIPredicate::sgt, lhs, rhs); : createIComparisonExpr(CmpIPredicate::sgt, lhs, rhs);
} }
Value mlir::edsc::op::operator>=(Value lhs, Value rhs) { Value mlir::edsc::op::sge(Value lhs, Value rhs) {
auto type = lhs.getType(); auto type = lhs.getType();
return type.isa<FloatType>() return type.isa<FloatType>()
? createFComparisonExpr(CmpFPredicate::OGE, lhs, rhs) ? createFComparisonExpr(CmpFPredicate::OGE, lhs, rhs)
: createIComparisonExpr(CmpIPredicate::sge, lhs, rhs); : createIComparisonExpr(CmpIPredicate::sge, lhs, rhs);
} }
Value mlir::edsc::op::ult(Value lhs, Value rhs) {
auto type = lhs.getType();
return type.isa<FloatType>()
? createFComparisonExpr(CmpFPredicate::OLT, lhs, rhs)
: createIComparisonExpr(CmpIPredicate::ult, lhs, rhs);
}
Value mlir::edsc::op::ule(Value lhs, Value rhs) {
auto type = lhs.getType();
return type.isa<FloatType>()
? createFComparisonExpr(CmpFPredicate::OLE, lhs, rhs)
: createIComparisonExpr(CmpIPredicate::ule, lhs, rhs);
}
Value mlir::edsc::op::ugt(Value lhs, Value rhs) {
auto type = lhs.getType();
return type.isa<FloatType>()
? createFComparisonExpr(CmpFPredicate::OGT, lhs, rhs)
: createIComparisonExpr(CmpIPredicate::ugt, lhs, rhs);
}
Value mlir::edsc::op::uge(Value lhs, Value rhs) {
auto type = lhs.getType();
return type.isa<FloatType>()
? createFComparisonExpr(CmpFPredicate::OGE, lhs, rhs)
: createIComparisonExpr(CmpIPredicate::uge, lhs, rhs);
}

View File

@ -169,8 +169,8 @@ Operation *mlir::edsc::ops::linalg_generic_pointwise_max(StructuredIndexed I1,
StructuredIndexed I2, StructuredIndexed I2,
StructuredIndexed O) { StructuredIndexed O) {
BinaryPointwiseOpBuilder binOp([](Value a, Value b) -> Value { BinaryPointwiseOpBuilder binOp([](Value a, Value b) -> Value {
using edsc::op::operator>; using edsc::op::sgt;
return std_select(a > b, a, b); return std_select(sgt(a, b), a, b);
}); });
return linalg_generic_pointwise(binOp, I1, I2, O); return linalg_generic_pointwise(binOp, I1, I2, O);
} }

View File

@ -263,16 +263,16 @@ Value getConvOpInput(ConvOp convOp, StdIndexedValue im,
continue; continue;
} }
using edsc::op::operator<; using edsc::op::sge;
using edsc::op::operator>=; using edsc::op::slt;
using edsc::op::operator||; using edsc::op::operator||;
Value leftOutOfBound = dim < zeroIndex; Value leftOutOfBound = slt(dim, zeroIndex);
if (conds.empty()) if (conds.empty())
conds.push_back(leftOutOfBound); conds.push_back(leftOutOfBound);
else else
conds.push_back(conds.back() || leftOutOfBound); conds.push_back(conds.back() || leftOutOfBound);
Value rightBound = std_dim(convOp.input(), idx); Value rightBound = std_dim(convOp.input(), idx);
conds.push_back(conds.back() || (dim >= rightBound)); conds.push_back(conds.back() || (sge(dim, rightBound)));
// When padding is involved, the indices will only be shifted to negative, // When padding is involved, the indices will only be shifted to negative,
// so having a max op is enough. // so having a max op is enough.
@ -337,8 +337,8 @@ void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMaxOp op) {
// Emit scalar form. // Emit scalar form.
Value lhs = std_load(op.output(), indices.outputs); Value lhs = std_load(op.output(), indices.outputs);
Value rhs = std_load(op.input(), indices.inputs); Value rhs = std_load(op.input(), indices.inputs);
using edsc::op::operator>; using edsc::op::sgt;
Value maxValue = std_select(lhs > rhs, lhs, rhs); Value maxValue = std_select(sgt(lhs, rhs), lhs, rhs);
std_store(maxValue, op.output(), indices.outputs); std_store(maxValue, op.output(), indices.outputs);
} }
template <typename IndexedValueType> template <typename IndexedValueType>
@ -347,8 +347,8 @@ void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMinOp op) {
// Emit scalar form. // Emit scalar form.
Value lhs = std_load(op.output(), indices.outputs); Value lhs = std_load(op.output(), indices.outputs);
Value rhs = std_load(op.input(), indices.inputs); Value rhs = std_load(op.input(), indices.inputs);
using edsc::op::operator<; using edsc::op::slt;
Value minValue = std_select(lhs < rhs, lhs, rhs); Value minValue = std_select(slt(lhs, rhs), lhs, rhs);
std_store(minValue, op.output(), indices.outputs); std_store(minValue, op.output(), indices.outputs);
} }
template <typename IndexedValueType> template <typename IndexedValueType>

View File

@ -459,9 +459,9 @@ TEST_FUNC(diviu_op_i32) {
TEST_FUNC(select_op_i32) { TEST_FUNC(select_op_i32) {
using namespace edsc::op; using namespace edsc::op;
auto f32Type = FloatType::getF32(&globalContext()); auto i32Type = IntegerType::get(32, &globalContext());
auto memrefType = MemRefType::get( auto memrefType = MemRefType::get(
{ShapedType::kDynamicSize, ShapedType::kDynamicSize}, f32Type, {}, 0); {ShapedType::kDynamicSize, ShapedType::kDynamicSize}, i32Type, {}, 0);
auto f = makeFunction("select_op", {}, {memrefType}); auto f = makeFunction("select_op", {}, {memrefType});
OpBuilder builder(f.getBody()); OpBuilder builder(f.getBody());
@ -470,7 +470,18 @@ TEST_FUNC(select_op_i32) {
MemRefBoundsCapture vA(f.getArgument(0)); MemRefBoundsCapture vA(f.getArgument(0));
AffineIndexedValue A(f.getArgument(0)); AffineIndexedValue A(f.getArgument(0));
affineLoopNestBuilder({zero, zero}, {one, one}, {1, 1}, [&](ValueRange ivs) { affineLoopNestBuilder({zero, zero}, {one, one}, {1, 1}, [&](ValueRange ivs) {
std_select(eq(ivs[0], zero), A(zero, zero), A(ivs[0], ivs[1])); using namespace edsc::op;
Value i = ivs[0], j = ivs[1];
std_select(eq(i, zero), A(zero, zero), A(i, j));
std_select(ne(i, zero), A(zero, zero), A(i, j));
std_select(slt(i, zero), A(zero, zero), A(i, j));
std_select(sle(i, zero), A(zero, zero), A(i, j));
std_select(sgt(i, zero), A(zero, zero), A(i, j));
std_select(sge(i, zero), A(zero, zero), A(i, j));
std_select(ult(i, zero), A(zero, zero), A(i, j));
std_select(ule(i, zero), A(zero, zero), A(i, j));
std_select(ugt(i, zero), A(zero, zero), A(i, j));
std_select(uge(i, zero), A(zero, zero), A(i, j));
}); });
// clang-format off // clang-format off
@ -481,6 +492,42 @@ TEST_FUNC(select_op_i32) {
// CHECK-DAG: {{.*}} = affine.load // CHECK-DAG: {{.*}} = affine.load
// CHECK-DAG: {{.*}} = affine.load // CHECK-DAG: {{.*}} = affine.load
// CHECK-NEXT: {{.*}} = select // CHECK-NEXT: {{.*}} = select
// CHECK-DAG: {{.*}} = cmpi "ne"
// CHECK-DAG: {{.*}} = affine.load
// CHECK-DAG: {{.*}} = affine.load
// CHECK-NEXT: {{.*}} = select
// CHECK-DAG: {{.*}} = cmpi "slt"
// CHECK-DAG: {{.*}} = affine.load
// CHECK-DAG: {{.*}} = affine.load
// CHECK-NEXT: {{.*}} = select
// CHECK-DAG: {{.*}} = cmpi "sle"
// CHECK-DAG: {{.*}} = affine.load
// CHECK-DAG: {{.*}} = affine.load
// CHECK-NEXT: {{.*}} = select
// CHECK-DAG: {{.*}} = cmpi "sgt"
// CHECK-DAG: {{.*}} = affine.load
// CHECK-DAG: {{.*}} = affine.load
// CHECK-NEXT: {{.*}} = select
// CHECK-DAG: {{.*}} = cmpi "sge"
// CHECK-DAG: {{.*}} = affine.load
// CHECK-DAG: {{.*}} = affine.load
// CHECK-NEXT: {{.*}} = select
// CHECK-DAG: {{.*}} = cmpi "ult"
// CHECK-DAG: {{.*}} = affine.load
// CHECK-DAG: {{.*}} = affine.load
// CHECK-NEXT: {{.*}} = select
// CHECK-DAG: {{.*}} = cmpi "ule"
// CHECK-DAG: {{.*}} = affine.load
// CHECK-DAG: {{.*}} = affine.load
// CHECK-NEXT: {{.*}} = select
// CHECK-DAG: {{.*}} = cmpi "ugt"
// CHECK-DAG: {{.*}} = affine.load
// CHECK-DAG: {{.*}} = affine.load
// CHECK-NEXT: {{.*}} = select
// CHECK-DAG: {{.*}} = cmpi "uge"
// CHECK-DAG: {{.*}} = affine.load
// CHECK-DAG: {{.*}} = affine.load
// CHECK-NEXT: {{.*}} = select
// clang-format on // clang-format on
f.print(llvm::outs()); f.print(llvm::outs());
f.erase(); f.erase();
@ -503,10 +550,14 @@ TEST_FUNC(select_op_f32) {
Value i = ivs[0], j = ivs[1]; Value i = ivs[0], j = ivs[1];
std_select(eq(B(i, j), B(i + one, j)), A(zero, zero), A(i, j)); std_select(eq(B(i, j), B(i + one, j)), A(zero, zero), A(i, j));
std_select(ne(B(i, j), B(i + one, j)), A(zero, zero), A(i, j)); std_select(ne(B(i, j), B(i + one, j)), A(zero, zero), A(i, j));
std_select(B(i, j) >= B(i + one, j), A(zero, zero), A(i, j)); std_select(sge(B(i, j), B(i + one, j)), A(zero, zero), A(i, j));
std_select(B(i, j) <= B(i + one, j), A(zero, zero), A(i, j)); std_select(sle(B(i, j), B(i + one, j)), A(zero, zero), A(i, j));
std_select(B(i, j) < B(i + one, j), A(zero, zero), A(i, j)); std_select(slt(B(i, j), B(i + one, j)), A(zero, zero), A(i, j));
std_select(B(i, j) > B(i + one, j), A(zero, zero), A(i, j)); std_select(sgt(B(i, j), B(i + one, j)), A(zero, zero), A(i, j));
std_select(uge(B(i, j), B(i + one, j)), A(zero, zero), A(i, j));
std_select(ule(B(i, j), B(i + one, j)), A(zero, zero), A(i, j));
std_select(ult(B(i, j), B(i + one, j)), A(zero, zero), A(i, j));
std_select(ugt(B(i, j), B(i + one, j)), A(zero, zero), A(i, j));
}); });
// CHECK-LABEL: @select_op // CHECK-LABEL: @select_op
@ -554,6 +605,34 @@ TEST_FUNC(select_op_f32) {
// CHECK-DAG: affine.load // CHECK-DAG: affine.load
// CHECK-DAG: affine.apply // CHECK-DAG: affine.apply
// CHECK-NEXT: select // CHECK-NEXT: select
// CHECK-DAG: cmpf "oge"
// CHECK-DAG: affine.load
// CHECK-DAG: affine.load
// CHECK-DAG: affine.load
// CHECK-DAG: affine.load
// CHECK-DAG: affine.apply
// CHECK-NEXT: select
// CHECK-DAG: cmpf "ole"
// CHECK-DAG: affine.load
// CHECK-DAG: affine.load
// CHECK-DAG: affine.load
// CHECK-DAG: affine.load
// CHECK-DAG: affine.apply
// CHECK-NEXT: select
// CHECK-DAG: cmpf "olt"
// CHECK-DAG: affine.load
// CHECK-DAG: affine.load
// CHECK-DAG: affine.load
// CHECK-DAG: affine.load
// CHECK-DAG: affine.apply
// CHECK-NEXT: select
// CHECK-DAG: cmpf "ogt"
// CHECK-DAG: affine.load
// CHECK-DAG: affine.load
// CHECK-DAG: affine.load
// CHECK-DAG: affine.load
// CHECK-DAG: affine.apply
// CHECK-NEXT: select
// clang-format on // clang-format on
f.print(llvm::outs()); f.print(llvm::outs());
f.erase(); f.erase();