forked from OSchip/llvm-project
Don't fail verifying unranked shapes as being the same as this could be valid at runtime.
tensor<*xf32> could be a tensor<1xf32> at runtime but this verifyShapeMatch would return failure and say function is invalid. -- PiperOrigin-RevId: 248583038
This commit is contained in:
parent
1982afb145
commit
e489e59246
|
@ -776,7 +776,7 @@ static LogicalResult verifyShapeMatch(Type type1, Type type2) {
|
|||
|
||||
// Check unranked tensor cases
|
||||
if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>())
|
||||
return failure();
|
||||
return success();
|
||||
|
||||
// Check normal vector/tensor cases
|
||||
if (auto sType1 = type1.dyn_cast<ShapedType>()) {
|
||||
|
|
|
@ -25,13 +25,14 @@ using namespace mlir::OpTrait::impl;
|
|||
|
||||
namespace {
|
||||
|
||||
// TODO: Replace with regular test once this trait is used by operation in core.
|
||||
TEST(OpDefinitionTest, SameOperandAndResultElementType) {
|
||||
MLIRContext context;
|
||||
#define FILE_LOC \
|
||||
FileLineColLoc::get(UniquedFilename::get(__FILE__, &context), __LINE__, 0, \
|
||||
&context)
|
||||
|
||||
// TODO: Replace with regular test once this trait is used by operation in core.
|
||||
// TODO(b/132891206): Replace with dialect test.
|
||||
TEST(OpDefinitionTest, SameOperandAndResultElementType) {
|
||||
MLIRContext context;
|
||||
Builder b(&context);
|
||||
auto *operandtF32x10x10 = Operation::create(
|
||||
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
|
||||
|
@ -84,8 +85,47 @@ TEST(OpDefinitionTest, SameOperandAndResultElementType) {
|
|||
b.getTensorType({5}, b.getF32Type())));
|
||||
EXPECT_FALSE(valid(FILE_LOC, operandtI32x1, operandvF32x1,
|
||||
b.getTensorType({5}, b.getF32Type())));
|
||||
|
||||
#undef FILE_LOC
|
||||
}
|
||||
|
||||
TEST(OpDefinitionTest, SameOperandAndResultShape) {
|
||||
MLIRContext context;
|
||||
Builder b(&context);
|
||||
auto *operandtF32x10x10 = Operation::create(
|
||||
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
|
||||
/*resultTypes=*/{b.getTensorType({10, 10}, b.getF32Type())},
|
||||
/*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
|
||||
/*resizableOperandList=*/false, &context);
|
||||
auto *operandtF32x1 = Operation::create(
|
||||
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
|
||||
/*resultTypes=*/{b.getTensorType({1}, b.getF32Type())},
|
||||
/*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
|
||||
/*resizableOperandList=*/false, &context);
|
||||
auto *operandtF32xunranked = Operation::create(
|
||||
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
|
||||
/*resultTypes=*/{b.getTensorType(b.getF32Type())},
|
||||
/*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
|
||||
/*resizableOperandList=*/false, &context);
|
||||
|
||||
// SameOperandAndResultShape trait.
|
||||
auto valid = [&](Location loc, Operation *x, Operation *y, Type resultType) {
|
||||
auto op = Operation::create(loc, OperationName("some_op", &context),
|
||||
/*operands=*/{x->getResult(0), y->getResult(0)},
|
||||
/*resultTypes=*/{resultType},
|
||||
/*attributes=*/llvm::None, /*successors=*/{},
|
||||
/*numRegions=*/0,
|
||||
/*resizableOperandList=*/false, &context);
|
||||
return succeeded(verifySameOperandsAndResultShape(op));
|
||||
};
|
||||
|
||||
EXPECT_TRUE(valid(FILE_LOC, operandtF32x1, operandtF32x1,
|
||||
b.getTensorType({1}, b.getF32Type())));
|
||||
EXPECT_FALSE(valid(FILE_LOC, operandtF32x1, operandtF32x1,
|
||||
b.getTensorType({12}, b.getF32Type())));
|
||||
EXPECT_FALSE(valid(FILE_LOC, operandtF32x1, operandtF32x10x10,
|
||||
b.getTensorType({1}, b.getF32Type())));
|
||||
EXPECT_TRUE(valid(FILE_LOC, operandtF32x1, operandtF32xunranked,
|
||||
b.getTensorType({1}, b.getF32Type())));
|
||||
}
|
||||
|
||||
#undef FILE_LOC
|
||||
} // end namespace
|
||||
|
|
Loading…
Reference in New Issue