forked from OSchip/llvm-project
Allow element type traits to operate on scalars
This allows confirming that a scalar argument has the same element type as a shaped one. It's easy to validate a type is shaped on its own if that's desirable, so this shouldn't make that use case harder. This matches the behavior of other traits that operate on element type (e.g. AllElementTypesMatch). Also this makes the code simpler because now we just use getElementTypeOrSelf. Verified that all uses in core already check the type is shaped in another way. PiperOrigin-RevId: 273068507
This commit is contained in:
parent
8b9b72cee8
commit
18db4ce493
|
@ -1235,9 +1235,9 @@ def SameOperandsShape : NativeOpTrait<"SameOperandsShape">;
|
|||
def SameOperandsAndResultShape : NativeOpTrait<"SameOperandsAndResultShape">;
|
||||
// Op has the same operand and result type.
|
||||
def SameOperandsAndResultType : NativeOpTrait<"SameOperandsAndResultType">;
|
||||
// Op has the same element type for all operands.
|
||||
// Op has the same element type (or type itself, if scalar) for all operands.
|
||||
def SameOperandsElementType : NativeOpTrait<"SameOperandsElementType">;
|
||||
// Op has the same operand and result element type.
|
||||
// Op has the same operand and result element type (or type itself, if scalar).
|
||||
def SameOperandsAndResultElementType :
|
||||
NativeOpTrait<"SameOperandsAndResultElementType">;
|
||||
// Op is a terminator.
|
||||
|
|
|
@ -645,7 +645,7 @@ public:
|
|||
};
|
||||
|
||||
/// This class provides verification for ops that are known to have the same
|
||||
/// operand element type.
|
||||
/// operand element type (or the type itself if it is scalar).
|
||||
///
|
||||
template <typename ConcreteType>
|
||||
class SameOperandsElementType
|
||||
|
@ -657,7 +657,7 @@ public:
|
|||
};
|
||||
|
||||
/// This class provides verification for ops that are known to have the same
|
||||
/// operand and result element type.
|
||||
/// operand and result element type (or the type itself if it is scalar).
|
||||
///
|
||||
template <typename ConcreteType>
|
||||
class SameOperandsAndResultElementType
|
||||
|
|
|
@ -25,7 +25,9 @@
|
|||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include <numeric>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
/// Form the OperationName for an op with the specified string. This either is
|
||||
|
@ -800,17 +802,10 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) {
|
|||
LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) {
|
||||
if (failed(verifyAtLeastNOperands(op, 1)))
|
||||
return failure();
|
||||
auto elementType = getElementTypeOrSelf(op->getOperand(0));
|
||||
|
||||
auto type = op->getOperand(0)->getType().dyn_cast<ShapedType>();
|
||||
if (!type)
|
||||
return op->emitOpError("requires shaped type results");
|
||||
auto elementType = type.getElementType();
|
||||
|
||||
for (auto operandType : llvm::drop_begin(op->getOperandTypes(), 1)) {
|
||||
auto shapedType = operandType.dyn_cast<ShapedType>();
|
||||
if (!shapedType)
|
||||
return op->emitOpError("requires shaped type operands");
|
||||
if (shapedType.getElementType() != elementType)
|
||||
for (auto operand : llvm::drop_begin(op->getOperands(), 1)) {
|
||||
if (getElementTypeOrSelf(operand) != elementType)
|
||||
return op->emitOpError("requires the same element type for all operands");
|
||||
}
|
||||
|
||||
|
@ -823,27 +818,18 @@ OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) {
|
|||
failed(verifyAtLeastNResults(op, 1)))
|
||||
return failure();
|
||||
|
||||
auto type = op->getResult(0)->getType().dyn_cast<ShapedType>();
|
||||
if (!type)
|
||||
return op->emitOpError("requires shaped type results");
|
||||
auto elementType = type.getElementType();
|
||||
auto elementType = getElementTypeOrSelf(op->getResult(0));
|
||||
|
||||
// Verify result element type matches first result's element type.
|
||||
for (auto result : drop_begin(op->getResults(), 1)) {
|
||||
auto resultType = result->getType().dyn_cast<ShapedType>();
|
||||
if (!resultType)
|
||||
return op->emitOpError("requires shaped type results");
|
||||
if (resultType.getElementType() != elementType)
|
||||
if (getElementTypeOrSelf(result) != elementType)
|
||||
return op->emitOpError(
|
||||
"requires the same element type for all operands and results");
|
||||
}
|
||||
|
||||
// Verify operand's element type matches first result's element type.
|
||||
for (auto operand : op->getOperands()) {
|
||||
auto operandType = operand->getType().dyn_cast<ShapedType>();
|
||||
if (!operandType)
|
||||
return op->emitOpError("requires shaped type operands");
|
||||
if (operandType.getElementType() != elementType)
|
||||
if (getElementTypeOrSelf(operand) != elementType)
|
||||
return op->emitOpError(
|
||||
"requires the same element type for all operands and results");
|
||||
}
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
// CHECK: succeededSameOperandsElementType
|
||||
func @succeededSameOperandsElementType(%t10x10 : tensor<10x10xf32>, %t1f: tensor<1xf32>, %v1: vector<1xf32>, %t1i: tensor<1xi32>) {
|
||||
func @succeededSameOperandsElementType(%t10x10 : tensor<10x10xf32>, %t1f: tensor<1xf32>, %v1: vector<1xf32>, %t1i: tensor<1xi32>, %sf: f32) {
|
||||
%0 = "test.same_operand_element_type"(%t1f, %t1f) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi32>
|
||||
%1 = "test.same_operand_element_type"(%t1f, %t10x10) : (tensor<1xf32>, tensor<10x10xf32>) -> tensor<1xi32>
|
||||
%2 = "test.same_operand_element_type"(%t10x10, %v1) : (tensor<10x10xf32>, vector<1xf32>) -> tensor<1xi32>
|
||||
%3 = "test.same_operand_element_type"(%v1, %t1f) : (vector<1xf32>, tensor<1xf32>) -> tensor<1xi32>
|
||||
%4 = "test.same_operand_element_type"(%v1, %t1f) : (vector<1xf32>, tensor<1xf32>) -> tensor<121xi32>
|
||||
%5 = "test.same_operand_element_type"(%sf, %sf) : (f32, f32) -> i32
|
||||
%6 = "test.same_operand_element_type"(%sf, %t1f) : (f32, tensor<1xf32>) -> tensor<121xi32>
|
||||
%7 = "test.same_operand_element_type"(%sf, %v1) : (f32, vector<1xf32>) -> tensor<121xi32>
|
||||
%8 = "test.same_operand_element_type"(%sf, %t10x10) : (f32, tensor<10x10xf32>) -> tensor<121xi32>
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -26,13 +30,24 @@ func @failedSameOperandAndResultElementType_no_operands() {
|
|||
|
||||
// -----
|
||||
|
||||
func @failedSameOperandElementType_scalar_type_mismatch(%si: i32, %sf: f32) {
|
||||
// expected-error@+1 {{requires the same element type for all operands}}
|
||||
%0 = "test.same_operand_element_type"(%sf, %si) : (f32, i32) -> tensor<1xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: succeededSameOperandAndResultElementType
|
||||
func @succeededSameOperandAndResultElementType(%t10x10 : tensor<10x10xf32>, %t1f: tensor<1xf32>, %v1: vector<1xf32>, %t1i: tensor<1xi32>) {
|
||||
func @succeededSameOperandAndResultElementType(%t10x10 : tensor<10x10xf32>, %t1f: tensor<1xf32>, %v1: vector<1xf32>, %t1i: tensor<1xi32>, %sf: f32) {
|
||||
%0 = "test.same_operand_and_result_element_type"(%t1f, %t1f) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
%1 = "test.same_operand_and_result_element_type"(%t1f, %t10x10) : (tensor<1xf32>, tensor<10x10xf32>) -> tensor<1xf32>
|
||||
%2 = "test.same_operand_and_result_element_type"(%t10x10, %v1) : (tensor<10x10xf32>, vector<1xf32>) -> tensor<1xf32>
|
||||
%3 = "test.same_operand_and_result_element_type"(%v1, %t1f) : (vector<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
%4 = "test.same_operand_and_result_element_type"(%v1, %t1f) : (vector<1xf32>, tensor<1xf32>) -> tensor<121xf32>
|
||||
%5 = "test.same_operand_and_result_element_type"(%sf, %sf) : (f32, f32) -> f32
|
||||
%6 = "test.same_operand_and_result_element_type"(%sf, %t1f) : (f32, tensor<1xf32>) -> tensor<121xf32>
|
||||
%7 = "test.same_operand_and_result_element_type"(%sf, %v1) : (f32, vector<1xf32>) -> tensor<121xf32>
|
||||
%8 = "test.same_operand_and_result_element_type"(%sf, %t10x10) : (f32, tensor<10x10xf32>) -> tensor<121xf32>
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -52,6 +67,13 @@ func @failedSameOperandAndResultElementType_operand_mismatch(%t1f: tensor<1xf32>
|
|||
|
||||
// -----
|
||||
|
||||
func @failedSameOperandAndResultElementType_result_mismatch(%t1f: tensor<1xf32>) {
|
||||
// expected-error@+1 {{requires the same element type for all operands and results}}
|
||||
%0:2 = "test.same_operand_and_result_element_type"(%t1f) : (tensor<1xf32>) -> (tensor<1xf32>, tensor<1xi32>)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @failedSameOperandAndResultElementType_no_operands() {
|
||||
// expected-error@+1 {{expected 1 or more operands}}
|
||||
%0 = "test.same_operand_and_result_element_type"() : () -> tensor<1xf32>
|
||||
|
|
|
@ -236,14 +236,14 @@ def FunctionalRegionOp : TEST_Op<"functional_region_op",
|
|||
|
||||
def SameOperandElementTypeOp : TEST_Op<"same_operand_element_type",
|
||||
[SameOperandsElementType]> {
|
||||
let arguments = (ins AnyVectorOrTensor, AnyVectorOrTensor);
|
||||
let results = (outs AnyVectorOrTensor);
|
||||
let arguments = (ins AnyType, AnyType);
|
||||
let results = (outs AnyType);
|
||||
}
|
||||
|
||||
def SameOperandAndResultElementTypeOp : TEST_Op<"same_operand_and_result_element_type",
|
||||
[SameOperandsAndResultElementType]> {
|
||||
let arguments = (ins Variadic<AnyVectorOrTensor>);
|
||||
let results = (outs Variadic<AnyVectorOrTensor>);
|
||||
let arguments = (ins Variadic<AnyType>);
|
||||
let results = (outs Variadic<AnyType>);
|
||||
}
|
||||
|
||||
def SameOperandShapeOp : TEST_Op<"same_operand_shape", [SameOperandsShape]> {
|
||||
|
|
Loading…
Reference in New Issue