From e9fae0f19eec1fce746101b410d2345f0fbf89b4 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Tue, 7 Dec 2021 10:03:31 -0800 Subject: [PATCH] [mlir][tosa] Disable tosa.depthwise_conv2d canonicalizer for quantized case Quantized case needs to include zero-point corrections before the tosa.mul. Disabled for the quantized use-case. Reviewed By: NatashaKnk Differential Revision: https://reviews.llvm.org/D115264 --- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 6 ++++++ mlir/test/Dialect/Tosa/canonicalize.mlir | 9 +++++++++ 2 files changed, 15 insertions(+) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index cefe13f57dbb..601e66006d6f 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -526,12 +526,18 @@ struct DepthwiseConv2DMulOptimization ShapedType inputType = input.getType().cast(); ShapedType weightType = weight.getType().cast(); ShapedType resultType = op.output().getType().cast(); + Type inputEType = inputType.getElementType(); if (!(inputType.hasStaticShape() && weightType.hasStaticShape() && resultType.hasStaticShape())) { return failure(); } + // Quantization information needs to still be performed. + if (op.quantization_info() || !inputEType.isa()) { + return failure(); + } + // Stride must be 1 for this optimization. for (Attribute stride : op.stride().getValue()) { if (!stride.cast().getValue().isOne()) { diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index ed659ee91964..a9418be3e632 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -128,6 +128,15 @@ func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x // ----- +// CHECK-LABEL: @depthwise_conv2d_as_mul_q +func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6xi32> { + // CHECK: "tosa.depthwise_conv2d" + %0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 0 : i32, weight_zp = 0 : i32}} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32> + return %0 : tensor<4x10x10x6xi32> +} + +// ----- + // CHECK-LABEL: @depthwise_conv2d_stride_2 func @depthwise_conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> { // CHECK: "tosa.depthwise_conv2d"