[mlir][tosa] Disable tosa.depthwise_conv2d canonicalizer for quantized case

Quantized case needs to include zero-point corrections before the tosa.mul.
Disabled for the quantized use-case.

Reviewed By: NatashaKnk

Differential Revision: https://reviews.llvm.org/D115264
This commit is contained in:
Rob Suderman 2021-12-07 10:03:31 -08:00
parent 5bf4f2acb8
commit e9fae0f19e
2 changed files with 15 additions and 0 deletions

View File

@ -526,12 +526,18 @@ struct DepthwiseConv2DMulOptimization
ShapedType inputType = input.getType().cast<ShapedType>(); ShapedType inputType = input.getType().cast<ShapedType>();
ShapedType weightType = weight.getType().cast<ShapedType>(); ShapedType weightType = weight.getType().cast<ShapedType>();
ShapedType resultType = op.output().getType().cast<ShapedType>(); ShapedType resultType = op.output().getType().cast<ShapedType>();
Type inputEType = inputType.getElementType();
if (!(inputType.hasStaticShape() && weightType.hasStaticShape() && if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
resultType.hasStaticShape())) { resultType.hasStaticShape())) {
return failure(); return failure();
} }
// Quantization information needs to still be performed.
if (op.quantization_info() || !inputEType.isa<FloatType>()) {
return failure();
}
// Stride must be 1 for this optimization. // Stride must be 1 for this optimization.
for (Attribute stride : op.stride().getValue()) { for (Attribute stride : op.stride().getValue()) {
if (!stride.cast<IntegerAttr>().getValue().isOne()) { if (!stride.cast<IntegerAttr>().getValue().isOne()) {

View File

@ -128,6 +128,15 @@ func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x
// ----- // -----
// CHECK-LABEL: @depthwise_conv2d_as_mul_q
func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6xi32> {
// CHECK: "tosa.depthwise_conv2d"
%0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 0 : i32, weight_zp = 0 : i32}} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32>
return %0 : tensor<4x10x10x6xi32>
}
// -----
// CHECK-LABEL: @depthwise_conv2d_stride_2 // CHECK-LABEL: @depthwise_conv2d_stride_2
func @depthwise_conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> { func @depthwise_conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
// CHECK: "tosa.depthwise_conv2d" // CHECK: "tosa.depthwise_conv2d"