forked from OSchip/llvm-project
Add SameOperandsAndResultElementType trait.
This trait only works for tensor and vector types at the moment, verifying that the element type of an op with only tensor and vector types match. Added a unit test for it as there is no op currently in core that uses this trait. -- PiperOrigin-RevId: 246661697
This commit is contained in:
parent
3b930b0d70
commit
dcab80115f
|
@ -781,6 +781,9 @@ def NoSideEffect : NativeOpTrait<"HasNoSideEffect">;
|
|||
def SameValueShape : NativeOpTrait<"SameOperandsAndResultShape">;
|
||||
// Op has the same operand and result type.
|
||||
def SameValueType : NativeOpTrait<"SameOperandsAndResultType">;
|
||||
// Op has the same operand and result element type.
|
||||
def SameOperandsAndResultElementType :
|
||||
NativeOpTrait<"SameOperandsAndResultElementType">;
|
||||
// Op is a terminator.
|
||||
def Terminator : NativeOpTrait<"IsTerminator">;
|
||||
|
||||
|
|
|
@ -317,6 +317,7 @@ LogicalResult verifyOneResult(Operation *op);
|
|||
LogicalResult verifyNResults(Operation *op, unsigned numOperands);
|
||||
LogicalResult verifyAtLeastNResults(Operation *op, unsigned numOperands);
|
||||
LogicalResult verifySameOperandsAndResultShape(Operation *op);
|
||||
LogicalResult verifySameOperandsAndResultElementType(Operation *op);
|
||||
LogicalResult verifySameOperandsAndResultType(Operation *op);
|
||||
LogicalResult verifyResultsAreBoolLike(Operation *op);
|
||||
LogicalResult verifyResultsAreFloatLike(Operation *op);
|
||||
|
@ -572,12 +573,24 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// This class provides verification for ops that are known to have the same
|
||||
/// operand and result element type.
|
||||
///
|
||||
/// TODO: This only works for VectorOrTensorType at the moment.
|
||||
template <typename ConcreteType>
|
||||
class SameOperandsAndResultElementType
|
||||
: public TraitBase<ConcreteType, SameOperandsAndResultElementType> {
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return impl::verifySameOperandsAndResultElementType(op);
|
||||
}
|
||||
};
|
||||
|
||||
/// This class provides verification for ops that are known to have the same
|
||||
/// operand and result type.
|
||||
///
|
||||
/// Note: this trait subsumes the SameOperandsAndResultShape trait.
|
||||
/// Additionally, it requires all operands and results should also have
|
||||
/// the same element type.
|
||||
/// Note: this trait subsumes the SameOperandsAndResultShape and
|
||||
/// SameOperandsAndResultElementType traits.
|
||||
template <typename ConcreteType>
|
||||
class SameOperandsAndResultType
|
||||
: public TraitBase<ConcreteType, SameOperandsAndResultType> {
|
||||
|
|
|
@ -798,6 +798,39 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) {
|
||||
if (op->getNumOperands() == 0 || op->getNumResults() == 0)
|
||||
return failure();
|
||||
|
||||
auto type = op->getResult(0)->getType().dyn_cast<VectorOrTensorType>();
|
||||
if (!type)
|
||||
return op->emitOpError("requires vector or tensor type results");
|
||||
auto elementType = type.getElementType();
|
||||
|
||||
// Verify result element type matches first result's element type.
|
||||
for (auto result : drop_begin(op->getResults(), 1)) {
|
||||
auto resultType = result->getType().dyn_cast<VectorOrTensorType>();
|
||||
if (!resultType)
|
||||
return op->emitOpError("requires vector or tensor type results");
|
||||
if (resultType.getElementType() != elementType)
|
||||
return op->emitOpError(
|
||||
"requires the same element type for all operands and results");
|
||||
}
|
||||
|
||||
// Verify operand's element type matches first result's element type.
|
||||
for (auto operand : op->getOperands()) {
|
||||
auto operandType = operand->getType().dyn_cast<VectorOrTensorType>();
|
||||
if (!operandType)
|
||||
return op->emitOpError("requires vector or tensor type operands");
|
||||
if (operandType.getElementType() != elementType)
|
||||
return op->emitOpError(
|
||||
"requires the same element type for all operands and results");
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) {
|
||||
if (op->getNumOperands() == 0 || op->getNumResults() == 0)
|
||||
return failure();
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
add_mlir_unittest(MLIRIRTests
|
||||
DialectTest.cpp
|
||||
OperationSupportTest.cpp
|
||||
OpDefinitionTest.cpp
|
||||
SDBMTest.cpp
|
||||
)
|
||||
target_link_libraries(MLIRIRTests
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
//===- OpDefinitionTest.cpp - Op definition unit tests --------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "gmock/gmock.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::OpTrait::impl;
|
||||
|
||||
namespace {
|
||||
|
||||
// TODO: Replace with regular test once this trait is used by operation in core.
|
||||
TEST(OpDefinitionTest, SameOperandAndResultElementType) {
|
||||
MLIRContext context;
|
||||
#define FILE_LOC \
|
||||
FileLineColLoc::get(UniquedFilename::get(__FILE__, &context), __LINE__, 0, \
|
||||
&context)
|
||||
|
||||
Builder b(&context);
|
||||
auto *operandtF32x10x10 = Operation::create(
|
||||
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
|
||||
/*resultTypes=*/{b.getTensorType({10, 10}, b.getF32Type())},
|
||||
/*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
|
||||
/*resizableOperandList=*/false, &context);
|
||||
auto *operandtF32x1 = Operation::create(
|
||||
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
|
||||
/*resultTypes=*/{b.getTensorType({1}, b.getF32Type())},
|
||||
/*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
|
||||
/*resizableOperandList=*/false, &context);
|
||||
auto *operandvF32x1 = Operation::create(
|
||||
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
|
||||
/*resultTypes=*/{b.getVectorType({1}, b.getF32Type())},
|
||||
/*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
|
||||
/*resizableOperandList=*/false, &context);
|
||||
auto *operandtI32x1 = Operation::create(
|
||||
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
|
||||
/*resultTypes=*/{b.getTensorType({1}, b.getIntegerType(32))},
|
||||
/*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
|
||||
/*resizableOperandList=*/false, &context);
|
||||
|
||||
// Verifies whether an op with x and y as inputs and resultType satisfies the
|
||||
// SameOperandAndResultElementType trait.
|
||||
auto valid = [&](Location loc, Operation *x, Operation *y, Type resultType) {
|
||||
auto op = Operation::create(loc, OperationName("some_op", &context),
|
||||
/*operands=*/{x->getResult(0), y->getResult(0)},
|
||||
/*resultTypes=*/{resultType},
|
||||
/*attributes=*/llvm::None, /*successors=*/{},
|
||||
/*numRegions=*/0,
|
||||
/*resizableOperandList=*/false, &context);
|
||||
return succeeded(verifySameOperandsAndResultElementType(op));
|
||||
};
|
||||
|
||||
EXPECT_TRUE(valid(FILE_LOC, operandtF32x1, operandtF32x1,
|
||||
b.getTensorType({12}, b.getF32Type())));
|
||||
EXPECT_TRUE(valid(FILE_LOC, operandtF32x10x10, operandtF32x1,
|
||||
b.getTensorType({5}, b.getF32Type())));
|
||||
EXPECT_FALSE(valid(FILE_LOC, operandtF32x10x10, operandtI32x1,
|
||||
b.getTensorType({7}, b.getF32Type())));
|
||||
EXPECT_FALSE(valid(FILE_LOC, operandtF32x10x10, operandtF32x1,
|
||||
b.getTensorType({12}, b.getIntegerType(32))));
|
||||
EXPECT_FALSE(valid(FILE_LOC, operandtF32x10x10, operandtI32x1,
|
||||
b.getTensorType({9}, b.getIntegerType(32))));
|
||||
EXPECT_TRUE(valid(FILE_LOC, operandtF32x10x10, operandtF32x1,
|
||||
b.getVectorType({9}, b.getF32Type())));
|
||||
EXPECT_TRUE(valid(FILE_LOC, operandtF32x10x10, operandvF32x1,
|
||||
b.getVectorType({9}, b.getF32Type())));
|
||||
EXPECT_TRUE(valid(FILE_LOC, operandtF32x1, operandvF32x1,
|
||||
b.getTensorType({5}, b.getF32Type())));
|
||||
EXPECT_FALSE(valid(FILE_LOC, operandtI32x1, operandvF32x1,
|
||||
b.getTensorType({5}, b.getF32Type())));
|
||||
|
||||
#undef FILE_LOC
|
||||
}
|
||||
|
||||
} // end namespace
|
Loading…
Reference in New Issue