forked from OSchip/llvm-project
[mlir] Correct verifyCompatibleShapes
verifyCompatibleShapes is not transitive. Create an n-ary version and update SameOperandShapes and SameOperandAndResultShapes traits to use it. Differential Revision: https://reviews.llvm.org/D98331
This commit is contained in:
parent
b8c58374f6
commit
25a20b8aa6
|
@ -59,6 +59,13 @@ LogicalResult verifyCompatibleShape(Type type1, Type type2);
|
|||
/// each pair wise entries have compatible shape.
|
||||
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2);
|
||||
|
||||
/// Returns success if all given types have compatible shapes. That is, they are
|
||||
/// all scalars (not shaped), or they are all shaped types and any ranked shapes
|
||||
/// have compatible dimensions. The element type does not matter.
|
||||
LogicalResult verifyCompatibleShapes(TypeRange types);
|
||||
|
||||
/// Dimensions are compatible if all non-dynamic dims are equal.
|
||||
LogicalResult verifyCompatibleDims(ArrayRef<int64_t> dims);
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utility Iterators
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -834,11 +834,9 @@ LogicalResult OpTrait::impl::verifySameOperandsShape(Operation *op) {
|
|||
if (failed(verifyAtLeastNOperands(op, 1)))
|
||||
return failure();
|
||||
|
||||
auto type = op->getOperand(0).getType();
|
||||
for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) {
|
||||
if (failed(verifyCompatibleShape(opType, type)))
|
||||
return op->emitOpError() << "requires the same shape for all operands";
|
||||
}
|
||||
if (failed(verifyCompatibleShapes(op->getOperandTypes())))
|
||||
return op->emitOpError() << "requires the same shape for all operands";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -847,17 +845,13 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) {
|
|||
failed(verifyAtLeastNResults(op, 1)))
|
||||
return failure();
|
||||
|
||||
auto type = op->getOperand(0).getType();
|
||||
for (auto resultType : op->getResultTypes()) {
|
||||
if (failed(verifyCompatibleShape(resultType, type)))
|
||||
return op->emitOpError()
|
||||
<< "requires the same shape for all operands and results";
|
||||
}
|
||||
for (auto opType : llvm::drop_begin(op->getOperandTypes(), 1)) {
|
||||
if (failed(verifyCompatibleShape(opType, type)))
|
||||
return op->emitOpError()
|
||||
<< "requires the same shape for all operands and results";
|
||||
}
|
||||
SmallVector<Type, 8> types(op->getOperandTypes());
|
||||
types.append(llvm::to_vector<4>(op->getResultTypes()));
|
||||
|
||||
if (failed(verifyCompatibleShapes(types)))
|
||||
return op->emitOpError()
|
||||
<< "requires the same shape for all operands and results";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -11,6 +11,9 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
@ -97,6 +100,57 @@ LogicalResult mlir::verifyCompatibleShapes(TypeRange types1, TypeRange types2) {
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult mlir::verifyCompatibleDims(ArrayRef<int64_t> dims) {
|
||||
if (dims.empty())
|
||||
return success();
|
||||
auto staticDim = std::accumulate(
|
||||
dims.begin(), dims.end(), dims.front(), [](auto fold, auto dim) {
|
||||
return ShapedType::isDynamic(dim) ? fold : dim;
|
||||
});
|
||||
return success(llvm::all_of(dims, [&](auto dim) {
|
||||
return ShapedType::isDynamic(dim) || dim == staticDim;
|
||||
}));
|
||||
}
|
||||
|
||||
/// Returns success if all given types have compatible shapes. That is, they are
|
||||
/// all scalars (not shaped), or they are all shaped types and any ranked shapes
|
||||
/// have compatible dimensions. Dimensions are compatible if all non-dynamic
|
||||
/// dims are equal. The element type does not matter.
|
||||
LogicalResult mlir::verifyCompatibleShapes(TypeRange types) {
|
||||
auto shapedTypes = llvm::to_vector<8>(llvm::map_range(
|
||||
types, [](auto type) { return type.template dyn_cast<ShapedType>(); }));
|
||||
// Return failure if some, but not all are not shaped. Return early if none
|
||||
// are shaped also.
|
||||
if (llvm::none_of(shapedTypes, [](auto t) { return t; }))
|
||||
return success();
|
||||
if (!llvm::all_of(shapedTypes, [](auto t) { return t; }))
|
||||
return failure();
|
||||
|
||||
// Remove all unranked shapes
|
||||
auto shapes = llvm::to_vector<8>(llvm::make_filter_range(
|
||||
shapedTypes, [](auto shapedType) { return shapedType.hasRank(); }));
|
||||
if (shapes.empty())
|
||||
return success();
|
||||
|
||||
// All ranks should be equal
|
||||
auto firstRank = shapes.front().getRank();
|
||||
if (llvm::any_of(shapes,
|
||||
[&](auto shape) { return firstRank != shape.getRank(); }))
|
||||
return failure();
|
||||
|
||||
for (unsigned i = 0; i < firstRank; ++i) {
|
||||
// Retrieve all ranked dimensions
|
||||
auto dims = llvm::to_vector<8>(llvm::map_range(
|
||||
llvm::make_filter_range(
|
||||
shapes, [&](auto shape) { return shape.getRank() >= i; }),
|
||||
[&](auto shape) { return shape.getDimSize(i); }));
|
||||
if (verifyCompatibleDims(dims).failed())
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
OperandElementTypeIterator::OperandElementTypeIterator(
|
||||
Operation::operand_iterator it)
|
||||
: llvm::mapped_iterator<Operation::operand_iterator, Type (*)(Value)>(
|
||||
|
|
|
@ -133,6 +133,13 @@ func @failedSameOperandAndResultShape_operand_result_mismatch(%t10x10 : tensor<1
|
|||
|
||||
// -----
|
||||
|
||||
func @failedSameOperandAndResultShape_operand_result_mismatch(%t10 : tensor<10xf32>, %t1: tensor<?xf32>) {
|
||||
// expected-error@+1 {{requires the same shape for all operands and results}}
|
||||
"test.same_operand_and_result_shape"(%t1, %t10) : (tensor<?xf32>, tensor<10xf32>) -> tensor<3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @failedSameOperandAndResultShape_no_operands() {
|
||||
// expected-error@+1 {{expected 1 or more operands}}
|
||||
"test.same_operand_and_result_shape"() : () -> (tensor<1xf32>)
|
||||
|
@ -347,7 +354,7 @@ func @failedSingleBlockImplicitTerminator_missing_terminator() {
|
|||
func private @foo()
|
||||
"test.finish" () : () -> ()
|
||||
}) : () -> ()
|
||||
func private @foo()
|
||||
func private @foo()
|
||||
|
||||
// -----
|
||||
|
||||
|
|
Loading…
Reference in New Issue