diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index a906076c3141..05dec4f34a32 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1235,9 +1235,9 @@ def SameOperandsShape : NativeOpTrait<"SameOperandsShape">; def SameOperandsAndResultShape : NativeOpTrait<"SameOperandsAndResultShape">; // Op has the same operand and result type. def SameOperandsAndResultType : NativeOpTrait<"SameOperandsAndResultType">; -// Op has the same element type for all operands. +// Op has the same element type (or type itself, if scalar) for all operands. def SameOperandsElementType : NativeOpTrait<"SameOperandsElementType">; -// Op has the same operand and result element type. +// Op has the same operand and result element type (or type itself, if scalar). def SameOperandsAndResultElementType : NativeOpTrait<"SameOperandsAndResultElementType">; // Op is a terminator. diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index 0b0c1877dcdd..dd82e7b7f715 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -645,7 +645,7 @@ public: }; /// This class provides verification for ops that are known to have the same -/// operand element type. +/// operand element type (or the type itself if it is scalar). /// template class SameOperandsElementType @@ -657,7 +657,7 @@ public: }; /// This class provides verification for ops that are known to have the same -/// operand and result element type. +/// operand and result element type (or the type itself if it is scalar). /// template class SameOperandsAndResultElementType diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 2c05bb49b94e..23983bc610dc 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -25,7 +25,9 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" #include + using namespace mlir; /// Form the OperationName for an op with the specified string. This either is @@ -800,17 +802,10 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) { LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) { if (failed(verifyAtLeastNOperands(op, 1))) return failure(); + auto elementType = getElementTypeOrSelf(op->getOperand(0)); - auto type = op->getOperand(0)->getType().dyn_cast(); - if (!type) - return op->emitOpError("requires shaped type results"); - auto elementType = type.getElementType(); - - for (auto operandType : llvm::drop_begin(op->getOperandTypes(), 1)) { - auto shapedType = operandType.dyn_cast(); - if (!shapedType) - return op->emitOpError("requires shaped type operands"); - if (shapedType.getElementType() != elementType) + for (auto operand : llvm::drop_begin(op->getOperands(), 1)) { + if (getElementTypeOrSelf(operand) != elementType) return op->emitOpError("requires the same element type for all operands"); } @@ -823,27 +818,18 @@ OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { failed(verifyAtLeastNResults(op, 1))) return failure(); - auto type = op->getResult(0)->getType().dyn_cast(); - if (!type) - return op->emitOpError("requires shaped type results"); - auto elementType = type.getElementType(); + auto elementType = getElementTypeOrSelf(op->getResult(0)); // Verify result element type matches first result's element type. for (auto result : drop_begin(op->getResults(), 1)) { - auto resultType = result->getType().dyn_cast(); - if (!resultType) - return op->emitOpError("requires shaped type results"); - if (resultType.getElementType() != elementType) + if (getElementTypeOrSelf(result) != elementType) return op->emitOpError( "requires the same element type for all operands and results"); } // Verify operand's element type matches first result's element type. for (auto operand : op->getOperands()) { - auto operandType = operand->getType().dyn_cast(); - if (!operandType) - return op->emitOpError("requires shaped type operands"); - if (operandType.getElementType() != elementType) + if (getElementTypeOrSelf(operand) != elementType) return op->emitOpError( "requires the same element type for all operands and results"); } diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir index bd07bc08950e..926547ccd7df 100644 --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -1,12 +1,16 @@ // RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s // CHECK: succeededSameOperandsElementType -func @succeededSameOperandsElementType(%t10x10 : tensor<10x10xf32>, %t1f: tensor<1xf32>, %v1: vector<1xf32>, %t1i: tensor<1xi32>) { +func @succeededSameOperandsElementType(%t10x10 : tensor<10x10xf32>, %t1f: tensor<1xf32>, %v1: vector<1xf32>, %t1i: tensor<1xi32>, %sf: f32) { %0 = "test.same_operand_element_type"(%t1f, %t1f) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi32> %1 = "test.same_operand_element_type"(%t1f, %t10x10) : (tensor<1xf32>, tensor<10x10xf32>) -> tensor<1xi32> %2 = "test.same_operand_element_type"(%t10x10, %v1) : (tensor<10x10xf32>, vector<1xf32>) -> tensor<1xi32> %3 = "test.same_operand_element_type"(%v1, %t1f) : (vector<1xf32>, tensor<1xf32>) -> tensor<1xi32> %4 = "test.same_operand_element_type"(%v1, %t1f) : (vector<1xf32>, tensor<1xf32>) -> tensor<121xi32> + %5 = "test.same_operand_element_type"(%sf, %sf) : (f32, f32) -> i32 + %6 = "test.same_operand_element_type"(%sf, %t1f) : (f32, tensor<1xf32>) -> tensor<121xi32> + %7 = "test.same_operand_element_type"(%sf, %v1) : (f32, vector<1xf32>) -> tensor<121xi32> + %8 = "test.same_operand_element_type"(%sf, %t10x10) : (f32, tensor<10x10xf32>) -> tensor<121xi32> return } @@ -26,13 +30,24 @@ func @failedSameOperandAndResultElementType_no_operands() { // ----- +func @failedSameOperandElementType_scalar_type_mismatch(%si: i32, %sf: f32) { + // expected-error@+1 {{requires the same element type for all operands}} + %0 = "test.same_operand_element_type"(%sf, %si) : (f32, i32) -> tensor<1xf32> +} + +// ----- + // CHECK: succeededSameOperandAndResultElementType -func @succeededSameOperandAndResultElementType(%t10x10 : tensor<10x10xf32>, %t1f: tensor<1xf32>, %v1: vector<1xf32>, %t1i: tensor<1xi32>) { +func @succeededSameOperandAndResultElementType(%t10x10 : tensor<10x10xf32>, %t1f: tensor<1xf32>, %v1: vector<1xf32>, %t1i: tensor<1xi32>, %sf: f32) { %0 = "test.same_operand_and_result_element_type"(%t1f, %t1f) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> %1 = "test.same_operand_and_result_element_type"(%t1f, %t10x10) : (tensor<1xf32>, tensor<10x10xf32>) -> tensor<1xf32> %2 = "test.same_operand_and_result_element_type"(%t10x10, %v1) : (tensor<10x10xf32>, vector<1xf32>) -> tensor<1xf32> %3 = "test.same_operand_and_result_element_type"(%v1, %t1f) : (vector<1xf32>, tensor<1xf32>) -> tensor<1xf32> %4 = "test.same_operand_and_result_element_type"(%v1, %t1f) : (vector<1xf32>, tensor<1xf32>) -> tensor<121xf32> + %5 = "test.same_operand_and_result_element_type"(%sf, %sf) : (f32, f32) -> f32 + %6 = "test.same_operand_and_result_element_type"(%sf, %t1f) : (f32, tensor<1xf32>) -> tensor<121xf32> + %7 = "test.same_operand_and_result_element_type"(%sf, %v1) : (f32, vector<1xf32>) -> tensor<121xf32> + %8 = "test.same_operand_and_result_element_type"(%sf, %t10x10) : (f32, tensor<10x10xf32>) -> tensor<121xf32> return } @@ -52,6 +67,13 @@ func @failedSameOperandAndResultElementType_operand_mismatch(%t1f: tensor<1xf32> // ----- +func @failedSameOperandAndResultElementType_result_mismatch(%t1f: tensor<1xf32>) { + // expected-error@+1 {{requires the same element type for all operands and results}} + %0:2 = "test.same_operand_and_result_element_type"(%t1f) : (tensor<1xf32>) -> (tensor<1xf32>, tensor<1xi32>) +} + +// ----- + func @failedSameOperandAndResultElementType_no_operands() { // expected-error@+1 {{expected 1 or more operands}} %0 = "test.same_operand_and_result_element_type"() : () -> tensor<1xf32> diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td index a7c5fa052356..dd620def6b02 100644 --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -236,14 +236,14 @@ def FunctionalRegionOp : TEST_Op<"functional_region_op", def SameOperandElementTypeOp : TEST_Op<"same_operand_element_type", [SameOperandsElementType]> { - let arguments = (ins AnyVectorOrTensor, AnyVectorOrTensor); - let results = (outs AnyVectorOrTensor); + let arguments = (ins AnyType, AnyType); + let results = (outs AnyType); } def SameOperandAndResultElementTypeOp : TEST_Op<"same_operand_and_result_element_type", [SameOperandsAndResultElementType]> { - let arguments = (ins Variadic); - let results = (outs Variadic); + let arguments = (ins Variadic); + let results = (outs Variadic); } def SameOperandShapeOp : TEST_Op<"same_operand_shape", [SameOperandsShape]> {