[mlir][tosa] Add remaining tosa comparison folders

Added numerical splat folders for comparison operations and
equal of two identical int values.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D133138
This commit is contained in:
Rob Suderman 2022-09-01 14:10:19 -07:00
parent a4d48e3b0b
commit 5a231720bc
3 changed files with 135 additions and 32 deletions

View File

@ -1143,6 +1143,8 @@ def Tosa_EqualOp : Tosa_Op<"equal", [InferTensorType, ResultsBroadcastableShape,
/// InferTypeOpInterface.
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
@ -1191,6 +1193,8 @@ def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [
let results = (outs
I1Tensor:$output
);
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//

View File

@ -675,6 +675,11 @@ struct APIntFoldGreater {
APIntFoldGreater() {}
APInt operator()(APInt l, APInt r) { return APInt(1, l.sgt(r)); }
};
struct APIntFoldGreaterEqual {
APIntFoldGreaterEqual() {}
APInt operator()(APInt l, APInt r) { return APInt(1, l.sge(r)); }
};
} // namespace
OpFoldResult GreaterOp::fold(ArrayRef<Attribute> operands) {
@ -689,6 +694,42 @@ OpFoldResult GreaterOp::fold(ArrayRef<Attribute> operands) {
lhsAttr, rhsAttr, resultTy);
}
OpFoldResult GreaterEqualOp::fold(ArrayRef<Attribute> operands) {
auto resultTy = getType().dyn_cast<RankedTensorType>();
auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
if (!lhsAttr || !rhsAttr)
return {};
return BinaryFolder<APIntFoldGreaterEqual,
ComparisonFold<std::greater_equal<APFloat>>>(
lhsAttr, rhsAttr, resultTy);
}
OpFoldResult EqualOp::fold(ArrayRef<Attribute> operands) {
auto resultTy = getType().dyn_cast<RankedTensorType>();
auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
Value lhs = getInput1();
Value rhs = getInput2();
auto lhsTy = lhs.getType().cast<ShapedType>();
// If we are comparing an integer value to itself it is always true. We can
// not do this with float due to float values.
if (lhsTy.getElementType().isa<IntegerType>() && resultTy.hasStaticShape() &&
lhs == rhs) {
return DenseElementsAttr::get(resultTy, true);
}
if (!lhsAttr || !rhsAttr)
return {};
return BinaryFolder<ComparisonFold<std::equal_to<APInt>>,
ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
resultTy);
}
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
if (getInput().getType() == getType())
return getInput();

View File

@ -350,50 +350,108 @@ func.func @fold_sub_splat_f32() -> tensor<10xf32> {
// -----
// CHECK-LABEL: @fold_greater_splat_f32_true
func.func @fold_greater_splat_f32_true() -> tensor<10xi1> {
%one = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
%two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
%add = "tosa.greater"(%one, %two) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
// CHECK: %[[BOOL:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
// CHECK: return %[[BOOL]]
return %add : tensor<10xi1>
// CHECK-LABEL: @fold_greater_splat_f32
func.func @fold_greater_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
%0 = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
%1 = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
%2 = "tosa.const"() {value = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32>
%3 = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
%true = "tosa.greater"(%0, %1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
%false = "tosa.greater"(%2, %3) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
// CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
// CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
// CHECK: return %[[TRUE]], %[[FALSE]]
return %true, %false : tensor<10xi1>, tensor<10xi1>
}
// -----
// CHECK-LABEL: @fold_greater_splat_f32_false
func.func @fold_greater_splat_f32_false() -> tensor<10xi1> {
%one = "tosa.const"() {value = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32>
%two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
%add = "tosa.greater"(%one, %two) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
// CHECK: %[[BOOL:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
// CHECK: return %[[BOOL]]
return %add : tensor<10xi1>
// CHECK-LABEL: @fold_greater_splat_i32
func.func @fold_greater_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) {
%0 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
%1 = "tosa.const"() {value = dense<8> : tensor<10xi32>} : () -> tensor<10xi32>
%2 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
%3 = "tosa.const"() {value = dense<-12> : tensor<10xi32>} : () -> tensor<10xi32>
%false = "tosa.greater"(%0, %1) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
%true = "tosa.greater"(%2, %3) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
// CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
// CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
// CHECK: return %[[FALSE]], %[[TRUE]]
return %false, %true : tensor<10xi1>, tensor<10xi1>
}
// -----
// CHECK-LABEL: @fold_greater_splat_i32_false
func.func @fold_greater_splat_i32_false() -> tensor<10xi1> {
%one = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
%two = "tosa.const"() {value = dense<8> : tensor<10xi32>} : () -> tensor<10xi32>
%add = "tosa.greater"(%one, %two) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
// CHECK: %[[BOOL:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
// CHECK: return %[[BOOL]]
return %add : tensor<10xi1>
// CHECK-LABEL: @fold_greater_eq_splat_f32
func.func @fold_greater_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
%0 = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
%1 = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
%2 = "tosa.const"() {value = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32>
%3 = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
%true = "tosa.greater_equal"(%0, %1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
%false = "tosa.greater_equal"(%2, %3) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
// CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
// CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
// CHECK: return %[[TRUE]], %[[FALSE]]
return %true, %false : tensor<10xi1>, tensor<10xi1>
}
// -----
// CHECK-LABEL: @fold_greater_splat_i32_true
func.func @fold_greater_splat_i32_true() -> tensor<10xi1> {
%one = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
%two = "tosa.const"() {value = dense<-12> : tensor<10xi32>} : () -> tensor<10xi32>
%add = "tosa.greater"(%one, %two) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
// CHECK: %[[BOOL:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
// CHECK: return %[[BOOL]]
return %add : tensor<10xi1>
// CHECK-LABEL: @fold_greater_eq_splat_i32
func.func @fold_greater_eq_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) {
%0 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
%1 = "tosa.const"() {value = dense<8> : tensor<10xi32>} : () -> tensor<10xi32>
%2 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
%3 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
%true = "tosa.greater_equal"(%2, %3) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
%false = "tosa.greater_equal"(%0, %1) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
// CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
// CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
// CHECK: return %[[TRUE]], %[[FALSE]]
return %true, %false : tensor<10xi1>, tensor<10xi1>
}
// -----
// CHECK-LABEL: @fold_eq_splat_f32
func.func @fold_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
%0 = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
%1 = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
%2 = "tosa.const"() {value = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32>
%3 = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
%true = "tosa.equal"(%0, %1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
%false = "tosa.equal"(%2, %3) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
// CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
// CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
// CHECK: return %[[TRUE]], %[[FALSE]]
return %true, %false : tensor<10xi1>, tensor<10xi1>
}
// -----
// CHECK-LABEL: @fold_eq_splat_i32
func.func @fold_eq_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) {
%0 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
%1 = "tosa.const"() {value = dense<8> : tensor<10xi32>} : () -> tensor<10xi32>
%2 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
%3 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
%true = "tosa.equal"(%2, %3) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
%false = "tosa.equal"(%0, %1) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
// CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
// CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
// CHECK: return %[[TRUE]], %[[FALSE]]
return %true, %false : tensor<10xi1>, tensor<10xi1>
}
// -----
// CHECK-LABEL: @fold_eq_i32
func.func @fold_eq_i32(%arg0 : tensor<10xi32>) -> (tensor<10xi1>) {
// CHECK: %[[TRUE:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
%0 = "tosa.equal"(%arg0, %arg0) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
// CHECK: return %[[TRUE]]
return %0 : tensor<10xi1>
}
// -----