forked from OSchip/llvm-project
[mlir][tosa] Add tosa.max_pool2d lowering to linalg int max pooling additions
Lowerings tosa.max_pool2d to linalg equivalent operations. Includes adding max pooling operations for linalg, with corresponding tests. Differential Revision: https://reviews.llvm.org/D99824
This commit is contained in:
parent
4a84b03ece
commit
ceeb5b0f87
|
@ -352,6 +352,51 @@ def pooling_nhwc_sum
|
|||
ow * strides[1] + kw * dilations[1], c));
|
||||
}
|
||||
|
||||
ods_def<PoolingNHWCMaxI8Op>:
|
||||
def pooling_nhwc_i8_max
|
||||
(I: i8(N, H, W, C), K: i8(KH, KW))
|
||||
-> (O: i8(N, OH, OW, C))
|
||||
attr(strides: 2xi64, dilations: 2xi64)
|
||||
{
|
||||
O(n, oh, ow, c) =
|
||||
std_select<kh, kw>(std_cmpi_sgt(I(n, oh * strides[0] + kh * dilations[0],
|
||||
ow * strides[1] + kw * dilations[1], c),
|
||||
O(n, oh, ow, c)),
|
||||
I(n, oh * strides[0] + kh * dilations[0],
|
||||
ow * strides[1] + kw * dilations[1], c),
|
||||
O(n, oh, ow, c));
|
||||
}
|
||||
|
||||
ods_def<PoolingNHWCMaxI16Op>:
|
||||
def pooling_nhwc_i16_max
|
||||
(I: i16(N, H, W, C), K: i16(KH, KW))
|
||||
-> (O: i16(N, OH, OW, C))
|
||||
attr(strides: 2xi64, dilations: 2xi64)
|
||||
{
|
||||
O(n, oh, ow, c) =
|
||||
std_select<kh, kw>(std_cmpi_sgt(I(n, oh * strides[0] + kh * dilations[0],
|
||||
ow * strides[1] + kw * dilations[1], c),
|
||||
O(n, oh, ow, c)),
|
||||
I(n, oh * strides[0] + kh * dilations[0],
|
||||
ow * strides[1] + kw * dilations[1], c),
|
||||
O(n, oh, ow, c));
|
||||
}
|
||||
|
||||
ods_def<PoolingNHWCMaxI32Op>:
|
||||
def pooling_nhwc_i32_max
|
||||
(I: i32(N, H, W, C), K: i32(KH, KW))
|
||||
-> (O: i32(N, OH, OW, C))
|
||||
attr(strides: 2xi64, dilations: 2xi64)
|
||||
{
|
||||
O(n, oh, ow, c) =
|
||||
std_select<kh, kw>(std_cmpi_sgt(I(n, oh * strides[0] + kh * dilations[0],
|
||||
ow * strides[1] + kw * dilations[1], c),
|
||||
O(n, oh, ow, c)),
|
||||
I(n, oh * strides[0] + kh * dilations[0],
|
||||
ow * strides[1] + kw * dilations[1], c),
|
||||
O(n, oh, ow, c));
|
||||
}
|
||||
|
||||
ods_def<PoolingNHWCMaxFOp>:
|
||||
def pooling_nhwc_max
|
||||
(I: f32(N, H, W, C), K: f32(KH, KW))
|
||||
|
|
|
@ -59,6 +59,15 @@ struct CmpFValueBuilder : public ValueBuilder<CmpFOp> {
|
|||
using std_cmpf_ogt = CmpFValueBuilder<CmpFPredicate::OGT>;
|
||||
using std_cmpf_olt = CmpFValueBuilder<CmpFPredicate::OLT>;
|
||||
|
||||
template <CmpIPredicate Predicate>
|
||||
struct CmpIValueBuilder : public ValueBuilder<CmpIOp> {
|
||||
using ValueBuilder<CmpIOp>::ValueBuilder;
|
||||
template <typename... Args>
|
||||
CmpIValueBuilder(Args... args) : ValueBuilder<CmpIOp>(Predicate, args...) {}
|
||||
};
|
||||
|
||||
using std_cmpi_sgt = CmpIValueBuilder<CmpIPredicate::sgt>;
|
||||
|
||||
/// Branches into `block` with `operands`.
|
||||
BranchOp std_br(Block *block, ValueRange operands);
|
||||
|
||||
|
|
|
@ -1230,6 +1230,22 @@ public:
|
|||
"Pad converter requires static shaped input / padding values.");
|
||||
}
|
||||
|
||||
Attribute constantAttr;
|
||||
if (elementTy.isa<FloatType>())
|
||||
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
|
||||
else if (elementTy.isa<IntegerType>() && !padOp.quantization_info())
|
||||
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
|
||||
else if (elementTy.isa<IntegerType>() && padOp.quantization_info()) {
|
||||
auto value = padOp.quantization_info().getValue().input_zp().getValue();
|
||||
constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue());
|
||||
}
|
||||
|
||||
if (!constantAttr) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
padOp,
|
||||
"tosa.pad to linalg lowering encountered an unknown element type");
|
||||
}
|
||||
|
||||
Value lowIndex = rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(0));
|
||||
Value highIndex =
|
||||
rewriter.create<ConstantOp>(loc, rewriter.getIndexAttr(1));
|
||||
|
@ -1256,22 +1272,6 @@ public:
|
|||
highValues.push_back(highVal);
|
||||
}
|
||||
|
||||
Attribute constantAttr;
|
||||
if (elementTy.isa<FloatType>())
|
||||
constantAttr = rewriter.getFloatAttr(elementTy, 0.0);
|
||||
else if (elementTy.isa<IntegerType>() && !padOp.quantization_info())
|
||||
constantAttr = rewriter.getIntegerAttr(elementTy, 0);
|
||||
else if (elementTy.isa<IntegerType>() && padOp.quantization_info()) {
|
||||
auto value = padOp.quantization_info().getValue().input_zp().getValue();
|
||||
constantAttr = rewriter.getIntegerAttr(elementTy, value.getZExtValue());
|
||||
}
|
||||
|
||||
if (!constantAttr) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
padOp,
|
||||
"tosa.pad to linalg lowering encountered an unknown element type");
|
||||
}
|
||||
|
||||
Value constant = rewriter.create<ConstantOp>(loc, constantAttr);
|
||||
|
||||
auto newPadOp = linalg::PadTensorOp::createPadScalarOp(
|
||||
|
@ -1523,6 +1523,128 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
|
||||
public:
|
||||
using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
|
||||
PatternRewriter &rewriter) const final {
|
||||
Location loc = op.getLoc();
|
||||
Value input = op.input();
|
||||
ShapedType inputTy = input.getType().cast<ShapedType>();
|
||||
Type inElementTy = inputTy.getElementType();
|
||||
|
||||
ShapedType resultTy = op.getType().cast<ShapedType>();
|
||||
Type outElementTy = inputTy.getElementType();
|
||||
int64_t rank = inputTy.getRank();
|
||||
|
||||
if (!inputTy.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
// Determine what the initial value needs to be for the max pool op.
|
||||
Attribute initialAttr;
|
||||
if (outElementTy.isF32())
|
||||
initialAttr = rewriter.getFloatAttr(
|
||||
outElementTy,
|
||||
APFloat::getLargest(
|
||||
outElementTy.cast<FloatType>().getFloatSemantics(), true));
|
||||
|
||||
if (outElementTy.isa<IntegerType>())
|
||||
initialAttr = rewriter.getIntegerAttr(
|
||||
outElementTy,
|
||||
APInt::getSignedMinValue(outElementTy.getIntOrFloatBitWidth()));
|
||||
|
||||
if (!initialAttr)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Unsupported initial value for tosa.maxpool_2d op");
|
||||
|
||||
Value initialValue = rewriter.create<ConstantOp>(loc, initialAttr);
|
||||
|
||||
SmallVector<int64_t> kernel, stride, pad;
|
||||
getValuesFromIntArrayAttribute(op.kernel(), kernel);
|
||||
getValuesFromIntArrayAttribute(op.stride(), stride);
|
||||
getValuesFromIntArrayAttribute(op.pad(), pad);
|
||||
|
||||
Attribute strideAttr = rewriter.getI64VectorAttr(stride);
|
||||
Attribute dilationAttr = rewriter.getI64VectorAttr({1, 1});
|
||||
|
||||
// If non-zero padding we need to pad the input
|
||||
if (llvm::any_of(pad, [](int64_t v) { return v != 0; })) {
|
||||
SmallVector<int64_t, 4> paddedShape;
|
||||
for (int64_t i = 0; i < rank; i++)
|
||||
paddedShape.push_back(inputTy.getDimSize(i));
|
||||
|
||||
paddedShape[1] += pad[0] + pad[1];
|
||||
paddedShape[2] += pad[2] + pad[3];
|
||||
|
||||
OpFoldResult zeroIndex = rewriter.getIndexAttr(0);
|
||||
OpFoldResult heightLowPadIndex = rewriter.getIndexAttr(pad[0]);
|
||||
OpFoldResult heightHighPadIndex = rewriter.getIndexAttr(pad[1]);
|
||||
OpFoldResult widthLowPadIndex = rewriter.getIndexAttr(pad[2]);
|
||||
OpFoldResult widthHighPadIndex = rewriter.getIndexAttr(pad[3]);
|
||||
|
||||
SmallVector<OpFoldResult, 4> lowIndices = {zeroIndex, heightLowPadIndex,
|
||||
widthLowPadIndex, zeroIndex};
|
||||
SmallVector<OpFoldResult, 4> highIndices = {zeroIndex, heightHighPadIndex,
|
||||
widthHighPadIndex, zeroIndex};
|
||||
|
||||
input = linalg::PadTensorOp::createPadScalarOp(
|
||||
RankedTensorType::get(paddedShape, inElementTy), input,
|
||||
initialValue, lowIndices, highIndices, loc, rewriter)
|
||||
.result();
|
||||
}
|
||||
|
||||
Value initTensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, resultTy.getShape(), resultTy.getElementType());
|
||||
|
||||
Value filledInitTensor =
|
||||
rewriter.create<linalg::FillOp>(loc, initTensor, initialValue).result();
|
||||
|
||||
Value fakeWindowDims =
|
||||
rewriter.create<linalg::InitTensorOp>(loc, kernel, outElementTy);
|
||||
|
||||
auto createOp = [&](auto *typePtr) -> linalg::LinalgOp {
|
||||
return cast<linalg::LinalgOp>(
|
||||
rewriter
|
||||
.create<std::remove_pointer_t<decltype(typePtr)>>(
|
||||
loc, ArrayRef<Type>{resultTy},
|
||||
ValueRange{input, fakeWindowDims}, filledInitTensor,
|
||||
dilationAttr, strideAttr)
|
||||
.getOperation());
|
||||
};
|
||||
|
||||
if (inElementTy.isF32()) {
|
||||
linalg::LinalgOp poolingOp =
|
||||
createOp(static_cast<linalg::PoolingNHWCMaxFOp *>(nullptr));
|
||||
rewriter.replaceOp(op, poolingOp->getResult(0));
|
||||
return success();
|
||||
}
|
||||
|
||||
if (inElementTy.isInteger(8)) {
|
||||
linalg::LinalgOp poolingOp =
|
||||
createOp(static_cast<linalg::PoolingNHWCMaxI8Op *>(nullptr));
|
||||
rewriter.replaceOp(op, poolingOp->getResult(0));
|
||||
return success();
|
||||
}
|
||||
|
||||
if (inElementTy.isInteger(16)) {
|
||||
linalg::LinalgOp poolingOp =
|
||||
createOp(static_cast<linalg::PoolingNHWCMaxI16Op *>(nullptr));
|
||||
rewriter.replaceOp(op, poolingOp->getResult(0));
|
||||
return success();
|
||||
}
|
||||
|
||||
if (inElementTy.isInteger(32)) {
|
||||
linalg::LinalgOp poolingOp =
|
||||
createOp(static_cast<linalg::PoolingNHWCMaxI32Op *>(nullptr));
|
||||
rewriter.replaceOp(op, poolingOp->getResult(0));
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
|
||||
|
@ -1579,6 +1701,7 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
|
|||
TileConverter,
|
||||
TransposeConverter,
|
||||
MatMulConverter,
|
||||
MaxPool2dConverter,
|
||||
FullyConnectedConverter>(patterns->getContext());
|
||||
// clang-format on
|
||||
}
|
||||
|
|
|
@ -873,3 +873,53 @@ func @table16(%arg0: tensor<6xi16>, %arg1: tensor<513xi16>) -> () {
|
|||
%0 = "tosa.table"(%arg0, %arg1) : (tensor<6xi16>, tensor<513xi16>) -> (tensor<6xi32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @max_pool
|
||||
func @max_pool(%arg0: tensor<1x6x34x62xf32>) -> () {
|
||||
// CHECK-DAG: [[CONST:%.+]] = constant -3.40282347E+38
|
||||
// CHECK-DAG: [[INIT:%.+]] = linalg.init_tensor [1, 4, 32, 62]
|
||||
// CHECK-DAG: [[FILL:%.+]] = linalg.fill([[INIT]], [[CONST]])
|
||||
// CHECK-DAG: [[KERNEL:%.+]] = linalg.init_tensor [3, 3]
|
||||
// CHECK: linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%arg0, [[KERNEL]] : tensor<1x6x34x62xf32>, tensor<3x3xf32>) outs([[FILL]] : tensor<1x4x32x62xf32>)
|
||||
%0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xf32>) -> (tensor<1x4x32x62xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @max_pool_padded
|
||||
func @max_pool_padded(%arg0: tensor<1x6x34x62xf32>) -> () {
|
||||
// CHECK-DAG: [[CONST:%.+]] = constant -3.40282347E+38 : f32
|
||||
// CHECK-DAG: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 0, 0, 0] high[0, 0, 1, 0]
|
||||
// CHECK-DAG: linalg.yield [[CONST]]
|
||||
// CHECK-DAG: [[INIT:%.+]] = linalg.init_tensor [1, 4, 33, 62]
|
||||
// CHECK-DAG: [[FILL:%.+]] = linalg.fill([[INIT]], [[CONST]])
|
||||
// CHECK-DAG: [[KERNEL:%.+]] = linalg.init_tensor [3, 3]
|
||||
// CHECK: linalg.pooling_nhwc_max {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins([[PAD]], [[KERNEL]] : tensor<1x6x35x62xf32>, tensor<3x3xf32>) outs([[FILL]] : tensor<1x4x33x62xf32>)
|
||||
%0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 1], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xf32>) -> (tensor<1x4x33x62xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @max_pool_i8
|
||||
func @max_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> () {
|
||||
// CHECK: constant -128
|
||||
// CHECK: linalg.pooling_nhwc_i8_max
|
||||
%0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xi8>) -> (tensor<1x4x32x62xi8>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @max_pool_i16
|
||||
func @max_pool_i16(%arg0: tensor<1x6x34x62xi16>) -> () {
|
||||
// CHECK: constant -32768
|
||||
// CHECK: linalg.pooling_nhwc_i16_max
|
||||
%0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xi16>) -> (tensor<1x4x32x62xi16>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @max_pool_i32
|
||||
func @max_pool_i32(%arg0: tensor<1x6x34x62xi32>) -> () {
|
||||
// CHECK: constant -2147483648
|
||||
// CHECK: linalg.pooling_nhwc_i32_max
|
||||
%0 = "tosa.max_pool2d"(%arg0) {pad = [0, 0, 0, 0], kernel = [3, 3], stride = [1, 1]} : (tensor<1x6x34x62xi32>) -> (tensor<1x4x32x62xi32>)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -340,6 +340,84 @@ func @pooling_nhwc_max(%input: memref<?x?x?x?xf32>, %fake: memref<2x3xf32>, %ini
|
|||
|
||||
// -----
|
||||
|
||||
func @pooling_nhwc_i8_max(%input: memref<?x?x?x?xi8>, %fake: memref<2x3xi8>, %init: memref<?x?x?x?xi8>) {
|
||||
linalg.pooling_nhwc_i8_max {dilations = dense<1> : tensor<2xi64>, strides = dense<[2, 3]> : tensor<2xi64>}
|
||||
ins(%input, %fake: memref<?x?x?x?xi8>, memref<2x3xi8>)
|
||||
outs(%init: memref<?x?x?x?xi8>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 3 + d5, d3)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
|
||||
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
|
||||
|
||||
// CHECK: func @pooling_nhwc_i8_max
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?x?xi8>, memref<2x3xi8>)
|
||||
// CHECK-SAME: outs(%{{.+}} : memref<?x?x?x?xi8>)
|
||||
|
||||
// CHECK: ^{{.+}}(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i8)
|
||||
// CHECK-NEXT: %[[CMP:.+]] = cmpi sgt, %[[BBARG0]], %[[BBARG2]] : i8
|
||||
// CHECK-NEXT: %[[RES:.+]] = select %[[CMP]], %[[BBARG0]], %[[BBARG2]] : i8
|
||||
// CHECK-NEXT: linalg.yield %[[RES]] : i8
|
||||
|
||||
// -----
|
||||
|
||||
func @pooling_nhwc_i16_max(%input: memref<?x?x?x?xi16>, %fake: memref<2x3xi16>, %init: memref<?x?x?x?xi16>) {
|
||||
linalg.pooling_nhwc_i16_max {dilations = dense<1> : tensor<2xi64>, strides = dense<[2, 3]> : tensor<2xi64>}
|
||||
ins(%input, %fake: memref<?x?x?x?xi16>, memref<2x3xi16>)
|
||||
outs(%init: memref<?x?x?x?xi16>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 3 + d5, d3)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
|
||||
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
|
||||
|
||||
// CHECK: func @pooling_nhwc_i16_max
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?x?xi16>, memref<2x3xi16>)
|
||||
// CHECK-SAME: outs(%{{.+}} : memref<?x?x?x?xi16>)
|
||||
|
||||
// CHECK: ^{{.+}}(%[[BBARG0:.+]]: i16, %[[BBARG1:.+]]: i16, %[[BBARG2:.+]]: i16)
|
||||
// CHECK-NEXT: %[[CMP:.+]] = cmpi sgt, %[[BBARG0]], %[[BBARG2]] : i16
|
||||
// CHECK-NEXT: %[[RES:.+]] = select %[[CMP]], %[[BBARG0]], %[[BBARG2]] : i16
|
||||
// CHECK-NEXT: linalg.yield %[[RES]] : i16
|
||||
|
||||
// -----
|
||||
|
||||
func @pooling_nhwc_i32_max(%input: memref<?x?x?x?xi32>, %fake: memref<2x3xi32>, %init: memref<?x?x?x?xi32>) {
|
||||
linalg.pooling_nhwc_i32_max {dilations = dense<1> : tensor<2xi64>, strides = dense<[2, 3]> : tensor<2xi64>}
|
||||
ins(%input, %fake: memref<?x?x?x?xi32>, memref<2x3xi32>)
|
||||
outs(%init: memref<?x?x?x?xi32>)
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 * 2 + d4, d2 * 3 + d5, d3)>
|
||||
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
|
||||
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
|
||||
|
||||
// CHECK: func @pooling_nhwc_i32_max
|
||||
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?x?xi32>, memref<2x3xi32>)
|
||||
// CHECK-SAME: outs(%{{.+}} : memref<?x?x?x?xi32>)
|
||||
|
||||
// CHECK: ^{{.+}}(%[[BBARG0:.+]]: i32, %[[BBARG1:.+]]: i32, %[[BBARG2:.+]]: i32)
|
||||
// CHECK-NEXT: %[[CMP:.+]] = cmpi sgt, %[[BBARG0]], %[[BBARG2]] : i32
|
||||
// CHECK-NEXT: %[[RES:.+]] = select %[[CMP]], %[[BBARG0]], %[[BBARG2]] : i32
|
||||
// CHECK-NEXT: linalg.yield %[[RES]] : i32
|
||||
|
||||
// -----
|
||||
|
||||
func @pooling_nhwc_min(%input: memref<?x?x?x?xf32>, %fake: memref<2x3xf32>, %init: memref<?x?x?x?xf32>) {
|
||||
linalg.pooling_nhwc_min {dilations = dense<3> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
|
||||
ins(%input, %fake: memref<?x?x?x?xf32>, memref<2x3xf32>)
|
||||
|
|
|
@ -344,6 +344,109 @@ func @pooling_nhwc_max(%input: memref<1x4x4x1xf32>, %fake: memref<3x3xf32>, %out
|
|||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @pooling_nhwc_i8_max_tensor
|
||||
// CHECK: %{{.+}} = linalg.pooling_nhwc_i8_max
|
||||
// CHECK-SAME: dilations = dense<1> : tensor<2xi64>
|
||||
// CHECK-SAME: strides = dense<1> : tensor<2xi64>
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi8>, tensor<3x3xi8>)
|
||||
// CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xi8>) -> tensor<1x2x2x1xi8>
|
||||
func @pooling_nhwc_i8_max_tensor(%input: tensor<1x4x4x1xi8>) -> tensor<1x2x2x1xi8> {
|
||||
%fake = linalg.init_tensor [3, 3] : tensor<3x3xi8>
|
||||
%init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xi8>
|
||||
%cst = constant 0 : i8
|
||||
%fill = linalg.fill(%init, %cst) : tensor<1x2x2x1xi8>, i8 -> tensor<1x2x2x1xi8>
|
||||
%res = linalg.pooling_nhwc_i8_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
|
||||
ins(%input, %fake: tensor<1x4x4x1xi8>, tensor<3x3xi8>)
|
||||
outs(%fill: tensor<1x2x2x1xi8>) -> tensor<1x2x2x1xi8>
|
||||
return %res : tensor<1x2x2x1xi8>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @pooling_nhwc_i8_max
|
||||
// CHECK: linalg.pooling_nhwc_i8_max
|
||||
// CHECK-SAME: dilations = dense<1> : tensor<2xi64>
|
||||
// CHECK-SAME: strides = dense<1> : tensor<2xi64>
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x4x4x1xi8>, memref<3x3xi8>)
|
||||
// CHECK-SAME: outs(%{{.+}} : memref<1x2x2x1xi8>)
|
||||
func @pooling_nhwc_i8_max(%input: memref<1x4x4x1xi8>, %fake: memref<3x3xi8>, %output: memref<1x2x2x1xi8>) {
|
||||
linalg.pooling_nhwc_i8_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
|
||||
ins(%input, %fake: memref<1x4x4x1xi8>, memref<3x3xi8>)
|
||||
outs(%output: memref<1x2x2x1xi8>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @pooling_nhwc_i16_max_tensor
|
||||
// CHECK: %{{.+}} = linalg.pooling_nhwc_i16_max
|
||||
// CHECK-SAME: dilations = dense<1> : tensor<2xi64>
|
||||
// CHECK-SAME: strides = dense<1> : tensor<2xi64>
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi16>, tensor<3x3xi16>)
|
||||
// CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xi16>) -> tensor<1x2x2x1xi16>
|
||||
func @pooling_nhwc_i16_max_tensor(%input: tensor<1x4x4x1xi16>) -> tensor<1x2x2x1xi16> {
|
||||
%fake = linalg.init_tensor [3, 3] : tensor<3x3xi16>
|
||||
%init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xi16>
|
||||
%cst = constant 0 : i16
|
||||
%fill = linalg.fill(%init, %cst) : tensor<1x2x2x1xi16>, i16 -> tensor<1x2x2x1xi16>
|
||||
%res = linalg.pooling_nhwc_i16_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
|
||||
ins(%input, %fake: tensor<1x4x4x1xi16>, tensor<3x3xi16>)
|
||||
outs(%fill: tensor<1x2x2x1xi16>) -> tensor<1x2x2x1xi16>
|
||||
return %res : tensor<1x2x2x1xi16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @pooling_nhwc_i16_max
|
||||
// CHECK: linalg.pooling_nhwc_i16_max
|
||||
// CHECK-SAME: dilations = dense<1> : tensor<2xi64>
|
||||
// CHECK-SAME: strides = dense<1> : tensor<2xi64>
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x4x4x1xi16>, memref<3x3xi16>)
|
||||
// CHECK-SAME: outs(%{{.+}} : memref<1x2x2x1xi16>)
|
||||
func @pooling_nhwc_i16_max(%input: memref<1x4x4x1xi16>, %fake: memref<3x3xi16>, %output: memref<1x2x2x1xi16>) {
|
||||
linalg.pooling_nhwc_i16_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
|
||||
ins(%input, %fake: memref<1x4x4x1xi16>, memref<3x3xi16>)
|
||||
outs(%output: memref<1x2x2x1xi16>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @pooling_nhwc_i32_max_tensor
|
||||
// CHECK: %{{.+}} = linalg.pooling_nhwc_i32_max
|
||||
// CHECK-SAME: dilations = dense<1> : tensor<2xi64>
|
||||
// CHECK-SAME: strides = dense<1> : tensor<2xi64>
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x4x4x1xi32>, tensor<3x3xi32>)
|
||||
// CHECK-SAME: outs(%{{.+}} : tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
|
||||
func @pooling_nhwc_i32_max_tensor(%input: tensor<1x4x4x1xi32>) -> tensor<1x2x2x1xi32> {
|
||||
%fake = linalg.init_tensor [3, 3] : tensor<3x3xi32>
|
||||
%init = linalg.init_tensor [1, 2, 2, 1] : tensor<1x2x2x1xi32>
|
||||
%cst = constant 0 : i32
|
||||
%fill = linalg.fill(%init, %cst) : tensor<1x2x2x1xi32>, i32 -> tensor<1x2x2x1xi32>
|
||||
%res = linalg.pooling_nhwc_i32_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
|
||||
ins(%input, %fake: tensor<1x4x4x1xi32>, tensor<3x3xi32>)
|
||||
outs(%fill: tensor<1x2x2x1xi32>) -> tensor<1x2x2x1xi32>
|
||||
return %res : tensor<1x2x2x1xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @pooling_nhwc_i32_max
|
||||
// CHECK: linalg.pooling_nhwc_i32_max
|
||||
// CHECK-SAME: dilations = dense<1> : tensor<2xi64>
|
||||
// CHECK-SAME: strides = dense<1> : tensor<2xi64>
|
||||
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<1x4x4x1xi32>, memref<3x3xi32>)
|
||||
// CHECK-SAME: outs(%{{.+}} : memref<1x2x2x1xi32>)
|
||||
func @pooling_nhwc_i32_max(%input: memref<1x4x4x1xi32>, %fake: memref<3x3xi32>, %output: memref<1x2x2x1xi32>) {
|
||||
linalg.pooling_nhwc_i32_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
|
||||
ins(%input, %fake: memref<1x4x4x1xi32>, memref<3x3xi32>)
|
||||
outs(%output: memref<1x2x2x1xi32>)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @pooling_nhwc_min_tensor
|
||||
|
|
Loading…
Reference in New Issue