forked from OSchip/llvm-project
[mlir][tosa] Disable tosa.depthwise_conv2d canonicalizer for quantized case
Quantized case needs to include zero-point corrections before the tosa.mul. Disabled for the quantized use-case. Reviewed By: NatashaKnk Differential Revision: https://reviews.llvm.org/D115264
This commit is contained in:
parent
5bf4f2acb8
commit
e9fae0f19e
|
@ -526,12 +526,18 @@ struct DepthwiseConv2DMulOptimization
|
||||||
ShapedType inputType = input.getType().cast<ShapedType>();
|
ShapedType inputType = input.getType().cast<ShapedType>();
|
||||||
ShapedType weightType = weight.getType().cast<ShapedType>();
|
ShapedType weightType = weight.getType().cast<ShapedType>();
|
||||||
ShapedType resultType = op.output().getType().cast<ShapedType>();
|
ShapedType resultType = op.output().getType().cast<ShapedType>();
|
||||||
|
Type inputEType = inputType.getElementType();
|
||||||
|
|
||||||
if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
|
if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
|
||||||
resultType.hasStaticShape())) {
|
resultType.hasStaticShape())) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Quantization information needs to still be performed.
|
||||||
|
if (op.quantization_info() || !inputEType.isa<FloatType>()) {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
// Stride must be 1 for this optimization.
|
// Stride must be 1 for this optimization.
|
||||||
for (Attribute stride : op.stride().getValue()) {
|
for (Attribute stride : op.stride().getValue()) {
|
||||||
if (!stride.cast<IntegerAttr>().getValue().isOne()) {
|
if (!stride.cast<IntegerAttr>().getValue().isOne()) {
|
||||||
|
|
|
@ -128,6 +128,15 @@ func @depthwise_conv2d_as_mul(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: @depthwise_conv2d_as_mul_q
|
||||||
|
func @depthwise_conv2d_as_mul_q(%arg0: tensor<4x10x10x2xi8>, %arg1: tensor<1x1x2x3xi8>, %arg2: tensor<6xi32>) -> tensor<4x10x10x6xi32> {
|
||||||
|
// CHECK: "tosa.depthwise_conv2d"
|
||||||
|
%0 = "tosa.depthwise_conv2d"(%arg0, %arg1, %arg2) {pad = [0, 0, 0, 0], stride = [1, 1], dilation = [1, 1], quantization_info = {input_zp = 0 : i32, weight_zp = 0 : i32}} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>) -> tensor<4x10x10x6xi32>
|
||||||
|
return %0 : tensor<4x10x10x6xi32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: @depthwise_conv2d_stride_2
|
// CHECK-LABEL: @depthwise_conv2d_stride_2
|
||||||
func @depthwise_conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
|
func @depthwise_conv2d_stride_2(%arg0: tensor<4x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<6xf32>) -> tensor<4x10x10x6xf32> {
|
||||||
// CHECK: "tosa.depthwise_conv2d"
|
// CHECK: "tosa.depthwise_conv2d"
|
||||||
|
|
Loading…
Reference in New Issue