[MLIR] Canonicalize broadcast operations on single shapes

This covers cases that are not folded away because the extent tensor type
becomes more concrete in the process.

Differential Revision: https://reviews.llvm.org/D98782
This commit is contained in:
Frederik Gossen 2021-03-18 08:58:59 +01:00
parent 6802fdf887
commit 1ce70c15ed
2 changed files with 28 additions and 1 deletions

View File

@ -414,11 +414,26 @@ struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
return failure();
}
};
struct BroadcastForwardSingleOperandPattern
: public OpRewritePattern<BroadcastOp> {
using OpRewritePattern<BroadcastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(BroadcastOp op,
PatternRewriter &rewriter) const override {
if (op.getNumOperands() == 1) {
rewriter.replaceOp(op, op.shapes().front());
return success();
}
return failure();
}
};
} // namespace
void BroadcastOp::getCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.insert<RemoveDuplicateOperandsPattern<BroadcastOp>>(context);
patterns.insert<BroadcastForwardSingleOperandPattern,
RemoveDuplicateOperandsPattern<BroadcastOp>>(context);
}
//===----------------------------------------------------------------------===//

View File

@ -1119,3 +1119,15 @@ func @broadcast_on_duplicate_shapes(%a : !shape.shape, %b : !shape.shape)
!shape.shape, !shape.shape, !shape.shape, !shape.shape -> !shape.shape
return %0 : !shape.shape
}
// -----
// CHECK-LABEL: @broadcast_on_single_operand
// CHECK-SAME: (%[[A:.*]]: tensor<3xindex>)
func @broadcast_on_single_operand(%a : tensor<3xindex>) {
// CHECK-NOT: broadcast
// CHECK: "use"(%[[A]])
%0 = shape.broadcast %a : tensor<3xindex> -> tensor<?xindex>
"use"(%0) : (tensor<?xindex>) -> ()
return
}