Don't fail verifying unranked shapes as being the same as this could be valid at runtime.

tensor<*xf32> could be a tensor<1xf32> at runtime but this verifyShapeMatch would return failure and say function is invalid.

--

PiperOrigin-RevId: 248583038
This commit is contained in:
Jacques Pienaar 2019-05-16 12:57:35 -07:00 committed by Mehdi Amini
parent 1982afb145
commit e489e59246
2 changed files with 46 additions and 6 deletions

View File

@ -776,7 +776,7 @@ static LogicalResult verifyShapeMatch(Type type1, Type type2) {
// Check unranked tensor cases
if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>())
return failure();
return success();
// Check normal vector/tensor cases
if (auto sType1 = type1.dyn_cast<ShapedType>()) {

View File

@ -25,13 +25,14 @@ using namespace mlir::OpTrait::impl;
namespace {
// TODO: Replace with regular test once this trait is used by operation in core.
TEST(OpDefinitionTest, SameOperandAndResultElementType) {
MLIRContext context;
#define FILE_LOC \
FileLineColLoc::get(UniquedFilename::get(__FILE__, &context), __LINE__, 0, \
&context)
// TODO: Replace with regular test once this trait is used by operation in core.
// TODO(b/132891206): Replace with dialect test.
TEST(OpDefinitionTest, SameOperandAndResultElementType) {
MLIRContext context;
Builder b(&context);
auto *operandtF32x10x10 = Operation::create(
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
@ -84,8 +85,47 @@ TEST(OpDefinitionTest, SameOperandAndResultElementType) {
b.getTensorType({5}, b.getF32Type())));
EXPECT_FALSE(valid(FILE_LOC, operandtI32x1, operandvF32x1,
b.getTensorType({5}, b.getF32Type())));
#undef FILE_LOC
}
TEST(OpDefinitionTest, SameOperandAndResultShape) {
MLIRContext context;
Builder b(&context);
auto *operandtF32x10x10 = Operation::create(
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
/*resultTypes=*/{b.getTensorType({10, 10}, b.getF32Type())},
/*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
/*resizableOperandList=*/false, &context);
auto *operandtF32x1 = Operation::create(
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
/*resultTypes=*/{b.getTensorType({1}, b.getF32Type())},
/*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
/*resizableOperandList=*/false, &context);
auto *operandtF32xunranked = Operation::create(
FILE_LOC, OperationName("some_const", &context), /*operands=*/{},
/*resultTypes=*/{b.getTensorType(b.getF32Type())},
/*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0,
/*resizableOperandList=*/false, &context);
// SameOperandAndResultShape trait.
auto valid = [&](Location loc, Operation *x, Operation *y, Type resultType) {
auto op = Operation::create(loc, OperationName("some_op", &context),
/*operands=*/{x->getResult(0), y->getResult(0)},
/*resultTypes=*/{resultType},
/*attributes=*/llvm::None, /*successors=*/{},
/*numRegions=*/0,
/*resizableOperandList=*/false, &context);
return succeeded(verifySameOperandsAndResultShape(op));
};
EXPECT_TRUE(valid(FILE_LOC, operandtF32x1, operandtF32x1,
b.getTensorType({1}, b.getF32Type())));
EXPECT_FALSE(valid(FILE_LOC, operandtF32x1, operandtF32x1,
b.getTensorType({12}, b.getF32Type())));
EXPECT_FALSE(valid(FILE_LOC, operandtF32x1, operandtF32x10x10,
b.getTensorType({1}, b.getF32Type())));
EXPECT_TRUE(valid(FILE_LOC, operandtF32x1, operandtF32xunranked,
b.getTensorType({1}, b.getF32Type())));
}
#undef FILE_LOC
} // end namespace