[mlir][tosa] Fix tosa average_pool2d to linalg type issue

Average pool assumed the same input/output type. Result type for integers
is always an i32, should be updated appropriately.

Reviewed By: GMNGeoffrey

Differential Revision: https://reviews.llvm.org/D111590
This commit is contained in:
Rob Suderman 2021-10-12 13:02:29 -07:00
parent d7e766c781
commit 95e4b71519
5 changed files with 58 additions and 13 deletions

View File

@ -82,6 +82,8 @@ def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [
);
let builders = [Tosa_AvgPool2dOpQuantInfoBuilder];
let verifier = [{ return verifyAveragePoolOp(*this); }];
}
//===----------------------------------------------------------------------===//

View File

@ -2796,7 +2796,7 @@ public:
Type inElementTy = inputTy.getElementType();
ShapedType resultTy = op.getType().template cast<ShapedType>();
Type resultETy = inputTy.getElementType();
Type resultETy = op.getType().cast<ShapedType>().getElementType();
Type accETy =
inElementTy.isa<IntegerType>() ? rewriter.getI32Type() : inElementTy;
@ -2810,9 +2810,10 @@ public:
pad.resize(2, 0);
getValuesFromIntArrayAttribute(op.pad(), pad);
pad.resize(pad.size() + 2, 0);
Attribute initialAttr = rewriter.getZeroAttr(accETy);
Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
Attribute padAttr = rewriter.getZeroAttr(inElementTy);
Value paddedInput = applyPad(loc, input, pad, padAttr, rewriter);
Attribute initialAttr = rewriter.getZeroAttr(accETy);
Value initialValue = rewriter.create<ConstantOp>(loc, initialAttr);
SmallVector<int64_t> kernel, stride;
@ -2909,8 +2910,7 @@ public:
// to be applied.
Value poolVal = args[0];
if (accETy.isa<FloatType>()) {
auto countF =
rewriter.create<mlir::SIToFPOp>(loc, inElementTy, countI);
auto countF = rewriter.create<mlir::SIToFPOp>(loc, accETy, countI);
poolVal =
rewriter.create<DivFOp>(loc, poolVal, countF)->getResult(0);
} else {
@ -2974,8 +2974,11 @@ public:
auto clamp = clampHelper<mlir::CmpIOp>(
loc, scaled, min, max, CmpIPredicate::slt, rewriter);
poolVal = clamp;
// Convert type.
poolVal = rewriter.create<TruncateIOp>(loc, resultETy, clamp);
if (resultETy != clamp.getType()) {
poolVal = rewriter.create<TruncateIOp>(loc, resultETy, poolVal);
}
}
// Cast to output type.

View File

@ -342,6 +342,26 @@ static LogicalResult verifyConvOp(T op) {
return success();
}
static LogicalResult verifyAveragePoolOp(tosa::AvgPool2dOp op) {
auto inputETy = op.input().getType().cast<ShapedType>().getElementType();
auto resultETy = op.getType().cast<ShapedType>().getElementType();
if (auto quantType = inputETy.dyn_cast<mlir::quant::UniformQuantizedType>())
inputETy = quantType.getStorageType();
if (auto quantType = resultETy.dyn_cast<mlir::quant::UniformQuantizedType>())
resultETy = quantType.getStorageType();
if (inputETy.isF32() && resultETy.isF32())
return success();
if (inputETy.isInteger(8) && resultETy.isInteger(32))
return success();
if (inputETy.isInteger(16) && resultETy.isInteger(32))
return success();
return op.emitOpError("input/output element types are incompatible.");
}
//===----------------------------------------------------------------------===//
// TOSA Operator Quantization Builders.
//===----------------------------------------------------------------------===//

View File

@ -1465,15 +1465,14 @@ func @avg_pool_i8(%arg0 : tensor<1x128x128x2xi8>) -> () {
// CHECK: %[[SCALE:.+]] = "tosa.apply_scale"(%{{.+}}, %[[MULTIPLIER]], %[[SHIFT]]) {double_round = false}
// CHECK: %[[OUTZP:.+]] = constant -128
// CHECK: %[[OUT:.+]] = addi %[[SCALE]], %[[OUTZP]]
// CHECK: %[[MIN:.+]] = constant -128
// CHECK: %[[MAX:.+]] = constant 127
// CHECK: %[[MIN:.+]] = constant -2147483648
// CHECK: %[[MAX:.+]] = constant 2147483647
// CHECK: %[[CMP_MIN:.+]] = cmpi slt, %[[OUT]], %[[MIN]]
// CHECK: %[[CLMP_MIN:.+]] = select %[[CMP_MIN]], %[[MIN]], %[[OUT]]
// CHECK: %[[CMP_MAX:.+]] = cmpi slt, %[[MAX]], %[[OUT]]
// CHECK: %[[CLMP_MAX:.+]] = select %[[CMP_MAX]], %[[MAX]], %[[CLMP_MIN]]
// CHECK: %[[TRUNC:.+]] = trunci %[[CLMP_MAX]]
// CHECK: linalg.yield %[[TRUNC]]
%0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 4], pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, output_zp = -128 : i32}, stride = [4, 4]} : (tensor<1x128x128x2xi8>) -> tensor<1x32x32x2xi8>
// CHECK: linalg.yield %[[CLMP_MAX]]
%0 = "tosa.avg_pool2d"(%arg0) {kernel = [4, 4], pad = [0, 0, 0, 0], quantization_info = {input_zp = -128 : i32, output_zp = -128 : i32}, stride = [4, 4]} : (tensor<1x128x128x2xi8>) -> tensor<1x32x32x2xi32>
return
}

View File

@ -10,12 +10,33 @@ func @test_argmax(%arg0: tensor<14x19xf32>) -> tensor<14xi32> {
}
// -----
// CHECK-LABEL: avg_pool2d
func @test_avg_pool2d(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> {
// CHECK-LABEL: avg_pool2d_f32
func @test_avg_pool2d_f32(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> {
%0 = "tosa.avg_pool2d"(%arg0) {kernel = [2, 2], pad = [0, 1, 0, 1], stride = [1, 1]} : (tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32>
return %0 : tensor<1x7x7x9xf32>
}
// -----
// CHECK-LABEL: avg_pool2d_i8
func @test_avg_pool2d_i8(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi32> {
%0 = "tosa.avg_pool2d"(%arg0) {kernel = [2, 2], pad = [0, 1, 0, 1], stride = [1, 1]} : (tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi32>
return %0 : tensor<1x7x7x9xi32>
}
// -----
// CHECK-LABEL: avg_pool2d_i16
func @test_avg_pool2d_i16(%arg0: tensor<1x7x7x9xi16>) -> tensor<1x7x7x9xi32> {
%0 = "tosa.avg_pool2d"(%arg0) {kernel = [2, 2], pad = [0, 1, 0, 1], stride = [1, 1]} : (tensor<1x7x7x9xi16>) -> tensor<1x7x7x9xi32>
return %0 : tensor<1x7x7x9xi32>
}
// -----
// CHECK-LABEL: avg_pool2d_q8
func @test_avg_pool2d_q8(%arg0: tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i32:f32, 0.01>> {
%0 = "tosa.avg_pool2d"(%arg0) {kernel = [2, 2], pad = [0, 1, 0, 1], stride = [1, 1]} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i32:f32, 0.01>>
return %0 : tensor<1x7x7x9x!quant.uniform<i32:f32, 0.01>>
}
// -----
// CHECK-LABEL: conv2d
func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, %arg2: tensor<8xf32>) -> tensor<1x4x4x8xf32> {