diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index b5a7a20bdf7e..bc8947c7ae73 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -776,7 +776,7 @@ static LogicalResult verifyShapeMatch(Type type1, Type type2) { // Check unranked tensor cases if (type1.isa() || type2.isa()) - return failure(); + return success(); // Check normal vector/tensor cases if (auto sType1 = type1.dyn_cast()) { diff --git a/mlir/unittests/IR/OpDefinitionTest.cpp b/mlir/unittests/IR/OpDefinitionTest.cpp index 40740097eca9..dcc83ac6bf9e 100644 --- a/mlir/unittests/IR/OpDefinitionTest.cpp +++ b/mlir/unittests/IR/OpDefinitionTest.cpp @@ -25,13 +25,14 @@ using namespace mlir::OpTrait::impl; namespace { -// TODO: Replace with regular test once this trait is used by operation in core. -TEST(OpDefinitionTest, SameOperandAndResultElementType) { - MLIRContext context; #define FILE_LOC \ FileLineColLoc::get(UniquedFilename::get(__FILE__, &context), __LINE__, 0, \ &context) +// TODO: Replace with regular test once this trait is used by operation in core. +// TODO(b/132891206): Replace with dialect test. +TEST(OpDefinitionTest, SameOperandAndResultElementType) { + MLIRContext context; Builder b(&context); auto *operandtF32x10x10 = Operation::create( FILE_LOC, OperationName("some_const", &context), /*operands=*/{}, @@ -84,8 +85,47 @@ TEST(OpDefinitionTest, SameOperandAndResultElementType) { b.getTensorType({5}, b.getF32Type()))); EXPECT_FALSE(valid(FILE_LOC, operandtI32x1, operandvF32x1, b.getTensorType({5}, b.getF32Type()))); - -#undef FILE_LOC } +TEST(OpDefinitionTest, SameOperandAndResultShape) { + MLIRContext context; + Builder b(&context); + auto *operandtF32x10x10 = Operation::create( + FILE_LOC, OperationName("some_const", &context), /*operands=*/{}, + /*resultTypes=*/{b.getTensorType({10, 10}, b.getF32Type())}, + /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0, + /*resizableOperandList=*/false, &context); + auto *operandtF32x1 = Operation::create( + FILE_LOC, OperationName("some_const", &context), /*operands=*/{}, + /*resultTypes=*/{b.getTensorType({1}, b.getF32Type())}, + /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0, + /*resizableOperandList=*/false, &context); + auto *operandtF32xunranked = Operation::create( + FILE_LOC, OperationName("some_const", &context), /*operands=*/{}, + /*resultTypes=*/{b.getTensorType(b.getF32Type())}, + /*attributes=*/llvm::None, /*successors=*/{}, /*numRegions=*/0, + /*resizableOperandList=*/false, &context); + + // SameOperandAndResultShape trait. + auto valid = [&](Location loc, Operation *x, Operation *y, Type resultType) { + auto op = Operation::create(loc, OperationName("some_op", &context), + /*operands=*/{x->getResult(0), y->getResult(0)}, + /*resultTypes=*/{resultType}, + /*attributes=*/llvm::None, /*successors=*/{}, + /*numRegions=*/0, + /*resizableOperandList=*/false, &context); + return succeeded(verifySameOperandsAndResultShape(op)); + }; + + EXPECT_TRUE(valid(FILE_LOC, operandtF32x1, operandtF32x1, + b.getTensorType({1}, b.getF32Type()))); + EXPECT_FALSE(valid(FILE_LOC, operandtF32x1, operandtF32x1, + b.getTensorType({12}, b.getF32Type()))); + EXPECT_FALSE(valid(FILE_LOC, operandtF32x1, operandtF32x10x10, + b.getTensorType({1}, b.getF32Type()))); + EXPECT_TRUE(valid(FILE_LOC, operandtF32x1, operandtF32xunranked, + b.getTensorType({1}, b.getF32Type()))); +} + +#undef FILE_LOC } // end namespace