diff --git a/mlir/include/mlir/Dialect/Affine/EDSC/Builders.h b/mlir/include/mlir/Dialect/Affine/EDSC/Builders.h index 5c99a430c862..96191e01296a 100644 --- a/mlir/include/mlir/Dialect/Affine/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/Affine/EDSC/Builders.h @@ -66,10 +66,14 @@ Value operator^(Value lhs, Value rhs); /// Comparison operator overloadings. Value eq(Value lhs, Value rhs); Value ne(Value lhs, Value rhs); -Value operator<(Value lhs, Value rhs); -Value operator<=(Value lhs, Value rhs); -Value operator>(Value lhs, Value rhs); -Value operator>=(Value lhs, Value rhs); +Value slt(Value lhs, Value rhs); +Value sle(Value lhs, Value rhs); +Value sgt(Value lhs, Value rhs); +Value sge(Value lhs, Value rhs); +Value ult(Value lhs, Value rhs); +Value ule(Value lhs, Value rhs); +Value ugt(Value lhs, Value rhs); +Value uge(Value lhs, Value rhs); } // namespace op @@ -159,24 +163,44 @@ Value TemplatedIndexedValue::ne(Value e) { return ne(value, e); } template -Value TemplatedIndexedValue::operator<(Value e) { - using op::operator<; - return static_cast(*this) < e; +Value TemplatedIndexedValue::slt(Value e) { + using op::slt; + return slt(static_cast(*this), e); } template -Value TemplatedIndexedValue::operator<=(Value e) { - using op::operator<=; - return static_cast(*this) <= e; +Value TemplatedIndexedValue::sle(Value e) { + using op::sle; + return sle(static_cast(*this), e); } template -Value TemplatedIndexedValue::operator>(Value e) { - using op::operator>; - return static_cast(*this) > e; +Value TemplatedIndexedValue::sgt(Value e) { + using op::sgt; + return sgt(static_cast(*this), e); } template -Value TemplatedIndexedValue::operator>=(Value e) { - using op::operator>=; - return static_cast(*this) >= e; +Value TemplatedIndexedValue::sge(Value e) { + using op::sge; + return sge(static_cast(*this), e); +} +template +Value TemplatedIndexedValue::ult(Value e) { + using op::ult; + return ult(static_cast(*this), e); +} +template +Value TemplatedIndexedValue::ule(Value e) { + using op::ule; + return ule(static_cast(*this), e); +} +template +Value TemplatedIndexedValue::ugt(Value e) { + using op::ugt; + return ugt(static_cast(*this), e); +} +template +Value TemplatedIndexedValue::uge(Value e) { + using op::uge; + return uge(static_cast(*this), e); } } // namespace edsc diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h index a7c5506f7ab0..64df2c9fe367 100644 --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -288,21 +288,37 @@ public: /// Comparison operator overloadings. Value eq(Value e); Value ne(Value e); - Value operator<(Value e); - Value operator<=(Value e); - Value operator>(Value e); - Value operator>=(Value e); - Value operator<(TemplatedIndexedValue e) { - return *this < static_cast(e); + Value slt(Value e); + Value sle(Value e); + Value sgt(Value e); + Value sge(Value e); + Value ult(Value e); + Value ule(Value e); + Value ugt(Value e); + Value uge(Value e); + Value slt(TemplatedIndexedValue e) { + return slt(*this, static_cast(e)); } - Value operator<=(TemplatedIndexedValue e) { - return *this <= static_cast(e); + Value sle(TemplatedIndexedValue e) { + return sle(*this, static_cast(e)); } - Value operator>(TemplatedIndexedValue e) { - return *this > static_cast(e); + Value sgt(TemplatedIndexedValue e) { + return sgt(*this, static_cast(e)); } - Value operator>=(TemplatedIndexedValue e) { - return *this >= static_cast(e); + Value sge(TemplatedIndexedValue e) { + return sge(*this, static_cast(e)); + } + Value ult(TemplatedIndexedValue e) { + return ult(*this, static_cast(e)); + } + Value ule(TemplatedIndexedValue e) { + return ule(*this, static_cast(e)); + } + Value ugt(TemplatedIndexedValue e) { + return ugt(*this, static_cast(e)); + } + Value uge(TemplatedIndexedValue e) { + return uge(*this, static_cast(e)); } private: diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index 99ded0686a54..cf3d9653d7df 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -187,7 +187,7 @@ Value NDTransferOpHelper::emitInBoundsCondition( using namespace mlir::edsc::op; majorIvsPlusOffsets.push_back(iv + off); if (xferOp.isMaskedDim(leadingRank + idx)) { - Value inBounds = majorIvsPlusOffsets.back() < ub; + Value inBounds = slt(majorIvsPlusOffsets.back(), ub); inBoundsCondition = (inBoundsCondition) ? (inBoundsCondition && inBounds) : inBounds; } @@ -433,16 +433,16 @@ clip(TransferOpTy transfer, MemRefBoundsCapture &bounds, ArrayRef ivs) { auto i = memRefAccess[memRefDim]; if (loopIndex < 0) { auto N_minus_1 = N - one; - auto select_1 = std_select(i < N, i, N_minus_1); + auto select_1 = std_select(slt(i, N), i, N_minus_1); clippedScalarAccessExprs[memRefDim] = - std_select(i < zero, zero, select_1); + std_select(slt(i, zero), zero, select_1); } else { auto ii = ivs[loopIndex]; auto i_plus_ii = i + ii; auto N_minus_1 = N - one; - auto select_1 = std_select(i_plus_ii < N, i_plus_ii, N_minus_1); + auto select_1 = std_select(slt(i_plus_ii, N), i_plus_ii, N_minus_1); clippedScalarAccessExprs[memRefDim] = - std_select(i_plus_ii < zero, zero, select_1); + std_select(slt(i_plus_ii, zero), zero, select_1); } } diff --git a/mlir/lib/Dialect/Affine/EDSC/Builders.cpp b/mlir/lib/Dialect/Affine/EDSC/Builders.cpp index 559f375c5dff..e5bf1c015e02 100644 --- a/mlir/lib/Dialect/Affine/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Affine/EDSC/Builders.cpp @@ -221,29 +221,51 @@ Value mlir::edsc::op::ne(Value lhs, Value rhs) { ? createFComparisonExpr(CmpFPredicate::ONE, lhs, rhs) : createIComparisonExpr(CmpIPredicate::ne, lhs, rhs); } -Value mlir::edsc::op::operator<(Value lhs, Value rhs) { +Value mlir::edsc::op::slt(Value lhs, Value rhs) { auto type = lhs.getType(); return type.isa() ? createFComparisonExpr(CmpFPredicate::OLT, lhs, rhs) - : - // TODO(ntv,zinenko): signed by default, how about unsigned? - createIComparisonExpr(CmpIPredicate::slt, lhs, rhs); + : createIComparisonExpr(CmpIPredicate::slt, lhs, rhs); } -Value mlir::edsc::op::operator<=(Value lhs, Value rhs) { +Value mlir::edsc::op::sle(Value lhs, Value rhs) { auto type = lhs.getType(); return type.isa() ? createFComparisonExpr(CmpFPredicate::OLE, lhs, rhs) : createIComparisonExpr(CmpIPredicate::sle, lhs, rhs); } -Value mlir::edsc::op::operator>(Value lhs, Value rhs) { +Value mlir::edsc::op::sgt(Value lhs, Value rhs) { auto type = lhs.getType(); return type.isa() ? createFComparisonExpr(CmpFPredicate::OGT, lhs, rhs) : createIComparisonExpr(CmpIPredicate::sgt, lhs, rhs); } -Value mlir::edsc::op::operator>=(Value lhs, Value rhs) { +Value mlir::edsc::op::sge(Value lhs, Value rhs) { auto type = lhs.getType(); return type.isa() ? createFComparisonExpr(CmpFPredicate::OGE, lhs, rhs) : createIComparisonExpr(CmpIPredicate::sge, lhs, rhs); } +Value mlir::edsc::op::ult(Value lhs, Value rhs) { + auto type = lhs.getType(); + return type.isa() + ? createFComparisonExpr(CmpFPredicate::OLT, lhs, rhs) + : createIComparisonExpr(CmpIPredicate::ult, lhs, rhs); +} +Value mlir::edsc::op::ule(Value lhs, Value rhs) { + auto type = lhs.getType(); + return type.isa() + ? createFComparisonExpr(CmpFPredicate::OLE, lhs, rhs) + : createIComparisonExpr(CmpIPredicate::ule, lhs, rhs); +} +Value mlir::edsc::op::ugt(Value lhs, Value rhs) { + auto type = lhs.getType(); + return type.isa() + ? createFComparisonExpr(CmpFPredicate::OGT, lhs, rhs) + : createIComparisonExpr(CmpIPredicate::ugt, lhs, rhs); +} +Value mlir::edsc::op::uge(Value lhs, Value rhs) { + auto type = lhs.getType(); + return type.isa() + ? createFComparisonExpr(CmpFPredicate::OGE, lhs, rhs) + : createIComparisonExpr(CmpIPredicate::uge, lhs, rhs); +} diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp index 5016ca9b3055..8cfc25d2ff8e 100644 --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -169,8 +169,8 @@ Operation *mlir::edsc::ops::linalg_generic_pointwise_max(StructuredIndexed I1, StructuredIndexed I2, StructuredIndexed O) { BinaryPointwiseOpBuilder binOp([](Value a, Value b) -> Value { - using edsc::op::operator>; - return std_select(a > b, a, b); + using edsc::op::sgt; + return std_select(sgt(a, b), a, b); }); return linalg_generic_pointwise(binOp, I1, I2, O); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp index d031712ce5a9..ec57717eaca9 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -263,16 +263,16 @@ Value getConvOpInput(ConvOp convOp, StdIndexedValue im, continue; } - using edsc::op::operator<; - using edsc::op::operator>=; + using edsc::op::sge; + using edsc::op::slt; using edsc::op::operator||; - Value leftOutOfBound = dim < zeroIndex; + Value leftOutOfBound = slt(dim, zeroIndex); if (conds.empty()) conds.push_back(leftOutOfBound); else conds.push_back(conds.back() || leftOutOfBound); Value rightBound = std_dim(convOp.input(), idx); - conds.push_back(conds.back() || (dim >= rightBound)); + conds.push_back(conds.back() || (sge(dim, rightBound))); // When padding is involved, the indices will only be shifted to negative, // so having a max op is enough. @@ -337,8 +337,8 @@ void emitScalarImplementation(ArrayRef allIvs, PoolingMaxOp op) { // Emit scalar form. Value lhs = std_load(op.output(), indices.outputs); Value rhs = std_load(op.input(), indices.inputs); - using edsc::op::operator>; - Value maxValue = std_select(lhs > rhs, lhs, rhs); + using edsc::op::sgt; + Value maxValue = std_select(sgt(lhs, rhs), lhs, rhs); std_store(maxValue, op.output(), indices.outputs); } template @@ -347,8 +347,8 @@ void emitScalarImplementation(ArrayRef allIvs, PoolingMinOp op) { // Emit scalar form. Value lhs = std_load(op.output(), indices.outputs); Value rhs = std_load(op.input(), indices.inputs); - using edsc::op::operator<; - Value minValue = std_select(lhs < rhs, lhs, rhs); + using edsc::op::slt; + Value minValue = std_select(slt(lhs, rhs), lhs, rhs); std_store(minValue, op.output(), indices.outputs); } template diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp index 31748810f899..73f7adeeaf71 100644 --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -459,9 +459,9 @@ TEST_FUNC(diviu_op_i32) { TEST_FUNC(select_op_i32) { using namespace edsc::op; - auto f32Type = FloatType::getF32(&globalContext()); + auto i32Type = IntegerType::get(32, &globalContext()); auto memrefType = MemRefType::get( - {ShapedType::kDynamicSize, ShapedType::kDynamicSize}, f32Type, {}, 0); + {ShapedType::kDynamicSize, ShapedType::kDynamicSize}, i32Type, {}, 0); auto f = makeFunction("select_op", {}, {memrefType}); OpBuilder builder(f.getBody()); @@ -470,7 +470,18 @@ TEST_FUNC(select_op_i32) { MemRefBoundsCapture vA(f.getArgument(0)); AffineIndexedValue A(f.getArgument(0)); affineLoopNestBuilder({zero, zero}, {one, one}, {1, 1}, [&](ValueRange ivs) { - std_select(eq(ivs[0], zero), A(zero, zero), A(ivs[0], ivs[1])); + using namespace edsc::op; + Value i = ivs[0], j = ivs[1]; + std_select(eq(i, zero), A(zero, zero), A(i, j)); + std_select(ne(i, zero), A(zero, zero), A(i, j)); + std_select(slt(i, zero), A(zero, zero), A(i, j)); + std_select(sle(i, zero), A(zero, zero), A(i, j)); + std_select(sgt(i, zero), A(zero, zero), A(i, j)); + std_select(sge(i, zero), A(zero, zero), A(i, j)); + std_select(ult(i, zero), A(zero, zero), A(i, j)); + std_select(ule(i, zero), A(zero, zero), A(i, j)); + std_select(ugt(i, zero), A(zero, zero), A(i, j)); + std_select(uge(i, zero), A(zero, zero), A(i, j)); }); // clang-format off @@ -481,6 +492,42 @@ TEST_FUNC(select_op_i32) { // CHECK-DAG: {{.*}} = affine.load // CHECK-DAG: {{.*}} = affine.load // CHECK-NEXT: {{.*}} = select + // CHECK-DAG: {{.*}} = cmpi "ne" + // CHECK-DAG: {{.*}} = affine.load + // CHECK-DAG: {{.*}} = affine.load + // CHECK-NEXT: {{.*}} = select + // CHECK-DAG: {{.*}} = cmpi "slt" + // CHECK-DAG: {{.*}} = affine.load + // CHECK-DAG: {{.*}} = affine.load + // CHECK-NEXT: {{.*}} = select + // CHECK-DAG: {{.*}} = cmpi "sle" + // CHECK-DAG: {{.*}} = affine.load + // CHECK-DAG: {{.*}} = affine.load + // CHECK-NEXT: {{.*}} = select + // CHECK-DAG: {{.*}} = cmpi "sgt" + // CHECK-DAG: {{.*}} = affine.load + // CHECK-DAG: {{.*}} = affine.load + // CHECK-NEXT: {{.*}} = select + // CHECK-DAG: {{.*}} = cmpi "sge" + // CHECK-DAG: {{.*}} = affine.load + // CHECK-DAG: {{.*}} = affine.load + // CHECK-NEXT: {{.*}} = select + // CHECK-DAG: {{.*}} = cmpi "ult" + // CHECK-DAG: {{.*}} = affine.load + // CHECK-DAG: {{.*}} = affine.load + // CHECK-NEXT: {{.*}} = select + // CHECK-DAG: {{.*}} = cmpi "ule" + // CHECK-DAG: {{.*}} = affine.load + // CHECK-DAG: {{.*}} = affine.load + // CHECK-NEXT: {{.*}} = select + // CHECK-DAG: {{.*}} = cmpi "ugt" + // CHECK-DAG: {{.*}} = affine.load + // CHECK-DAG: {{.*}} = affine.load + // CHECK-NEXT: {{.*}} = select + // CHECK-DAG: {{.*}} = cmpi "uge" + // CHECK-DAG: {{.*}} = affine.load + // CHECK-DAG: {{.*}} = affine.load + // CHECK-NEXT: {{.*}} = select // clang-format on f.print(llvm::outs()); f.erase(); @@ -503,10 +550,14 @@ TEST_FUNC(select_op_f32) { Value i = ivs[0], j = ivs[1]; std_select(eq(B(i, j), B(i + one, j)), A(zero, zero), A(i, j)); std_select(ne(B(i, j), B(i + one, j)), A(zero, zero), A(i, j)); - std_select(B(i, j) >= B(i + one, j), A(zero, zero), A(i, j)); - std_select(B(i, j) <= B(i + one, j), A(zero, zero), A(i, j)); - std_select(B(i, j) < B(i + one, j), A(zero, zero), A(i, j)); - std_select(B(i, j) > B(i + one, j), A(zero, zero), A(i, j)); + std_select(sge(B(i, j), B(i + one, j)), A(zero, zero), A(i, j)); + std_select(sle(B(i, j), B(i + one, j)), A(zero, zero), A(i, j)); + std_select(slt(B(i, j), B(i + one, j)), A(zero, zero), A(i, j)); + std_select(sgt(B(i, j), B(i + one, j)), A(zero, zero), A(i, j)); + std_select(uge(B(i, j), B(i + one, j)), A(zero, zero), A(i, j)); + std_select(ule(B(i, j), B(i + one, j)), A(zero, zero), A(i, j)); + std_select(ult(B(i, j), B(i + one, j)), A(zero, zero), A(i, j)); + std_select(ugt(B(i, j), B(i + one, j)), A(zero, zero), A(i, j)); }); // CHECK-LABEL: @select_op @@ -554,6 +605,34 @@ TEST_FUNC(select_op_f32) { // CHECK-DAG: affine.load // CHECK-DAG: affine.apply // CHECK-NEXT: select + // CHECK-DAG: cmpf "oge" + // CHECK-DAG: affine.load + // CHECK-DAG: affine.load + // CHECK-DAG: affine.load + // CHECK-DAG: affine.load + // CHECK-DAG: affine.apply + // CHECK-NEXT: select + // CHECK-DAG: cmpf "ole" + // CHECK-DAG: affine.load + // CHECK-DAG: affine.load + // CHECK-DAG: affine.load + // CHECK-DAG: affine.load + // CHECK-DAG: affine.apply + // CHECK-NEXT: select + // CHECK-DAG: cmpf "olt" + // CHECK-DAG: affine.load + // CHECK-DAG: affine.load + // CHECK-DAG: affine.load + // CHECK-DAG: affine.load + // CHECK-DAG: affine.apply + // CHECK-NEXT: select + // CHECK-DAG: cmpf "ogt" + // CHECK-DAG: affine.load + // CHECK-DAG: affine.load + // CHECK-DAG: affine.load + // CHECK-DAG: affine.load + // CHECK-DAG: affine.apply + // CHECK-NEXT: select // clang-format on f.print(llvm::outs()); f.erase();