Add SameOperandsAndResultElementType trait.

This trait only works for tensor and vector types at the moment, verifying that the element type of an op with only tensor and vector types match. Added a unit test for it as there is no op currently in core that uses this trait.

--

PiperOrigin-RevId: 246661697
This commit is contained in:
Jacques Pienaar 2019-05-04 11:14:40 -07:00 committed by Mehdi Amini
parent 3b930b0d70
commit dcab80115f
5 changed files with 144 additions and 3 deletions

View File

@ -781,6 +781,9 @@ def NoSideEffect : NativeOpTrait<"HasNoSideEffect">;
def SameValueShape : NativeOpTrait<"SameOperandsAndResultShape">;
// Op has the same operand and result type.
def SameValueType : NativeOpTrait<"SameOperandsAndResultType">;
// Op has the same operand and result element type.
def SameOperandsAndResultElementType :
NativeOpTrait<"SameOperandsAndResultElementType">;
// Op is a terminator.
def Terminator : NativeOpTrait<"IsTerminator">;

View File

@ -317,6 +317,7 @@ LogicalResult verifyOneResult(Operation *op);
LogicalResult verifyNResults(Operation *op, unsigned numOperands);
LogicalResult verifyAtLeastNResults(Operation *op, unsigned numOperands);
LogicalResult verifySameOperandsAndResultShape(Operation *op);
LogicalResult verifySameOperandsAndResultElementType(Operation *op);
LogicalResult verifySameOperandsAndResultType(Operation *op);
LogicalResult verifyResultsAreBoolLike(Operation *op);
LogicalResult verifyResultsAreFloatLike(Operation *op);
@ -572,12 +573,24 @@ public:
}
};
/// This class provides verification for ops that are known to have the same
/// operand and result element type.
///
/// TODO: This only works for VectorOrTensorType at the moment.
template <typename ConcreteType>
class SameOperandsAndResultElementType
: public TraitBase<ConcreteType, SameOperandsAndResultElementType> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifySameOperandsAndResultElementType(op);
}
};
/// This class provides verification for ops that are known to have the same
/// operand and result type.
///
/// Note: this trait subsumes the SameOperandsAndResultShape trait.
/// Additionally, it requires all operands and results should also have
/// the same element type.
/// Note: this trait subsumes the SameOperandsAndResultShape and
/// SameOperandsAndResultElementType traits.
template <typename ConcreteType>
class SameOperandsAndResultType
: public TraitBase<ConcreteType, SameOperandsAndResultType> {

View File

@ -798,6 +798,39 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) {
return success();
}
LogicalResult
OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) {
if (op->getNumOperands() == 0 || op->getNumResults() == 0)
return failure();
auto type = op->getResult(0)->getType().dyn_cast<VectorOrTensorType>();
if (!type)
return op->emitOpError("requires vector or tensor type results");
auto elementType = type.getElementType();
// Verify result element type matches first result's element type.
for (auto result : drop_begin(op->getResults(), 1)) {
auto resultType = result->getType().dyn_cast<VectorOrTensorType>();
if (!resultType)
return op->emitOpError("requires vector or tensor type results");
if (resultType.getElementType() != elementType)
return op->emitOpError(
"requires the same element type for all operands and results");
}
// Verify operand's element type matches first result's element type.
for (auto operand : op->getOperands()) {
auto operandType = operand->getType().dyn_cast<VectorOrTensorType>();
if (!operandType)
return op->emitOpError("requires vector or tensor type operands");
if (operandType.getElementType() != elementType)
return op->emitOpError(
"requires the same element type for all operands and results");
}
return success();
}
LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) {
if (op->getNumOperands() == 0 || op->getNumResults() == 0)
return failure();

View File

@ -1,6 +1,7 @@
add_mlir_unittest(MLIRIRTests
DialectTest.cpp
OperationSupportTest.cpp
OpDefinitionTest.cpp
SDBMTest.cpp
)
target_link_libraries(MLIRIRTests

View File

@ -0,0 +1,91 @@
//===- OpDefinitionTest.cpp - Op definition unit tests --------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/StandardTypes.h"
#include "gmock/gmock.h"
using namespace mlir;
using namespace mlir::OpTrait::impl;
namespace {
// TODO: Replace with regular test once this trait is used by operation in core.
TEST(OpDefinitionTest, SameOperandAndResultElementType) {
MLIRContext context;
#define FILE_LOC \
FileLineColLoc::get(UniquedFilename::get(__FILE__, &context), __LINE__, 0, \
&context)
Builder b(&context);
auto *operandtF32x10x10 = Operation::create(
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
/*resultTypes=*/{b.getTensorType({10, 10}, b.getF32Type())},
/*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
/*resizableOperandList=*/false, &context);
auto *operandtF32x1 = Operation::create(
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
/*resultTypes=*/{b.getTensorType({1}, b.getF32Type())},
/*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
/*resizableOperandList=*/false, &context);
auto *operandvF32x1 = Operation::create(
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
/*resultTypes=*/{b.getVectorType({1}, b.getF32Type())},
/*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
/*resizableOperandList=*/false, &context);
auto *operandtI32x1 = Operation::create(
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
/*resultTypes=*/{b.getTensorType({1}, b.getIntegerType(32))},
/*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
/*resizableOperandList=*/false, &context);
// Verifies whether an op with x and y as inputs and resultType satisfies the
// SameOperandAndResultElementType trait.
auto valid = [&](Location loc, Operation *x, Operation *y, Type resultType) {
auto op = Operation::create(loc, OperationName("some_op", &context),
/*operands=*/{x->getResult(0), y->getResult(0)},
/*resultTypes=*/{resultType},
/*attributes=*/llvm::None, /*successors=*/{},
/*numRegions=*/0,
/*resizableOperandList=*/false, &context);
return succeeded(verifySameOperandsAndResultElementType(op));
};
EXPECT_TRUE(valid(FILE_LOC, operandtF32x1, operandtF32x1,
b.getTensorType({12}, b.getF32Type())));
EXPECT_TRUE(valid(FILE_LOC, operandtF32x10x10, operandtF32x1,
b.getTensorType({5}, b.getF32Type())));
EXPECT_FALSE(valid(FILE_LOC, operandtF32x10x10, operandtI32x1,
b.getTensorType({7}, b.getF32Type())));
EXPECT_FALSE(valid(FILE_LOC, operandtF32x10x10, operandtF32x1,
b.getTensorType({12}, b.getIntegerType(32))));
EXPECT_FALSE(valid(FILE_LOC, operandtF32x10x10, operandtI32x1,
b.getTensorType({9}, b.getIntegerType(32))));
EXPECT_TRUE(valid(FILE_LOC, operandtF32x10x10, operandtF32x1,
b.getVectorType({9}, b.getF32Type())));
EXPECT_TRUE(valid(FILE_LOC, operandtF32x10x10, operandvF32x1,
b.getVectorType({9}, b.getF32Type())));
EXPECT_TRUE(valid(FILE_LOC, operandtF32x1, operandvF32x1,
b.getTensorType({5}, b.getF32Type())));
EXPECT_FALSE(valid(FILE_LOC, operandtI32x1, operandvF32x1,
b.getTensorType({5}, b.getF32Type())));
#undef FILE_LOC
}
} // end namespace