diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 4de90058ba78..554023dc0381 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -118,6 +118,8 @@ def Tosa_Conv2DOp : Tosa_Op<"conv2d", [ let builders = [Tosa_ConvOpQuantInfoBuilder]; let verifier = [{ return verifyConvOp(*this); }]; + + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 78a6b1b4a141..51c41e3334ad 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -423,6 +423,98 @@ void PadOp::getCanonicalizationPatterns(OwningRewritePatternList &results, results.insert(context); } +struct Conv2DFullyConnectedOptimization + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::Conv2DOp op, + PatternRewriter &rewriter) const override { + Value input = op.input(); + Value weight = op.weight(); + ShapedType inputType = input.getType().cast(); + ShapedType weightType = weight.getType().cast(); + ShapedType resultType = op.getType().cast(); + + if (!inputType.hasStaticShape() || !weightType.hasRank()) { + return failure(); + } + + // Stride must be 1 for this optimization. + for (Attribute stride : op.stride().getValue()) { + if (!stride.cast().getValue().isOne()) { + return failure(); + } + } + + // Only works for a 1x1 kernel. + ArrayRef weightShape = weightType.getShape(); + if (weightShape[1] != 1 || weightShape[2] != 1) { + return failure(); + } + + // Reshape input to [N,IH,IW,IC] -> [N * IH * IW, IC]. + ArrayRef inputShape = inputType.getShape(); + llvm::SmallVector revisedInputShape{ + inputShape[0] * inputShape[1] * inputShape[2], inputShape[3]}; + auto revisedInputShapeType = RankedTensorType::get( + revisedInputShape, + input.getType().dyn_cast().getElementType()); + auto reshapedInput = rewriter + .create( + op.getLoc(), revisedInputShapeType, input, + rewriter.getI64ArrayAttr(revisedInputShape)) + .getResult(); + + // Reshape kernel to [OC,KH,KW,IC] -> [OC, IC]. + llvm::SmallVector revisedWeightShape{weightShape[0], + weightShape[3]}; + auto revisedWeightShapeType = RankedTensorType::get( + revisedWeightShape, + weight.getType().dyn_cast().getElementType()); + auto reshapedWeight = rewriter + .create( + op.getLoc(), revisedWeightShapeType, weight, + rewriter.getI64ArrayAttr(revisedWeightShape)) + .getResult(); + + // Perform a fully connected network over the reshaped input and weight. + llvm::SmallVector fullyConnectedShape{ + inputShape[0] * inputShape[1] * inputShape[2], weightShape[0]}; + auto fullyConnectedShapeType = RankedTensorType::get( + fullyConnectedShape, + weight.getType().dyn_cast().getElementType()); + + Value fullyConnectedValue; + if (op.quantization_info()) { + fullyConnectedValue = + rewriter + .create( + op.getLoc(), fullyConnectedShapeType, reshapedInput, + reshapedWeight, op.bias(), op.quantization_info().getValue()) + .getResult(); + } else { + fullyConnectedValue = rewriter + .create( + op.getLoc(), fullyConnectedShapeType, + reshapedInput, reshapedWeight, op.bias()) + .getResult(); + } + + // Reshape output to [N, IH, IW, OC]. + llvm::SmallVector outputShape{inputShape[0], inputShape[1], + inputShape[2], weightShape[0]}; + rewriter.replaceOpWithNewOp( + op, resultType, fullyConnectedValue, + rewriter.getI64ArrayAttr(outputShape)); + return success(); + } +}; + +void Conv2DOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // Operator Folders. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 70f26650fe61..c4d105ca438e 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -66,12 +66,52 @@ func @concat_fold_cast(%arg0: tensor) -> tensor { return %0 : tensor } +// ----- + +// CHECK-LABEL: @conv2d_as_fully_connected +func @conv2d_as_fully_connected(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<3x1x1x2xf32>, %arg2: tensor<3xf32>) -> tensor<4x10x10x3xf32> { + // CHECK-NOT: "tosa.conv2d" + // CHECK: %[[VAR0:.*]] = "tosa.reshape"(%arg0) {new_shape = [400, 2]} + // CHECK-SAME: -> tensor<400x2xf32> + // CHECK: %[[VAR1:.*]] = "tosa.reshape"(%arg1) {new_shape = [3, 2]} + // CHECK-SAME: -> tensor<3x2xf32> + // CHECK: %[[VAR2:.*]] = "tosa.fully_connected"(%[[VAR0]], %[[VAR1]], %arg2) + // CHECK-SAME: -> tensor<400x3xf32> + // CHECK: %[[VAR3:.*]] = "tosa.reshape"(%[[VAR2]]) {new_shape = [4, 10, 10, 3]} + // CHECK-SAME: -> tensor<4x10x10x3xf32> + // CHECK: return %[[VAR3]] + %0 = "tosa.conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32> + return %0 : tensor<4x10x10x3xf32> +} + +// ----- + +// CHECK-LABEL: @conv2d_stride_2 +func @conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>) -> tensor<4x10x10x3xf32> { + // CHECK: "tosa.conv2d" + %weight = "tosa.const"() {value = dense<[[[[1.0, 1.0]]], [[[1.0, 1.0]]], [[[1.0, 1.0]]]]> : tensor<3x1x1x2xf32>} : ()-> tensor<3x1x1x2xf32> + %bias = "tosa.const"() {value = dense<0.0> : tensor<3xf32>} : ()-> tensor<3xf32> + %0 = "tosa.conv2d"(%arg0, %weight, %bias) {pad = [0, 0, 0, 0], stride = [2, 2], dilation = [1, 1]} : (tensor<4x10x10x2xf32>, tensor<3x1x1x2xf32>, tensor<3xf32>) -> tensor<4x10x10x3xf32> + return %0 : tensor<4x10x10x3xf32> +} + +// ----- + +// CHECK-LABEL: @conv2d_weight_2x2 +func @conv2d_weight_2x2(%arg0: tensor<4x10x10x1xf32>) -> tensor<4x10x10x1xf32> { + // CHECK: "tosa.conv2d" + %weight = "tosa.const"() {value = dense<[[[[1.0], [1.0]], [[1.0], [1.0]]]]> : tensor<1x2x2x1xf32>} : ()-> tensor<1x2x2x1xf32> + %bias = "tosa.const"() {value = dense<0.0> : tensor<1xf32>} : ()-> tensor<1xf32> + %0 = "tosa.conv2d"(%arg0, %weight, %bias) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1]} : (tensor<4x10x10x1xf32>, tensor<1x2x2x1xf32>, tensor<1xf32>) -> tensor<4x10x10x1xf32> + return %0 : tensor<4x10x10x1xf32> +} + // ---- // CHECK-LABEL: @pad_noop func @pad_noop(%arg0: tensor) -> tensor { // CHECK: return %arg0 - %0 = "tosa.const"() { value = dense<0> : tensor<2x2xi32>} : () -> tensor<2x2xi32> + %0 = "tosa.const"() { value = dense<0> : tensor<2x2xi32>} : () -> tensor<2x2xi32> %1 = "tosa.pad"(%arg0, %0) : (tensor, tensor<2x2xi32>) -> tensor return %1 : tensor } @@ -82,7 +122,7 @@ func @pad_noop(%arg0: tensor) -> tensor { func @pad_determine_val_i32(%arg0: tensor, %arg1 : tensor<2x2xi32>) -> tensor { // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0> : tensor} // CHECK: "tosa.pad"(%arg0, %arg1, %[[ZERO]]) - %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> + %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> %1 = "tosa.pad"(%arg0, %arg1) : (tensor, tensor<2x2xi32>) -> tensor return %1 : tensor } @@ -93,7 +133,7 @@ func @pad_determine_val_i32(%arg0: tensor, %arg1 : tensor<2x2xi32>) -> func @pad_determine_val_f32(%arg0: tensor, %arg1 : tensor<2x2xi32>) -> tensor { // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<0.000000e+00> : tensor} // CHECK: "tosa.pad"(%arg0, %arg1, %[[ZERO]]) - %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> + %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> %1 = "tosa.pad"(%arg0, %arg1) : (tensor, tensor<2x2xi32>) -> tensor return %1 : tensor } @@ -104,7 +144,7 @@ func @pad_determine_val_f32(%arg0: tensor, %arg1 : tensor<2x2xi32>) -> func @pad_determine_val_quant(%arg0: tensor, %arg1 : tensor<2x2xi32>) -> tensor { // CHECK: %[[ZERO:.+]] = "tosa.const"() {value = dense<42> : tensor} // CHECK: "tosa.pad"(%arg0, %arg1, %[[ZERO]]) - %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> + %0 = "tosa.const"() { value = dense<[[1, 0], [0, 1]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> %1 = "tosa.pad"(%arg0, %arg1) { quantization_info = {input_zp = 42:i32} } : (tensor, tensor<2x2xi32>) -> tensor return %1 : tensor }