[MLIR][Shape] Remove duplicate operands of `shape.assuming_all` op

Differential Revision: https://reviews.llvm.org/D103403
This commit is contained in:
Frederik Gossen 2021-05-31 13:51:20 +02:00
parent f7c95c3322
commit 1288adaa73
2 changed files with 39 additions and 25 deletions

View File

@ -429,11 +429,36 @@ struct AssumingAllToCstrEqCanonicalization
return success();
}
};
template <typename OpTy>
struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
// Find unique operands.
SmallVector<Value, 2> unique;
for (Value v : op.getOperands()) {
if (!llvm::is_contained(unique, v))
unique.push_back(v);
}
// Reduce op to equivalent with unique operands.
if (unique.size() < op.getNumOperands()) {
rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique,
op->getAttrs());
return success();
}
return failure();
}
};
} // namespace
void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization>(context);
patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization,
RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
}
OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
@ -508,30 +533,6 @@ static LogicalResult verify(BroadcastOp op) {
}
namespace {
template <typename OpTy>
struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(OpTy op,
PatternRewriter &rewriter) const override {
// Find unique operands.
SmallVector<Value, 2> unique;
for (Value v : op.getOperands()) {
if (!llvm::is_contained(unique, v))
unique.push_back(v);
}
// Reduce op to equivalent with unique operands.
if (unique.size() < op.getNumOperands()) {
rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique,
op->getAttrs());
return success();
}
return failure();
}
};
template <typename OpTy>
struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
using OpRewritePattern<OpTy>::OpRewritePattern;

View File

@ -477,6 +477,19 @@ func @assuming_all_to_cstr_eq(%a : !shape.shape, %b : tensor<?xindex>,
return %2 : !shape.witness
}
// -----
// `assuming_all` with duplicate operands.
// CHECK-LABEL: func @assuming_all_duplicate_operands
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xindex>, %[[ARG1:.*]]: tensor<?xindex>)
func @assuming_all_duplicate_operands(%arg0 : tensor<?xindex>,
%arg1 : tensor<?xindex>) -> !shape.witness {
// CHECK: %[[RES:.*]] = shape.cstr_broadcastable %[[ARG0]], %[[ARG1]]
// CHECK: return %[[RES]]
%0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<?xindex>, tensor<?xindex>
%1 = shape.assuming_all %0, %0, %0
return %1 : !shape.witness
}
// -----
// `assuming_all` with all `cstr_eq` but disjoint operands cannot be collapsed.
// CHECK-LABEL: func @assuming_all_to_cstr_eq