Allow element type traits to operate on scalars

This allows confirming that a scalar argument has the same element type as a shaped one. It's easy to validate a type is shaped on its own if that's desirable, so this shouldn't make that use case harder. This matches the behavior of other traits that operate on element type (e.g. AllElementTypesMatch). Also this makes the code simpler because now we just use getElementTypeOrSelf.

Verified that all uses in core already check the type is shaped in another way.

PiperOrigin-RevId: 273068507
This commit is contained in:
Geoffrey Martin-Noble 2019-10-05 10:05:40 -07:00 committed by A. Unique TensorFlower
parent 8b9b72cee8
commit 18db4ce493
5 changed files with 40 additions and 32 deletions

View File

@ -1235,9 +1235,9 @@ def SameOperandsShape : NativeOpTrait<"SameOperandsShape">;
def SameOperandsAndResultShape : NativeOpTrait<"SameOperandsAndResultShape">;
// Op has the same operand and result type.
def SameOperandsAndResultType : NativeOpTrait<"SameOperandsAndResultType">;
// Op has the same element type for all operands.
// Op has the same element type (or type itself, if scalar) for all operands.
def SameOperandsElementType : NativeOpTrait<"SameOperandsElementType">;
// Op has the same operand and result element type.
// Op has the same operand and result element type (or type itself, if scalar).
def SameOperandsAndResultElementType :
NativeOpTrait<"SameOperandsAndResultElementType">;
// Op is a terminator.

View File

@ -645,7 +645,7 @@ public:
};
/// This class provides verification for ops that are known to have the same
/// operand element type.
/// operand element type (or the type itself if it is scalar).
///
template <typename ConcreteType>
class SameOperandsElementType
@ -657,7 +657,7 @@ public:
};
/// This class provides verification for ops that are known to have the same
/// operand and result element type.
/// operand and result element type (or the type itself if it is scalar).
///
template <typename ConcreteType>
class SameOperandsAndResultElementType

View File

@ -25,7 +25,9 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include <numeric>
using namespace mlir;
/// Form the OperationName for an op with the specified string. This either is
@ -800,17 +802,10 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) {
LogicalResult OpTrait::impl::verifySameOperandsElementType(Operation *op) {
if (failed(verifyAtLeastNOperands(op, 1)))
return failure();
auto elementType = getElementTypeOrSelf(op->getOperand(0));
auto type = op->getOperand(0)->getType().dyn_cast<ShapedType>();
if (!type)
return op->emitOpError("requires shaped type results");
auto elementType = type.getElementType();
for (auto operandType : llvm::drop_begin(op->getOperandTypes(), 1)) {
auto shapedType = operandType.dyn_cast<ShapedType>();
if (!shapedType)
return op->emitOpError("requires shaped type operands");
if (shapedType.getElementType() != elementType)
for (auto operand : llvm::drop_begin(op->getOperands(), 1)) {
if (getElementTypeOrSelf(operand) != elementType)
return op->emitOpError("requires the same element type for all operands");
}
@ -823,27 +818,18 @@ OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) {
failed(verifyAtLeastNResults(op, 1)))
return failure();
auto type = op->getResult(0)->getType().dyn_cast<ShapedType>();
if (!type)
return op->emitOpError("requires shaped type results");
auto elementType = type.getElementType();
auto elementType = getElementTypeOrSelf(op->getResult(0));
// Verify result element type matches first result's element type.
for (auto result : drop_begin(op->getResults(), 1)) {
auto resultType = result->getType().dyn_cast<ShapedType>();
if (!resultType)
return op->emitOpError("requires shaped type results");
if (resultType.getElementType() != elementType)
if (getElementTypeOrSelf(result) != elementType)
return op->emitOpError(
"requires the same element type for all operands and results");
}
// Verify operand's element type matches first result's element type.
for (auto operand : op->getOperands()) {
auto operandType = operand->getType().dyn_cast<ShapedType>();
if (!operandType)
return op->emitOpError("requires shaped type operands");
if (operandType.getElementType() != elementType)
if (getElementTypeOrSelf(operand) != elementType)
return op->emitOpError(
"requires the same element type for all operands and results");
}

View File

@ -1,12 +1,16 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s
// CHECK: succeededSameOperandsElementType
func @succeededSameOperandsElementType(%t10x10 : tensor<10x10xf32>, %t1f: tensor<1xf32>, %v1: vector<1xf32>, %t1i: tensor<1xi32>) {
func @succeededSameOperandsElementType(%t10x10 : tensor<10x10xf32>, %t1f: tensor<1xf32>, %v1: vector<1xf32>, %t1i: tensor<1xi32>, %sf: f32) {
%0 = "test.same_operand_element_type"(%t1f, %t1f) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi32>
%1 = "test.same_operand_element_type"(%t1f, %t10x10) : (tensor<1xf32>, tensor<10x10xf32>) -> tensor<1xi32>
%2 = "test.same_operand_element_type"(%t10x10, %v1) : (tensor<10x10xf32>, vector<1xf32>) -> tensor<1xi32>
%3 = "test.same_operand_element_type"(%v1, %t1f) : (vector<1xf32>, tensor<1xf32>) -> tensor<1xi32>
%4 = "test.same_operand_element_type"(%v1, %t1f) : (vector<1xf32>, tensor<1xf32>) -> tensor<121xi32>
%5 = "test.same_operand_element_type"(%sf, %sf) : (f32, f32) -> i32
%6 = "test.same_operand_element_type"(%sf, %t1f) : (f32, tensor<1xf32>) -> tensor<121xi32>
%7 = "test.same_operand_element_type"(%sf, %v1) : (f32, vector<1xf32>) -> tensor<121xi32>
%8 = "test.same_operand_element_type"(%sf, %t10x10) : (f32, tensor<10x10xf32>) -> tensor<121xi32>
return
}
@ -26,13 +30,24 @@ func @failedSameOperandAndResultElementType_no_operands() {
// -----
func @failedSameOperandElementType_scalar_type_mismatch(%si: i32, %sf: f32) {
// expected-error@+1 {{requires the same element type for all operands}}
%0 = "test.same_operand_element_type"(%sf, %si) : (f32, i32) -> tensor<1xf32>
}
// -----
// CHECK: succeededSameOperandAndResultElementType
func @succeededSameOperandAndResultElementType(%t10x10 : tensor<10x10xf32>, %t1f: tensor<1xf32>, %v1: vector<1xf32>, %t1i: tensor<1xi32>) {
func @succeededSameOperandAndResultElementType(%t10x10 : tensor<10x10xf32>, %t1f: tensor<1xf32>, %v1: vector<1xf32>, %t1i: tensor<1xi32>, %sf: f32) {
%0 = "test.same_operand_and_result_element_type"(%t1f, %t1f) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
%1 = "test.same_operand_and_result_element_type"(%t1f, %t10x10) : (tensor<1xf32>, tensor<10x10xf32>) -> tensor<1xf32>
%2 = "test.same_operand_and_result_element_type"(%t10x10, %v1) : (tensor<10x10xf32>, vector<1xf32>) -> tensor<1xf32>
%3 = "test.same_operand_and_result_element_type"(%v1, %t1f) : (vector<1xf32>, tensor<1xf32>) -> tensor<1xf32>
%4 = "test.same_operand_and_result_element_type"(%v1, %t1f) : (vector<1xf32>, tensor<1xf32>) -> tensor<121xf32>
%5 = "test.same_operand_and_result_element_type"(%sf, %sf) : (f32, f32) -> f32
%6 = "test.same_operand_and_result_element_type"(%sf, %t1f) : (f32, tensor<1xf32>) -> tensor<121xf32>
%7 = "test.same_operand_and_result_element_type"(%sf, %v1) : (f32, vector<1xf32>) -> tensor<121xf32>
%8 = "test.same_operand_and_result_element_type"(%sf, %t10x10) : (f32, tensor<10x10xf32>) -> tensor<121xf32>
return
}
@ -52,6 +67,13 @@ func @failedSameOperandAndResultElementType_operand_mismatch(%t1f: tensor<1xf32>
// -----
func @failedSameOperandAndResultElementType_result_mismatch(%t1f: tensor<1xf32>) {
// expected-error@+1 {{requires the same element type for all operands and results}}
%0:2 = "test.same_operand_and_result_element_type"(%t1f) : (tensor<1xf32>) -> (tensor<1xf32>, tensor<1xi32>)
}
// -----
func @failedSameOperandAndResultElementType_no_operands() {
// expected-error@+1 {{expected 1 or more operands}}
%0 = "test.same_operand_and_result_element_type"() : () -> tensor<1xf32>

View File

@ -236,14 +236,14 @@ def FunctionalRegionOp : TEST_Op<"functional_region_op",
def SameOperandElementTypeOp : TEST_Op<"same_operand_element_type",
[SameOperandsElementType]> {
let arguments = (ins AnyVectorOrTensor, AnyVectorOrTensor);
let results = (outs AnyVectorOrTensor);
let arguments = (ins AnyType, AnyType);
let results = (outs AnyType);
}
def SameOperandAndResultElementTypeOp : TEST_Op<"same_operand_and_result_element_type",
[SameOperandsAndResultElementType]> {
let arguments = (ins Variadic<AnyVectorOrTensor>);
let results = (outs Variadic<AnyVectorOrTensor>);
let arguments = (ins Variadic<AnyType>);
let results = (outs Variadic<AnyType>);
}
def SameOperandShapeOp : TEST_Op<"same_operand_shape", [SameOperandsShape]> {