From acaf85f7000e69766f5a86a52bff0becc50aaa91 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Thu, 28 Jan 2021 10:47:07 -0800 Subject: [PATCH] Add convenience function for checking arrays of shapes compatible. Expand existing one to handle the common case for verifying compatible is existing and inferred. This considers arrays equivalent if they they have the same size and pairwise compatible elements. --- mlir/include/mlir/IR/TypeUtilities.h | 5 +++++ mlir/lib/IR/TypeUtilities.cpp | 12 ++++++++++++ 2 files changed, 17 insertions(+) diff --git a/mlir/include/mlir/IR/TypeUtilities.h b/mlir/include/mlir/IR/TypeUtilities.h index f4ec9bd43bc5..f5e611124e70 100644 --- a/mlir/include/mlir/IR/TypeUtilities.h +++ b/mlir/include/mlir/IR/TypeUtilities.h @@ -55,6 +55,11 @@ LogicalResult verifyCompatibleShape(ArrayRef shape1, /// does not matter. LogicalResult verifyCompatibleShape(Type type1, Type type2); +/// Returns success if the given two arrays have the same number of elements and +/// each pair wise entries have compatible shape. +LogicalResult verifyCompatibleShapes(ArrayRef types1, + ArrayRef types2); + //===----------------------------------------------------------------------===// // Utility Iterators //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp index 6018f2161bd2..7e96a69a1537 100644 --- a/mlir/lib/IR/TypeUtilities.cpp +++ b/mlir/lib/IR/TypeUtilities.cpp @@ -86,6 +86,18 @@ LogicalResult mlir::verifyCompatibleShape(Type type1, Type type2) { return verifyCompatibleShape(sType1.getShape(), sType2.getShape()); } +/// Returns success if the given two arrays have the same number of elements and +/// each pair wise entries have compatible shape. +LogicalResult mlir::verifyCompatibleShapes(ArrayRef types1, + ArrayRef types2) { + if (types1.size() != types2.size()) + return failure(); + for (auto it : zip_first(types1, types2)) + if (failed(verifyCompatibleShape(std::get<0>(it), std::get<1>(it)))) + return failure(); + return success(); +} + OperandElementTypeIterator::OperandElementTypeIterator( Operation::operand_iterator it) : llvm::mapped_iterator(