From f9cefc7b9089bc915121ef5890c641b95cc55819 Mon Sep 17 00:00:00 2001 From: not-jenni Date: Thu, 16 Dec 2021 15:20:34 -0800 Subject: [PATCH] [mlir][tosa] Add tosa.max_pool2d as no-op canonicalization When the input and output of a pool2d op are both 1x1, it can be canonicalized to a no-op Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D115908 --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 24 ++++++++------ mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 35 ++++++++++++++++++++ mlir/test/Dialect/Tosa/canonicalize.mlir | 18 +++++++--- 3 files changed, 62 insertions(+), 15 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 173f26db6c93..982880e027e0 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // // This file defines the operation set for the TOSA dialect as defined in -// the TOSA specfication (https://developer.mlplatform.org/w/tosa/). +// the TOSA specfication (https://developer.mlplatform.org/w/tosa/). // //===----------------------------------------------------------------------===// @@ -58,7 +58,7 @@ def Tosa_ArgMaxOp : Tosa_Op<"argmax", [ //===----------------------------------------------------------------------===// def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [ DeclareOpInterfaceMethods, + ["inferReturnTypeComponents"]>, NoSideEffect]> { let summary = "Performs max pooling on the input."; @@ -275,6 +275,8 @@ def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [ let results = (outs Tosa_Tensor4D:$output ); + + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// @@ -326,9 +328,9 @@ def Tosa_ClampOp : Tosa_Op<"clamp", [ let description = [{ Clamp to an arbitrary minimum and maximum value. - Maximum and minimum values are specified as values in the range of the + Maximum and minimum values are specified as values in the range of the input type. - No zero point subtraction is done to the values, thus to clamp to the zero + No zero point subtraction is done to the values, thus to clamp to the zero point value, the zero point itself should be supplied as the minimum value. }]; @@ -488,7 +490,7 @@ def Tosa_BitwiseAndOp : Tosa_Op<"bitwise_and", [ let description = [{ Elementwise bitwise AND of input1 and input2. Axis of size 1 - will be broadcast as necessary. + will be broadcast as necessary. }]; let arguments = (ins @@ -1379,7 +1381,7 @@ def Tosa_ConcatOp : Tosa_Op<"concat", [ let summary = "Concatenates tensors along one dimension."; let description = [{ - Concatenate a variadic amount of tensors along a given axis. No data + Concatenate a variadic amount of tensors along a given axis. No data conversion happens during a concat operation. }]; @@ -1405,7 +1407,7 @@ def Tosa_PadOp : Tosa_Op<"pad", [ let summary = "Pads a tensor with value specified."; let description = [{ - Pads a tensor along borders of each dimension with pad_value. + Pads a tensor along borders of each dimension with pad_value. }]; let arguments = (ins @@ -1510,7 +1512,7 @@ def Tosa_SliceOp: Tosa_Op<"slice", [ //===----------------------------------------------------------------------===// def Tosa_TileOp: Tosa_Op<"tile", [ DeclareOpInterfaceMethods, + ["inferReturnTypeComponents"]>, NoSideEffect]> { let summary = "Tile operator"; @@ -1534,7 +1536,7 @@ def Tosa_TileOp: Tosa_Op<"tile", [ //===----------------------------------------------------------------------===// def Tosa_TransposeOp : Tosa_Op<"transpose", [ DeclareOpInterfaceMethods, + ["inferReturnTypeComponents"]>, NoSideEffect]> { let summary = "Transpose operator"; @@ -1565,7 +1567,7 @@ def Tosa_TransposeOp : Tosa_Op<"transpose", [ //===----------------------------------------------------------------------===// def Tosa_GatherOp : Tosa_Op<"gather", [ DeclareOpInterfaceMethods, + ["inferReturnTypeComponents"]>, NoSideEffect]> { let summary = "Gather operation,"; @@ -1697,7 +1699,7 @@ def Tosa_CastOp: Tosa_Op<"cast", [NoSideEffect, //===----------------------------------------------------------------------===// // Operator: rescale //===----------------------------------------------------------------------===// -def Tosa_RescaleOp: Tosa_Op<"rescale", [NoSideEffect, +def Tosa_RescaleOp: Tosa_Op<"rescale", [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "Tosa rescale operator"; diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 68c0d015c3a6..9809e57e3a9a 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -614,6 +614,41 @@ void DepthwiseConv2DOp::getCanonicalizationPatterns( results.insert(context); } +struct MaxPool2dIsNoOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, + PatternRewriter &rewriter) const override { + Value input = op.input(); + Value output = op.output(); + ShapedType inputType = input.getType().cast(); + ShapedType outputType = output.getType().cast(); + + if (!inputType.hasStaticShape() || !outputType.hasStaticShape()) { + return failure(); + } + + // If the output and input shapes are 1x1, then this is a no op. + ArrayRef outputShape = outputType.getShape(); + if (outputShape[1] != 1 || outputShape[2] != 1) { + return failure(); + } + + ArrayRef inputShape = inputType.getShape(); + if (inputShape[1] != 1 || inputShape[2] != 1) { + return failure(); + } + + rewriter.replaceOp(op, input); + return success(); + } +}; + +void MaxPool2dOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // Operator Folders. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 91c2e3ce7feb..fa5304bcfb04 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -181,7 +181,17 @@ func @depthwise_conv2d_weight_2x2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<2x return %0 : tensor<4x10x10x6xf32> } -// ---- +// ----- + +// CHECK-LABEL: @max_pool2d_is_noop +func @max_pool2d_is_noop(%arg0: tensor<10x1x1x3xf32>) -> tensor<10x1x1x3xf32> { + // CHECK-NOT: "tosa.max_pool2d" + // CHECK: return %arg0 + %0 = "tosa.max_pool2d"(%arg0) {kernel = [1, 1], pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<10x1x1x3xf32>) -> tensor<10x1x1x3xf32> + return %0 : tensor<10x1x1x3xf32> +} + +// ----- // CHECK-LABEL: @pad_noop func @pad_noop(%arg0: tensor) -> tensor { @@ -191,7 +201,7 @@ func @pad_noop(%arg0: tensor) -> tensor { return %1 : tensor } -// ---- +// ----- // CHECK-LABEL: @pad_determine_val_i32 func @pad_determine_val_i32(%arg0: tensor, %arg1 : tensor<2x2xi32>) -> tensor { @@ -202,7 +212,7 @@ func @pad_determine_val_i32(%arg0: tensor, %arg1 : tensor<2x2xi32>) -> return %1 : tensor } -// ---- +// ----- // CHECK-LABEL: @pad_determine_val_f32 func @pad_determine_val_f32(%arg0: tensor, %arg1 : tensor<2x2xi32>) -> tensor { @@ -213,7 +223,7 @@ func @pad_determine_val_f32(%arg0: tensor, %arg1 : tensor<2x2xi32>) -> return %1 : tensor } -// ---- +// ----- // CHECK-LABEL: @pad_determine_val_quant func @pad_determine_val_quant(%arg0: tensor, %arg1 : tensor<2x2xi32>) -> tensor {