forked from OSchip/llvm-project
[mlir][tosa] Add remaining tosa comparison folders
Added numerical splat folders for comparison operations and equal of two identical int values. Reviewed By: NatashaKnk Differential Revision: https://reviews.llvm.org/D133138
This commit is contained in:
parent
a4d48e3b0b
commit
5a231720bc
|
@ -1143,6 +1143,8 @@ def Tosa_EqualOp : Tosa_Op<"equal", [InferTensorType, ResultsBroadcastableShape,
|
||||||
/// InferTypeOpInterface.
|
/// InferTypeOpInterface.
|
||||||
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
|
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
|
||||||
}];
|
}];
|
||||||
|
|
||||||
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1191,6 +1193,8 @@ def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [
|
||||||
let results = (outs
|
let results = (outs
|
||||||
I1Tensor:$output
|
I1Tensor:$output
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let hasFolder = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -675,6 +675,11 @@ struct APIntFoldGreater {
|
||||||
APIntFoldGreater() {}
|
APIntFoldGreater() {}
|
||||||
APInt operator()(APInt l, APInt r) { return APInt(1, l.sgt(r)); }
|
APInt operator()(APInt l, APInt r) { return APInt(1, l.sgt(r)); }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct APIntFoldGreaterEqual {
|
||||||
|
APIntFoldGreaterEqual() {}
|
||||||
|
APInt operator()(APInt l, APInt r) { return APInt(1, l.sge(r)); }
|
||||||
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
OpFoldResult GreaterOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult GreaterOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
@ -689,6 +694,42 @@ OpFoldResult GreaterOp::fold(ArrayRef<Attribute> operands) {
|
||||||
lhsAttr, rhsAttr, resultTy);
|
lhsAttr, rhsAttr, resultTy);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OpFoldResult GreaterEqualOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
auto resultTy = getType().dyn_cast<RankedTensorType>();
|
||||||
|
auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
|
||||||
|
auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
|
||||||
|
|
||||||
|
if (!lhsAttr || !rhsAttr)
|
||||||
|
return {};
|
||||||
|
|
||||||
|
return BinaryFolder<APIntFoldGreaterEqual,
|
||||||
|
ComparisonFold<std::greater_equal<APFloat>>>(
|
||||||
|
lhsAttr, rhsAttr, resultTy);
|
||||||
|
}
|
||||||
|
|
||||||
|
OpFoldResult EqualOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
auto resultTy = getType().dyn_cast<RankedTensorType>();
|
||||||
|
auto lhsAttr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
|
||||||
|
auto rhsAttr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
|
||||||
|
Value lhs = getInput1();
|
||||||
|
Value rhs = getInput2();
|
||||||
|
auto lhsTy = lhs.getType().cast<ShapedType>();
|
||||||
|
|
||||||
|
// If we are comparing an integer value to itself it is always true. We can
|
||||||
|
// not do this with float due to float values.
|
||||||
|
if (lhsTy.getElementType().isa<IntegerType>() && resultTy.hasStaticShape() &&
|
||||||
|
lhs == rhs) {
|
||||||
|
return DenseElementsAttr::get(resultTy, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!lhsAttr || !rhsAttr)
|
||||||
|
return {};
|
||||||
|
|
||||||
|
return BinaryFolder<ComparisonFold<std::equal_to<APInt>>,
|
||||||
|
ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
|
||||||
|
resultTy);
|
||||||
|
}
|
||||||
|
|
||||||
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
|
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
|
||||||
if (getInput().getType() == getType())
|
if (getInput().getType() == getType())
|
||||||
return getInput();
|
return getInput();
|
||||||
|
|
|
@ -350,50 +350,108 @@ func.func @fold_sub_splat_f32() -> tensor<10xf32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @fold_greater_splat_f32_true
|
// CHECK-LABEL: @fold_greater_splat_f32
|
||||||
func.func @fold_greater_splat_f32_true() -> tensor<10xi1> {
|
func.func @fold_greater_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
|
||||||
%one = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
|
%0 = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
|
||||||
%two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
|
%1 = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
|
||||||
%add = "tosa.greater"(%one, %two) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
|
%2 = "tosa.const"() {value = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32>
|
||||||
// CHECK: %[[BOOL:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
|
%3 = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
|
||||||
// CHECK: return %[[BOOL]]
|
%true = "tosa.greater"(%0, %1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
|
||||||
return %add : tensor<10xi1>
|
%false = "tosa.greater"(%2, %3) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
|
||||||
|
// CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
|
||||||
|
// CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
|
||||||
|
// CHECK: return %[[TRUE]], %[[FALSE]]
|
||||||
|
return %true, %false : tensor<10xi1>, tensor<10xi1>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @fold_greater_splat_f32_false
|
// CHECK-LABEL: @fold_greater_splat_i32
|
||||||
func.func @fold_greater_splat_f32_false() -> tensor<10xi1> {
|
func.func @fold_greater_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) {
|
||||||
%one = "tosa.const"() {value = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32>
|
%0 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
|
||||||
%two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
|
%1 = "tosa.const"() {value = dense<8> : tensor<10xi32>} : () -> tensor<10xi32>
|
||||||
%add = "tosa.greater"(%one, %two) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
|
%2 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
|
||||||
// CHECK: %[[BOOL:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
|
%3 = "tosa.const"() {value = dense<-12> : tensor<10xi32>} : () -> tensor<10xi32>
|
||||||
// CHECK: return %[[BOOL]]
|
%false = "tosa.greater"(%0, %1) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
|
||||||
return %add : tensor<10xi1>
|
%true = "tosa.greater"(%2, %3) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
|
||||||
|
// CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
|
||||||
|
// CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
|
||||||
|
// CHECK: return %[[FALSE]], %[[TRUE]]
|
||||||
|
return %false, %true : tensor<10xi1>, tensor<10xi1>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @fold_greater_splat_i32_false
|
// CHECK-LABEL: @fold_greater_eq_splat_f32
|
||||||
func.func @fold_greater_splat_i32_false() -> tensor<10xi1> {
|
func.func @fold_greater_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
|
||||||
%one = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
|
%0 = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
|
||||||
%two = "tosa.const"() {value = dense<8> : tensor<10xi32>} : () -> tensor<10xi32>
|
%1 = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
|
||||||
%add = "tosa.greater"(%one, %two) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
|
%2 = "tosa.const"() {value = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32>
|
||||||
// CHECK: %[[BOOL:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
|
%3 = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
|
||||||
// CHECK: return %[[BOOL]]
|
%true = "tosa.greater_equal"(%0, %1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
|
||||||
return %add : tensor<10xi1>
|
%false = "tosa.greater_equal"(%2, %3) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
|
||||||
|
// CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
|
||||||
|
// CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
|
||||||
|
// CHECK: return %[[TRUE]], %[[FALSE]]
|
||||||
|
return %true, %false : tensor<10xi1>, tensor<10xi1>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @fold_greater_splat_i32_true
|
// CHECK-LABEL: @fold_greater_eq_splat_i32
|
||||||
func.func @fold_greater_splat_i32_true() -> tensor<10xi1> {
|
func.func @fold_greater_eq_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) {
|
||||||
%one = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
|
%0 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
|
||||||
%two = "tosa.const"() {value = dense<-12> : tensor<10xi32>} : () -> tensor<10xi32>
|
%1 = "tosa.const"() {value = dense<8> : tensor<10xi32>} : () -> tensor<10xi32>
|
||||||
%add = "tosa.greater"(%one, %two) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
|
%2 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
|
||||||
// CHECK: %[[BOOL:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
|
%3 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
|
||||||
// CHECK: return %[[BOOL]]
|
%true = "tosa.greater_equal"(%2, %3) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
|
||||||
return %add : tensor<10xi1>
|
%false = "tosa.greater_equal"(%0, %1) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
|
||||||
|
// CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
|
||||||
|
// CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
|
||||||
|
// CHECK: return %[[TRUE]], %[[FALSE]]
|
||||||
|
return %true, %false : tensor<10xi1>, tensor<10xi1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @fold_eq_splat_f32
|
||||||
|
func.func @fold_eq_splat_f32() -> (tensor<10xi1>, tensor<10xi1>) {
|
||||||
|
%0 = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
|
||||||
|
%1 = "tosa.const"() {value = dense<4.0> : tensor<10xf32>} : () -> tensor<10xf32>
|
||||||
|
%2 = "tosa.const"() {value = dense<1.0> : tensor<10xf32>} : () -> tensor<10xf32>
|
||||||
|
%3 = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32>
|
||||||
|
%true = "tosa.equal"(%0, %1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
|
||||||
|
%false = "tosa.equal"(%2, %3) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xi1>
|
||||||
|
// CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
|
||||||
|
// CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
|
||||||
|
// CHECK: return %[[TRUE]], %[[FALSE]]
|
||||||
|
return %true, %false : tensor<10xi1>, tensor<10xi1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @fold_eq_splat_i32
|
||||||
|
func.func @fold_eq_splat_i32() -> (tensor<10xi1>, tensor<10xi1>) {
|
||||||
|
%0 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
|
||||||
|
%1 = "tosa.const"() {value = dense<8> : tensor<10xi32>} : () -> tensor<10xi32>
|
||||||
|
%2 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
|
||||||
|
%3 = "tosa.const"() {value = dense<-10> : tensor<10xi32>} : () -> tensor<10xi32>
|
||||||
|
%true = "tosa.equal"(%2, %3) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
|
||||||
|
%false = "tosa.equal"(%0, %1) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
|
||||||
|
// CHECK-DAG: %[[TRUE:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
|
||||||
|
// CHECK-DAG: %[[FALSE:.+]] = "tosa.const"() {value = dense<false> : tensor<10xi1>}
|
||||||
|
// CHECK: return %[[TRUE]], %[[FALSE]]
|
||||||
|
return %true, %false : tensor<10xi1>, tensor<10xi1>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @fold_eq_i32
|
||||||
|
func.func @fold_eq_i32(%arg0 : tensor<10xi32>) -> (tensor<10xi1>) {
|
||||||
|
// CHECK: %[[TRUE:.+]] = "tosa.const"() {value = dense<true> : tensor<10xi1>}
|
||||||
|
%0 = "tosa.equal"(%arg0, %arg0) : (tensor<10xi32>, tensor<10xi32>) -> tensor<10xi1>
|
||||||
|
// CHECK: return %[[TRUE]]
|
||||||
|
return %0 : tensor<10xi1>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
Loading…
Reference in New Issue