[mlir][tosa] Add tosa.max_pool2d as no-op canonicalization

When the input and output of a pool2d op are both 1x1, it can be canonicalized to a no-op

Reviewed By: rsuderman

Differential Revision: https://reviews.llvm.org/D115908
This commit is contained in:
not-jenni 2021-12-16 15:20:34 -08:00 committed by Rob Suderman
parent b4618f576e
commit f9cefc7b90
3 changed files with 62 additions and 15 deletions

View File

@ -275,6 +275,8 @@ def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [
let results = (outs
Tosa_Tensor4D:$output
);
let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//

View File

@ -614,6 +614,41 @@ void DepthwiseConv2DOp::getCanonicalizationPatterns(
results.insert<DepthwiseConv2DMulOptimization>(context);
}
struct MaxPool2dIsNoOp : public OpRewritePattern<tosa::MaxPool2dOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
PatternRewriter &rewriter) const override {
Value input = op.input();
Value output = op.output();
ShapedType inputType = input.getType().cast<ShapedType>();
ShapedType outputType = output.getType().cast<ShapedType>();
if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) {
return failure();
}
// If the output and input shapes are 1x1, then this is a no op.
ArrayRef<int64_t> outputShape = outputType.getShape();
if (outputShape[1] != 1 || outputShape[2] != 1) {
return failure();
}
ArrayRef<int64_t> inputShape = inputType.getShape();
if (inputShape[1] != 1 || inputShape[2] != 1) {
return failure();
}
rewriter.replaceOp(op, input);
return success();
}
};
void MaxPool2dOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<MaxPool2dIsNoOp>(context);
}
//===----------------------------------------------------------------------===//
// Operator Folders.
//===----------------------------------------------------------------------===//

View File

@ -181,7 +181,17 @@ func @depthwise_conv2d_weight_2x2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<2x
return %0 : tensor<4x10x10x6xf32>
}
// ----
// -----
// CHECK-LABEL: @max_pool2d_is_noop
func @max_pool2d_is_noop(%arg0: tensor<10x1x1x3xf32>) -> tensor<10x1x1x3xf32> {
// CHECK-NOT: "tosa.max_pool2d"
// CHECK: return %arg0
%0 = "tosa.max_pool2d"(%arg0) {kernel = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<10x1x1x3xf32>) -> tensor<10x1x1x3xf32>
return %0 : tensor<10x1x1x3xf32>
}
// -----
// CHECK-LABEL: @pad_noop
func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
@ -191,7 +201,7 @@ func @pad_noop(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32> {
return %1 : tensor<?x?xf32>
}
// ----
// -----
// CHECK-LABEL: @pad_determine_val_i32
func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {
@ -202,7 +212,7 @@ func @pad_determine_val_i32(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) ->
return %1 : tensor<?x?xi32>
}
// ----
// -----
// CHECK-LABEL: @pad_determine_val_f32
func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xf32> {
@ -213,7 +223,7 @@ func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32>) ->
return %1 : tensor<?x?xf32>
}
// ----
// -----
// CHECK-LABEL: @pad_determine_val_quant
func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> {