forked from OSchip/llvm-project
[mlir][tosa] Add tosa.avg_pool2d lowering
Added the float lowerings for avg pool with corresponding tests. Differential Revision: https://reviews.llvm.org/D100793
This commit is contained in:
parent
987e52851e
commit
648dfdfc24
|
@ -1626,18 +1626,19 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
|
||||
template <typename SrcOp>
|
||||
class Pool2dConverter : public OpRewritePattern<SrcOp> {
|
||||
public:
|
||||
using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern;
|
||||
using OpRewritePattern<SrcOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
|
||||
LogicalResult matchAndRewrite(SrcOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.input();
|
||||
ShapedType inputTy = input.getType().cast<ShapedType>();
|
||||
Type inElementTy = inputTy.getElementType();
|
||||
|
||||
ShapedType resultTy = op.getType().cast<ShapedType>();
|
||||
ShapedType resultTy = op.getType().template cast<ShapedType>();
|
||||
Type outElementTy = inputTy.getElementType();
|
||||
int64_t rank = inputTy.getRank();
|
||||
|
||||
|
@ -1646,17 +1647,20 @@ public:
|
|||
|
||||
// Determine what the initial value needs to be for the max pool op.
|
||||
Attribute initialAttr;
|
||||
if (outElementTy.isF32())
|
||||
if (isa<tosa::MaxPool2dOp>(op) && outElementTy.isF32())
|
||||
initialAttr = rewriter.getFloatAttr(
|
||||
outElementTy,
|
||||
APFloat::getLargest(
|
||||
outElementTy.cast<FloatType>().getFloatSemantics(), true));
|
||||
|
||||
if (outElementTy.isa<IntegerType>())
|
||||
if (isa<tosa::MaxPool2dOp>(op) && outElementTy.isa<IntegerType>())
|
||||
initialAttr = rewriter.getIntegerAttr(
|
||||
outElementTy,
|
||||
APInt::getSignedMinValue(outElementTy.getIntOrFloatBitWidth()));
|
||||
|
||||
if (isa<tosa::AvgPool2dOp>(op) && outElementTy.isa<FloatType>())
|
||||
initialAttr = rewriter.getZeroAttr(outElementTy);
|
||||
|
||||
if (!initialAttr)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unsupported initial value for tosa.maxpool_2d op");
|
||||
|
@ -1670,6 +1674,7 @@ public:
|
|||
|
||||
Attribute strideAttr = rewriter.getI64VectorAttr(stride);
|
||||
Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
|
||||
int64_t kernelSize = kernel[0] * kernel[1];
|
||||
|
||||
// If non-zero padding we need to pad the input
|
||||
if (llvm::any_of(pad, [](int64_t v) { return v != 0; })) {
|
||||
|
@ -1716,34 +1721,46 @@ public:
|
|||
.getOperation());
|
||||
};
|
||||
|
||||
if (inElementTy.isF32()) {
|
||||
if (isa<tosa::MaxPool2dOp>(op) && inElementTy.isF32()) {
|
||||
linalg::LinalgOp poolingOp =
|
||||
createOp(static_cast<linalg::PoolingNHWCMaxFOp *>(nullptr));
|
||||
rewriter.replaceOp(op, poolingOp->getResult(0));
|
||||
return success();
|
||||
}
|
||||
|
||||
if (inElementTy.isInteger(8)) {
|
||||
if (isa<tosa::MaxPool2dOp>(op) && inElementTy.isInteger(8)) {
|
||||
linalg::LinalgOp poolingOp =
|
||||
createOp(static_cast<linalg::PoolingNHWCMaxI8Op *>(nullptr));
|
||||
rewriter.replaceOp(op, poolingOp->getResult(0));
|
||||
return success();
|
||||
}
|
||||
|
||||
if (inElementTy.isInteger(16)) {
|
||||
if (isa<tosa::MaxPool2dOp>(op) && inElementTy.isInteger(16)) {
|
||||
linalg::LinalgOp poolingOp =
|
||||
createOp(static_cast<linalg::PoolingNHWCMaxI16Op *>(nullptr));
|
||||
rewriter.replaceOp(op, poolingOp->getResult(0));
|
||||
return success();
|
||||
}
|
||||
|
||||
if (inElementTy.isInteger(32)) {
|
||||
if (isa<tosa::MaxPool2dOp>(op) && inElementTy.isInteger(32)) {
|
||||
linalg::LinalgOp poolingOp =
|
||||
createOp(static_cast<linalg::PoolingNHWCMaxI32Op *>(nullptr));
|
||||
rewriter.replaceOp(op, poolingOp->getResult(0));
|
||||
return success();
|
||||
}
|
||||
|
||||
if (isa<tosa::AvgPool2dOp>(op) && inElementTy.isF32()) {
|
||||
linalg::LinalgOp poolingOp =
|
||||
createOp(static_cast<linalg::PoolingNHWCSumFOp *>(nullptr));
|
||||
auto constAttr = DenseElementsAttr::get(
|
||||
resultTy, static_cast<float>(1.0 / kernelSize));
|
||||
auto constant = rewriter.create<ConstantOp>(loc, constAttr);
|
||||
auto mul = rewriter.create<tosa::MulOp>(
|
||||
loc, resultTy, poolingOp->getResult(0), constant, 0);
|
||||
rewriter.replaceOp(op, mul.output());
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
@ -1805,7 +1822,8 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
|
|||
TileConverter,
|
||||
TransposeConverter,
|
||||
MatMulConverter,
|
||||
MaxPool2dConverter,
|
||||
Pool2dConverter<tosa::AvgPool2dOp>,
|
||||
Pool2dConverter<tosa::MaxPool2dOp>,
|
||||
FullyConnectedConverter>(patterns->getContext());
|
||||
// clang-format on
|
||||
// clang-format on
|
||||
}
|
||||
|
|
|
@ -923,6 +923,21 @@ func @max_pool_i32(%arg0: tensor<1x6x34x62xi32>) -> () {
|
|||
%0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xi32>) -> (tensor<1x4x32x62xi32>)
|
||||
return
|
||||
}
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @avg_pool
|
||||
func @avg_pool(%arg0: tensor<1x6x34x62xf32>) -> () {
|
||||
// CHECK-DAG: [[CONST:%.+]] = constant 0
|
||||
// CHECK-DAG: [[INIT:%.+]] = linalg.init_tensor [1, 3, 31, 62]
|
||||
// CHECK-DAG: [[FILL:%.+]] = linalg.fill([[INIT]], [[CONST]])
|
||||
// CHECK-DAG: [[KERNEL:%.+]] = linalg.init_tensor [4, 4]
|
||||
// CHECK: linalg.pooling_nhwc_sum {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%arg0, [[KERNEL]] : tensor<1x6x34x62xf32>, tensor<4x4xf32>) outs([[FILL]] : tensor<1x3x31x62xf32>)
|
||||
// CHECK: constant dense<6.250000e-02>
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: mulf
|
||||
%0 = "tosa.avg_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [4, 4], stride = [1, 1]} : (tensor<1x6x34x62xf32>) -> (tensor<1x3x31x62xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
|
Loading…
Reference in New Issue