[MLIR][Shape] Remove empty extent tensor operands

Empty extent tensor operands were only removed when they were defined as a
constant. Additionally, we can remove them if they are known to be empty by
their type `tensor<0xindex>`.

Differential Revision: https://reviews.llvm.org/D101351
This commit is contained in:
Frederik Gossen 2021-04-27 14:50:52 +02:00
parent a950f66de2
commit f8d7bd996f
3 changed files with 15 additions and 6 deletions

View File

@ -70,6 +70,9 @@ public:
/// Returns the number of elements held by this attribute.
int64_t size() const { return getNumElements(); }
/// Returns if the number of elements held by this attribute is 0.
bool empty() const { return size() == 0; }
/// Generates a new ElementsAttr by mapping each int value to a new
/// underlying APInt. The new values can represent either an integer or float.
/// This ElementsAttr should contain integers.

View File

@ -534,8 +534,14 @@ struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
auto isPotentiallyNonEmptyShape = [](Value shape) {
if (auto constShape = shape.getDefiningOp<ConstShapeOp>())
return constShape.shape().size() != 0;
if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
if (extentTensorTy.getDimSize(0) == 0)
return false;
}
if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
if (constShape.shape().empty())
return false;
}
return true;
};
auto newOperands = llvm::to_vector<8>(

View File

@ -641,13 +641,13 @@ func @f() {
// -----
// Empty shape arguments can be removed from broadcastable ops.
// CHECK-LABEL: func @f
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xindex>, %[[ARG1:.*]]: tensor<?xindex>)
func @f(%arg0 : tensor<?xindex>, %arg1 : tensor<?xindex>) {
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xindex>, %[[ARG1:.*]]: tensor<?xindex>, %{{.*}}: tensor<0xindex>)
func @f(%arg0 : tensor<?xindex>, %arg1 : tensor<?xindex>, %arg2 : tensor<0xindex>) {
// CHECK-NOT: const_shape
// CHECK: cstr_broadcastable %[[ARG0]], %[[ARG1]] : tensor<?xindex>, tensor<?xindex>
%0 = shape.const_shape [] : !shape.shape
%1 = shape.cstr_broadcastable %arg0, %arg1, %0
: tensor<?xindex>, tensor<?xindex>, !shape.shape
%1 = shape.cstr_broadcastable %arg0, %arg1, %0, %arg2
: tensor<?xindex>, tensor<?xindex>, !shape.shape, tensor<0xindex>
"consume.witness"(%1) : (!shape.witness) -> ()
return
}