forked from OSchip/llvm-project
[MLIR][Shape] Remove empty extent tensor operands
Empty extent tensor operands were only removed when they were defined as a constant. Additionally, we can remove them if they are known to be empty by their type `tensor<0xindex>`. Differential Revision: https://reviews.llvm.org/D101351
This commit is contained in:
parent
a950f66de2
commit
f8d7bd996f
|
@ -70,6 +70,9 @@ public:
|
|||
/// Returns the number of elements held by this attribute.
|
||||
int64_t size() const { return getNumElements(); }
|
||||
|
||||
/// Returns if the number of elements held by this attribute is 0.
|
||||
bool empty() const { return size() == 0; }
|
||||
|
||||
/// Generates a new ElementsAttr by mapping each int value to a new
|
||||
/// underlying APInt. The new values can represent either an integer or float.
|
||||
/// This ElementsAttr should contain integers.
|
||||
|
|
|
@ -534,8 +534,14 @@ struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
|
|||
LogicalResult matchAndRewrite(OpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto isPotentiallyNonEmptyShape = [](Value shape) {
|
||||
if (auto constShape = shape.getDefiningOp<ConstShapeOp>())
|
||||
return constShape.shape().size() != 0;
|
||||
if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
|
||||
if (extentTensorTy.getDimSize(0) == 0)
|
||||
return false;
|
||||
}
|
||||
if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
|
||||
if (constShape.shape().empty())
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
auto newOperands = llvm::to_vector<8>(
|
||||
|
|
|
@ -641,13 +641,13 @@ func @f() {
|
|||
// -----
|
||||
// Empty shape arguments can be removed from broadcastable ops.
|
||||
// CHECK-LABEL: func @f
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xindex>, %[[ARG1:.*]]: tensor<?xindex>)
|
||||
func @f(%arg0 : tensor<?xindex>, %arg1 : tensor<?xindex>) {
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xindex>, %[[ARG1:.*]]: tensor<?xindex>, %{{.*}}: tensor<0xindex>)
|
||||
func @f(%arg0 : tensor<?xindex>, %arg1 : tensor<?xindex>, %arg2 : tensor<0xindex>) {
|
||||
// CHECK-NOT: const_shape
|
||||
// CHECK: cstr_broadcastable %[[ARG0]], %[[ARG1]] : tensor<?xindex>, tensor<?xindex>
|
||||
%0 = shape.const_shape [] : !shape.shape
|
||||
%1 = shape.cstr_broadcastable %arg0, %arg1, %0
|
||||
: tensor<?xindex>, tensor<?xindex>, !shape.shape
|
||||
%1 = shape.cstr_broadcastable %arg0, %arg1, %0, %arg2
|
||||
: tensor<?xindex>, tensor<?xindex>, !shape.shape, tensor<0xindex>
|
||||
"consume.witness"(%1) : (!shape.witness) -> ()
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue