forked from OSchip/llvm-project
[MLIR][Shape] Remove duplicate operands of `shape.assuming_all` op
Differential Revision: https://reviews.llvm.org/D103403
This commit is contained in:
parent
f7c95c3322
commit
1288adaa73
|
@ -429,11 +429,36 @@ struct AssumingAllToCstrEqCanonicalization
|
|||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
|
||||
using OpRewritePattern<OpTy>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(OpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Find unique operands.
|
||||
SmallVector<Value, 2> unique;
|
||||
for (Value v : op.getOperands()) {
|
||||
if (!llvm::is_contained(unique, v))
|
||||
unique.push_back(v);
|
||||
}
|
||||
|
||||
// Reduce op to equivalent with unique operands.
|
||||
if (unique.size() < op.getNumOperands()) {
|
||||
rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique,
|
||||
op->getAttrs());
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization>(context);
|
||||
patterns.add<AssumingAllOneOp, AssumingAllToCstrEqCanonicalization,
|
||||
RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
|
||||
}
|
||||
|
||||
OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
|
||||
|
@ -508,30 +533,6 @@ static LogicalResult verify(BroadcastOp op) {
|
|||
}
|
||||
|
||||
namespace {
|
||||
template <typename OpTy>
|
||||
struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
|
||||
using OpRewritePattern<OpTy>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(OpTy op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Find unique operands.
|
||||
SmallVector<Value, 2> unique;
|
||||
for (Value v : op.getOperands()) {
|
||||
if (!llvm::is_contained(unique, v))
|
||||
unique.push_back(v);
|
||||
}
|
||||
|
||||
// Reduce op to equivalent with unique operands.
|
||||
if (unique.size() < op.getNumOperands()) {
|
||||
rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), unique,
|
||||
op->getAttrs());
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename OpTy>
|
||||
struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
|
||||
using OpRewritePattern<OpTy>::OpRewritePattern;
|
||||
|
|
|
@ -477,6 +477,19 @@ func @assuming_all_to_cstr_eq(%a : !shape.shape, %b : tensor<?xindex>,
|
|||
return %2 : !shape.witness
|
||||
}
|
||||
|
||||
// -----
|
||||
// `assuming_all` with duplicate operands.
|
||||
// CHECK-LABEL: func @assuming_all_duplicate_operands
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xindex>, %[[ARG1:.*]]: tensor<?xindex>)
|
||||
func @assuming_all_duplicate_operands(%arg0 : tensor<?xindex>,
|
||||
%arg1 : tensor<?xindex>) -> !shape.witness {
|
||||
// CHECK: %[[RES:.*]] = shape.cstr_broadcastable %[[ARG0]], %[[ARG1]]
|
||||
// CHECK: return %[[RES]]
|
||||
%0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<?xindex>, tensor<?xindex>
|
||||
%1 = shape.assuming_all %0, %0, %0
|
||||
return %1 : !shape.witness
|
||||
}
|
||||
|
||||
// -----
|
||||
// `assuming_all` with all `cstr_eq` but disjoint operands cannot be collapsed.
|
||||
// CHECK-LABEL: func @assuming_all_to_cstr_eq
|
||||
|
|
Loading…
Reference in New Issue