[mlir][Linalg] Introduce linalg.pooling_min/max/sum op.

Summary:
Performs an N-D pooling operation similarly to the description in the TF
documentation:
https://www.tensorflow.org/api_docs/python/tf/nn/pool

Different from the description, this operation doesn't perform on batch and
channel. It only takes tensors of rank `N`.

```
  output[x[0], ..., x[N-1]] =
    REDUCE_{z[0], ..., z[N-1]}
      input[
            x[0] * strides[0] - pad_before[0] + dilation_rate[0]*z[0],
            ...
            x[N-1]*strides[N-1] - pad_before[N-1] + dilation_rate[N-1]*z[N-1]
            ],
```

The required optional arguments are:
  - strides: an i64 array specifying the stride (i.e. step) for window
    loops.
  - dilations: an i64 array specifying the filter upsampling/input
    downsampling rate
  - padding: an i64 array of pairs (low, high) specifying the number of
    elements to pad along a dimension.

If strides or dilations attributes are missing then the default value is
one for each of the input dimensions. Similarly, padding values are zero
for both low and high in each of the dimensions, if not specified.

Differential Revision: https://reviews.llvm.org/D76414
This commit is contained in:
Hanhan Wang 2020-03-31 21:21:33 -07:00
parent bb3111cbaf
commit 69ddee1d2a
9 changed files with 444 additions and 45 deletions

View File

@ -29,6 +29,9 @@ namespace mlir {
namespace linalg {
class ConvOp;
class PoolingMaxOp;
class PoolingMinOp;
class PoolingSumOp;
/// Returns the name mangled library call name to disambiguate between different
/// overloads at the C level. The name mangling scheme is basic and uses MLIR
@ -60,12 +63,13 @@ std::string generateLibraryCallName(Operation *op);
SmallVector<AffineExpr, 4> makeAffineDimExprs(unsigned num, unsigned &startIdx,
MLIRContext *context);
/// Builds the indexing expressions for a ConvOp `op`. Returns the vector of
/// AffineMaps representing:
/// `stride[i] * xs[i] + dilation[i] * zs[i] - pad_low[i]`
SmallVector<AffineExpr, 4> weightedConvInputIndex(ConvOp op,
ArrayRef<AffineExpr> xs,
ArrayRef<AffineExpr> zs);
/// Builds the indexing expressions for a ConvOp/PoolingOp `op`. Returns the
/// vector of AffineMaps representing:
/// `stride[i] * outputDims[i] + dilation[i] * windowDims[i] - pad_low[i]`
template <typename PoolingOp>
extern SmallVector<AffineExpr, 4>
weightedPoolingInputIndex(PoolingOp op, ArrayRef<AffineExpr> outputDims,
ArrayRef<AffineExpr> windowDims);
/// Returns `maybeMap.get()` if `maybeMap` is set, otherwise returns the
/// symbol-less identity map of `rank`.

View File

@ -251,7 +251,69 @@ def MatmulOp : LinalgStructured_Op<"matmul", [NInputs<2>, NOutputs<1>]> {
let hasFolder = 1;
}
def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
/// A base class for pooling operation such as conv. The arguments must contain
/// optional arguments `strides`, `dilations` and `padding` with following type:
/// OptionalAttr<I64ArrayAttr>:$strides
/// OptionalAttr<I64ArrayAttr>:$dilations
/// OptionalAttr<I64ElementsAttr>:$padding
/// `stirdes` denotes the step of each window along the dimension.
class PoolingBase_Op<string mnemonic, list<OpTrait> props>
: LinalgStructured_Op<mnemonic, props> {
let description = [{
Performs an N-D pooling operation similarly to the description in the TF
documentation:
https://www.tensorflow.org/api_docs/python/tf/nn/pool
Different from the description, this operation doesn't perform on batch and
channel. It only takes tensors of rank `N`.
```
output[x[0], ..., x[N-1]] =
REDUCE_{z[0], ..., z[N-1]}
input[
x[0] * strides[0] - pad_before[0] + dilation_rate[0]*z[0],
...
x[N-1]*strides[N-1] - pad_before[N-1] + dilation_rate[N-1]*z[N-1]
],
```
The required optional arguments are:
- strides: an i64 array specifying the stride (i.e. step) for window
loops.
- dilations: an i64 array specifying the filter upsampling/input
downsampling rate
- padding: an i64 array of pairs (low, high) specifying the number of
elements to pad along a dimension.
If strides or dilations attributes are missing then the default value is
one for each of the input dimensions. Similarly, padding values are zero
for both low and high in each of the dimensions, if not specified.
}];
code commonUtils = libraryCallName # [{
int64_t getStride(unsigned i) {
assert(i < getNumWindowLoops());
if (!strides().hasValue()) return 1;
return strides()->getValue()[i]
.cast<IntegerAttr>().getValue().getSExtValue();
}
int64_t getDilation(unsigned i) {
assert(i < getNumWindowLoops());
if (!dilations().hasValue()) return 1;
return dilations()->getValue()[i]
.cast<IntegerAttr>().getValue().getSExtValue();
}
int64_t getLowPad(unsigned i) {
assert(i < getNumWindowLoops());
if (!padding().hasValue()) return 0;
return padding().getValue().getValue<int64_t>({i, 0});
}
}];
}
def ConvOp : PoolingBase_Op<"conv", [NInputs<2>, NOutputs<1>]> {
let description = [{
Generic n-D convolution as described in the TF documentation:
@ -282,7 +344,7 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
OptionalAttr<I64ArrayAttr>:$dilations,
OptionalAttr<I64ElementsAttr>:$padding);
let extraClassDeclaration = libraryCallName # [{
let extraClassDeclaration = commonUtils # [{
// TODO(ntv) extend to support more than 1 dimensions and potentially
// grouping too.
unsigned getNumBatchDimensions() { return 1; }
@ -309,26 +371,6 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
return iters;
}
int64_t getStride(unsigned i) {
assert(i < getNumWindowLoops());
if (!strides().hasValue()) return 1;
return strides()->getValue()[i]
.cast<IntegerAttr>().getValue().getSExtValue();
}
int64_t getDilation(unsigned i) {
assert(i < getNumWindowLoops());
if (!dilations().hasValue()) return 1;
return dilations()->getValue()[i]
.cast<IntegerAttr>().getValue().getSExtValue();
}
int64_t getLowPad(unsigned i) {
assert(i < getNumWindowLoops());
if (!padding().hasValue()) return 0;
return padding().getValue().getValue<int64_t>({i, 0});
}
// F(z0, ..., zN-1, q, k) *
// I(b, x0 + z0 - pad_low_0, ..., xN-1 + zN-1 - pad_low_N-1, q)
// -> O(b, x0, ..., xN-1, k)
@ -358,7 +400,7 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
// Window reduction dims: sum_{z[0], ..., z[N-1], q}
auto zs = makeAffineDimExprs(nWin, idx, context);
// Construct the weighedSum expression.
auto ws = weightedConvInputIndex(*this, xs, zs);
auto ws = weightedPoolingInputIndex(*this, xs, zs);
return SmallVector<AffineMap, 8>{
// filter[z[0], ..., z[N-1], q, k]
AffineMap::get(idx, 0, concat(concat(zs, qs), ks)),
@ -378,6 +420,86 @@ def ConvOp : LinalgStructured_Op<"conv", [NInputs<2>, NOutputs<1>]> {
let hasFolder = 1;
}
class SingleInputPoolingBase_Op<string mnemonic>
: PoolingBase_Op<mnemonic, [NInputs<2>, NOutputs<1>]> {
let description = [{
A base class for single input pooling function.
TODO: Figure out a better way to handle window dimensions, i.e., eliminate
the fake memref.
The window dimensions are specified by argument `windowDims`. The i-th
dimension in the shape of `windowDims` denotes the size of the window along
dimension i. For example, if the window size is 2x3, then a memref<2x3>
should be passed to the operation as `windowDims`.
}];
let arguments = (ins AnyStridedMemRef:$input,
AnyStridedMemRef:$windowDims,
AnyStridedMemRef:$output,
OptionalAttr<I64ArrayAttr>:$strides,
OptionalAttr<I64ArrayAttr>:$dilations,
OptionalAttr<I64ElementsAttr>:$padding);
let extraClassDeclaration = commonUtils# [{
llvm::Optional<SmallVector<StringRef, 8>> referenceIterators() {
// Outer parallel loops are always the number of output dimensions.
unsigned nPar = getOutputShapedType(0).getRank();
// The window loops has the same number loops with output dimensions.
unsigned nWin = nPar;
SmallVector<StringRef, 8> iters(nPar, getParallelIteratorTypeName());
iters.reserve(nPar + nWin);
iters.append(nWin, getWindowIteratorTypeName());
return iters;
}
llvm::Optional<SmallVector<AffineMap, 8>> referenceIndexingMaps() {
MLIRContext *context = getContext();
auto nPar = getNumParallelLoops();
auto nWin = getNumWindowLoops();
assert(nWin > 0 && "expected at least one window dimension");
unsigned idx = 0;
auto outputDims = makeAffineDimExprs(nPar, idx, context);
auto windowDims = makeAffineDimExprs(nWin, idx, context);
// Construct the weighedSum expression.
auto inputDims =
weightedPoolingInputIndex(*this, outputDims, windowDims);
return SmallVector<AffineMap, 8>{
// input
AffineMap::get(idx, 0, inputDims),
// windowDims
AffineMap::get(idx, 0, windowDims),
// output
AffineMap::get(idx, 0, outputDims)
};
}
}];
let verifier = [{ return ::verify(*this); }];
let hasFolder = 1;
}
def PoolingMaxOp: SingleInputPoolingBase_Op<"pooling_max"> {
let description = [{
Takes max op as pooling operation, i.e., it samples the maximum value in the
window.
}];
}
def PoolingMinOp: SingleInputPoolingBase_Op<"pooling_min"> {
let description = [{
Takes min op as pooling operation, i.e., it samples the minimum value in the
window.
}];
}
def PoolingSumOp: SingleInputPoolingBase_Op<"pooling_sum"> {
let description = [{
Takes add op as pooling operation, i.e., it accumulates the values in the
window.
}];
}
//===----------------------------------------------------------------------===//
// Generic Linalg ops.
//===----------------------------------------------------------------------===//

View File

@ -72,6 +72,15 @@ constexpr StringRef getFunAttrName() { return "fun"; }
/// function that implements the structured op.
constexpr StringRef getLibraryCallAttrName() { return "library_call"; }
/// Attribute name for the StrArrayAttr which encodes the value of strides.
constexpr StringRef getStridesAttrName() { return "strides"; }
/// Attribute name for the StrArrayAttr which encodes the value of dilations.
constexpr StringRef getDilationsAttrName() { return "dilations"; }
/// Attribute name for the StrArrayAttr which encodes the value of paddings.
constexpr StringRef getPaddingAttrName() { return "padding"; }
/// Use to encode that a particular iterator type has parallel semantics.
constexpr StringRef getParallelIteratorTypeName() { return "parallel"; }

View File

@ -524,12 +524,21 @@ populateLinalgToStandardConversionPatterns(OwningRewritePatternList &patterns,
MLIRContext *ctx) {
// TODO(ntv) ConvOp conversion needs to export a descriptor with relevant
// attribute values such as kernel striding and dilation.
patterns.insert<CopyTransposeConversion, LinalgOpConversion<ConvOp>,
LinalgOpConversion<CopyOp>, LinalgOpConversion<DotOp>,
LinalgOpConversion<FillOp>, LinalgOpConversion<GenericOp>,
LinalgOpConversion<IndexedGenericOp>,
LinalgOpConversion<MatmulOp>, LinalgOpConversion<MatvecOp>>(
ctx);
// clang-format off
patterns.insert<
CopyTransposeConversion,
LinalgOpConversion<ConvOp>,
LinalgOpConversion<PoolingMaxOp>,
LinalgOpConversion<PoolingMinOp>,
LinalgOpConversion<PoolingSumOp>,
LinalgOpConversion<CopyOp>,
LinalgOpConversion<DotOp>,
LinalgOpConversion<FillOp>,
LinalgOpConversion<GenericOp>,
LinalgOpConversion<IndexedGenericOp>,
LinalgOpConversion<MatmulOp>,
LinalgOpConversion<MatvecOp>>(ctx);
// clang-format on
}
} // namespace

View File

@ -140,7 +140,6 @@ static void printGenericOp(OpAsmPrinter &p, GenericOpType op) {
p.printRegion(op.region());
p.printOptionalAttrDict(op.getAttrs(), attrNames);
p << ": " << op.getOperandTypes();
auto outputTensorTypes = op.getResultTypes();
if (!outputTensorTypes.empty())
p << " -> " << outputTensorTypes;
@ -827,8 +826,10 @@ static LogicalResult verify(CopyOp op) {
return success();
}
static LogicalResult
verifyStrideOrDilation(ConvOp op, ArrayRef<Attribute> attrs, bool isStride) {
template <typename LinalgPoolingOp>
static LogicalResult verifyStrideOrDilation(LinalgPoolingOp op,
ArrayRef<Attribute> attrs,
bool isStride) {
auto strideOrDilation = isStride ? "stride" : "dilation";
if (attrs.size() != op.getNumWindowLoops())
return op.emitOpError("expects num ")
@ -860,6 +861,41 @@ static LogicalResult verify(ConvOp op) {
return success();
}
template <typename PoolingOp>
LogicalResult verifySingleInputPoolingOp(PoolingOp op) {
auto inputType = op.input().getType().template cast<MemRefType>();
auto outputType = op.output().getType().template cast<MemRefType>();
if (outputType.getElementType() != inputType.getElementType())
return op.emitOpError("expects memref elemental types to match");
auto windowDimsType = op.windowDims().getType().template cast<MemRefType>();
if (outputType.getRank() != inputType.getRank() ||
outputType.getRank() != windowDimsType.getRank())
return op.emitOpError("expects memref ranks to match");
if (auto strides = op.strides()) {
if (failed(
verifyStrideOrDilation(op, strides->getValue(), /*isStride=*/true)))
return failure();
}
if (auto dilations = op.dilations()) {
if (failed(verifyStrideOrDilation(op, dilations->getValue(),
/*isStride=*/false)))
return failure();
}
return success();
}
static LogicalResult verify(PoolingMaxOp op) {
return verifySingleInputPoolingOp(op);
}
static LogicalResult verify(PoolingMinOp op) {
return verifySingleInputPoolingOp(op);
}
static LogicalResult verify(PoolingSumOp op) {
return verifySingleInputPoolingOp(op);
}
namespace mlir {
namespace linalg {
@ -894,21 +930,34 @@ mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
return res;
}
template <typename PoolingOp>
SmallVector<AffineExpr, 4>
mlir::linalg::weightedConvInputIndex(ConvOp op, ArrayRef<AffineExpr> xs,
ArrayRef<AffineExpr> zs) {
assert(xs.size() == zs.size());
mlir::linalg::weightedPoolingInputIndex(PoolingOp op,
ArrayRef<AffineExpr> outputDims,
ArrayRef<AffineExpr> windowDims) {
assert(outputDims.size() == windowDims.size());
SmallVector<AffineExpr, 4> res;
res.reserve(xs.size());
for (unsigned i = 0, e = xs.size(); i < e; ++i) {
res.reserve(outputDims.size());
for (unsigned i = 0, e = outputDims.size(); i < e; ++i) {
// TODO(ntv): add a level of indirection to linalg.generic.
auto expr =
op.getStride(i) * xs[i] + op.getDilation(i) * zs[i] - op.getLowPad(i);
auto expr = op.getStride(i) * outputDims[i] +
op.getDilation(i) * windowDims[i] - op.getLowPad(i);
res.push_back(expr);
}
return res;
}
#define INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(OP_TYPE) \
template SmallVector<AffineExpr, 4> \
mlir::linalg::weightedPoolingInputIndex<OP_TYPE>( \
OP_TYPE op, ArrayRef<AffineExpr> outputDims, \
ArrayRef<AffineExpr> windowDims);
INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(ConvOp)
INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(PoolingMaxOp)
INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(PoolingMinOp)
INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(PoolingSumOp)
SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a,
ArrayRef<AffineExpr> b) {
auto rangeA = llvm::make_range(a.begin(), a.end());
@ -959,6 +1008,18 @@ LogicalResult ConvOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}
LogicalResult PoolingMaxOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}
LogicalResult PoolingMinOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}
LogicalResult PoolingSumOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);
}
LogicalResult CopyOp::fold(ArrayRef<Attribute>,
SmallVectorImpl<OpFoldResult> &) {
return foldMemRefCast(*this);

View File

@ -106,6 +106,23 @@ static void inlineRegionAndEmitStdStore(OpType op,
}
}
// Returns a pair that contains input indices and output indices of a
// SingleInputPoolingOp `op`.
template <typename SingleInputPoolingOp>
static std::pair<SmallVector<ValueHandle, 8>, SmallVector<ValueHandle, 8>>
getInputAndOutputIndices(ArrayRef<Value> allIvs, SingleInputPoolingOp op) {
auto &b = ScopedContext::getBuilder();
auto loc = ScopedContext::getLocation();
auto mapsRange = op.indexing_maps().template getAsRange<AffineMapAttr>();
auto maps =
functional::map([](AffineMapAttr a) { return a.getValue(); }, mapsRange);
SmallVector<ValueHandle, 8> iIdx(
makeCanonicalAffineApplies(b, loc, maps[0], allIvs));
SmallVector<ValueHandle, 8> oIdx(
makeCanonicalAffineApplies(b, loc, maps[2], allIvs));
return {iIdx, oIdx};
}
namespace {
template <typename IndexedValueType, typename LinalgOpType>
class LinalgScopedEmitter {};
@ -273,6 +290,57 @@ public:
}
};
template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, PoolingMaxOp> {
public:
static void emitScalarImplementation(ArrayRef<Value> allIvs,
PoolingMaxOp op) {
auto indices = getInputAndOutputIndices(allIvs, op);
ValueHandleArray iIdx(indices.first);
ValueHandleArray oIdx(indices.second);
// Emit scalar form.
ValueHandle lhs = std_load(op.output(), oIdx);
ValueHandle rhs = std_load(op.input(), iIdx);
using edsc::op::operator>;
ValueHandle maxValue = std_select(lhs > rhs, lhs, rhs);
std_store(maxValue, op.output(), oIdx);
}
};
template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, PoolingMinOp> {
public:
static void emitScalarImplementation(ArrayRef<Value> allIvs,
PoolingMinOp op) {
auto indices = getInputAndOutputIndices(allIvs, op);
ValueHandleArray iIdx(indices.first);
ValueHandleArray oIdx(indices.second);
// Emit scalar form.
ValueHandle lhs = std_load(op.output(), oIdx);
ValueHandle rhs = std_load(op.input(), iIdx);
using edsc::op::operator<;
ValueHandle minValue = std_select(lhs < rhs, lhs, rhs);
std_store(minValue, op.output(), oIdx);
}
};
template <typename IndexedValueType>
class LinalgScopedEmitter<IndexedValueType, PoolingSumOp> {
public:
static void emitScalarImplementation(ArrayRef<Value> allIvs,
PoolingSumOp op) {
auto indices = getInputAndOutputIndices(allIvs, op);
SmallVector<ValueHandle, 8> iIdx = indices.first;
SmallVector<ValueHandle, 8> oIdx = indices.second;
IndexedValueType input(op.input()), output(op.output());
// Emit scalar form.
output(oIdx) += input(iIdx);
}
};
// Emits the MLIR for the scalar part of the generic op by:
// 1. Emitting std_load and std_store ops for each input and output
// view in order. This is achieved by applying the appropriate input or
@ -688,6 +756,9 @@ INSTANTIATE_LINALG_OP_TO_LOOPS(DotOp)
INSTANTIATE_LINALG_OP_TO_LOOPS(MatvecOp)
INSTANTIATE_LINALG_OP_TO_LOOPS(MatmulOp)
INSTANTIATE_LINALG_OP_TO_LOOPS(ConvOp)
INSTANTIATE_LINALG_OP_TO_LOOPS(PoolingMaxOp)
INSTANTIATE_LINALG_OP_TO_LOOPS(PoolingMinOp)
INSTANTIATE_LINALG_OP_TO_LOOPS(PoolingSumOp)
INSTANTIATE_LINALG_OP_TO_LOOPS(GenericOp)
INSTANTIATE_LINALG_OP_TO_LOOPS(IndexedGenericOp)

View File

@ -513,3 +513,14 @@ func @reshape(%arg0: memref<?x?x?xf32>) {
%0 = linalg.reshape %arg0 [affine_map<(i, j, k) -> (i, j)>, affine_map<(i, j, k) -> (k)>] :
memref<?x?x?xf32> into memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d0 * s0 + d1)>>
}
// -----
func @pooling_rank_mismatch(%arg0: memref<?x?x?xf32>,
%arg1: memref<2x3xf32>,
%arg2: memref<?x?x?xf32>) {
// expected-error @+1 {{expects memref ranks to match}}
linalg.pooling_max(%arg0, %arg1, %arg2) {strides = [2, 1, 2]}:
memref<?x?x?xf32>, memref<2x3xf32>, memref<?x?x?xf32>
return
}

View File

@ -9,6 +9,7 @@
// CHECK-DAG: #[[strided4D:.*]] = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3)>
// CHECK-DAG: #[[clampMinMap:.*]] = affine_map<(d0) -> (d0, 0)>
// CHECK-DAG: #[[Stride1Dilation1:.*]] = affine_map<(d0, d1) -> (d0 + d1)>
// CHECK-DAG: #[[Stride2Dilation1:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)>
// CHECK-DAG: #[[Stride2Dilation4:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1 * 4)>
// CHECK-DAG: #[[Stride3Dilation5:.*]] = affine_map<(d0, d1) -> (d0 * 3 + d1 * 5)>
@ -251,6 +252,75 @@ func @conv_padding(%arg0: memref<?x?x?x?xf32>,
// CHECK: %{{.*}} = addf %{{.*}}, %{{.*}} : f32
// CHECK: store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref<?x?x?x?xf32>
func @pooling_max(%arg0: memref<?x?xf32>,
%arg1: memref<?x?xi32>,
%arg2: memref<?x?xf32>) {
linalg.pooling_max(%arg0, %arg1, %arg2) { strides = [2, 1] }:
memref<?x?xf32>, memref<?x?xi32>, memref<?x?xf32>
return
}
// CHECK-LABEL: func @pooling_max
// CHECK: %[[WX:.*]] = dim %arg1, 0 : memref<?x?xi32>
// CHECK: %[[WY:.*]] = dim %arg1, 1 : memref<?x?xi32>
// CHECK: %[[OX:.*]] = dim %arg2, 0 : memref<?x?xf32>
// CHECK: %[[OY:.*]] = dim %arg2, 1 : memref<?x?xf32>
// CHECK: loop.for %{{.*}} = %{{.*}} to %[[OX]] step %{{.*}} {
// CHECK: loop.for %{{.*}} = %{{.*}} to %[[OY]] step %{{.*}} {
// CHECK: loop.for %{{.*}} = %{{.*}} to %[[WX]] step %{{.*}} {
// CHECK: loop.for %{{.*}} = %{{.*}} to %[[WY]] step %{{.*}} {
// CHECK: %[[IX:.*]] = affine.apply #[[Stride2Dilation1]](%{{.*}}, %{{.*}})
// CHECK: %[[IY:.*]] = affine.apply #[[Stride1Dilation1]](%{{.*}}, %{{.*}})
// CHECK: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
// CHECK: %{{.*}} = load %{{.*}}[%[[IX]], %[[IY]]] : memref<?x?xf32>
// CHECK: %[[RES:.*]] = select %{{.*}}, %{{.*}}, %{{.*}} : f32
// CHECK: store %[[RES]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
func @pooling_min(%arg0: memref<?x?xf32>,
%arg1: memref<?x?xi32>,
%arg2: memref<?x?xf32>) {
linalg.pooling_min(%arg0, %arg1, %arg2) { strides = [2, 1] }:
memref<?x?xf32>, memref<?x?xi32>, memref<?x?xf32>
return
}
// CHECK-LABEL: func @pooling_min
// CHECK: %[[WX:.*]] = dim %arg1, 0 : memref<?x?xi32>
// CHECK: %[[WY:.*]] = dim %arg1, 1 : memref<?x?xi32>
// CHECK: %[[OX:.*]] = dim %arg2, 0 : memref<?x?xf32>
// CHECK: %[[OY:.*]] = dim %arg2, 1 : memref<?x?xf32>
// CHECK: loop.for %{{.*}} = %{{.*}} to %[[OX]] step %{{.*}} {
// CHECK: loop.for %{{.*}} = %{{.*}} to %[[OY]] step %{{.*}} {
// CHECK: loop.for %{{.*}} = %{{.*}} to %[[WX]] step %{{.*}} {
// CHECK: loop.for %{{.*}} = %{{.*}} to %[[WY]] step %{{.*}} {
// CHECK: %[[IX:.*]] = affine.apply #[[Stride2Dilation1]](%{{.*}}, %{{.*}})
// CHECK: %[[IY:.*]] = affine.apply #[[Stride1Dilation1]](%{{.*}}, %{{.*}})
// CHECK: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
// CHECK: %{{.*}} = load %{{.*}}[%[[IX]], %[[IY]]] : memref<?x?xf32>
// CHECK: %[[RES:.*]] = select %{{.*}}, %{{.*}}, %{{.*}} : f32
// CHECK: store %[[RES]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
func @pooling_sum(%arg0: memref<?x?xf32>,
%arg1: memref<?x?xi32>,
%arg2: memref<?x?xf32>) {
linalg.pooling_sum(%arg0, %arg1, %arg2) { strides = [2, 1] }:
memref<?x?xf32>, memref<?x?xi32>, memref<?x?xf32>
return
}
// CHECK-LABEL: func @pooling_sum
// CHECK: %[[WX:.*]] = dim %arg1, 0 : memref<?x?xi32>
// CHECK: %[[WY:.*]] = dim %arg1, 1 : memref<?x?xi32>
// CHECK: %[[OX:.*]] = dim %arg2, 0 : memref<?x?xf32>
// CHECK: %[[OY:.*]] = dim %arg2, 1 : memref<?x?xf32>
// CHECK: loop.for %{{.*}} = %{{.*}} to %[[OX]] step %{{.*}} {
// CHECK: loop.for %{{.*}} = %{{.*}} to %[[OY]] step %{{.*}} {
// CHECK: loop.for %{{.*}} = %{{.*}} to %[[WX]] step %{{.*}} {
// CHECK: loop.for %{{.*}} = %{{.*}} to %[[WY]] step %{{.*}} {
// CHECK: %[[IX:.*]] = affine.apply #[[Stride2Dilation1]](%{{.*}}, %{{.*}})
// CHECK: %[[IY:.*]] = affine.apply #[[Stride1Dilation1]](%{{.*}}, %{{.*}})
// CHECK: %[[RHS:.*]] = load %{{.*}}[%[[IX]], %[[IY]]] : memref<?x?xf32>
// CHECK: %[[LHS:.*]] = load %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
// CHECK: %[[RES:.*]] = addf %[[LHS]], %[[RHS]] : f32
// CHECK: store %[[RES]], %{{.*}}[%{{.*}}, %{{.*}}] : memref<?x?xf32>
func @foo(%0: f32, %1: f32, %2: f32) -> (f32, f32) {
%f0 = constant 0.0 : f32
return %f0, %f0 : f32, f32

View File

@ -244,6 +244,48 @@ func @conv_padding(%arg0: memref<?x?x?x?xf32>,
// -----
func @pooling_max(%arg0: memref<?x?x?xf32>,
%arg1: memref<?x?x?xi32>,
%arg2: memref<?x?x?xf32>) {
linalg.pooling_max(%arg0, %arg1, %arg2) {strides = [2, 1, 2]}:
memref<?x?x?xf32>, memref<?x?x?xi32>, memref<?x?x?xf32>
return
}
// CHECK-LABEL: func @pooling_max
// CHECK: linalg.pooling_max(%{{.*}}, %{{.*}}, %{{.*}})
// CHECK-SAME: {strides = [2, 1, 2]}
// CHECK-SAME: memref<?x?x?xf32>, memref<?x?x?xi32>, memref<?x?x?xf32>
// -----
func @pooling_min(%arg0: memref<?x?x?xf32>,
%arg1: memref<?x?x?xi32>,
%arg2: memref<?x?x?xf32>) {
linalg.pooling_min(%arg0, %arg1, %arg2) {strides = [2, 1, 2]}:
memref<?x?x?xf32>, memref<?x?x?xi32>, memref<?x?x?xf32>
return
}
// CHECK-LABEL: func @pooling_min
// CHECK: linalg.pooling_min(%{{.*}}, %{{.*}}, %{{.*}})
// CHECK-SAME: {strides = [2, 1, 2]}
// CHECK-SAME: memref<?x?x?xf32>, memref<?x?x?xi32>, memref<?x?x?xf32>
// -----
func @pooling_sum(%arg0: memref<?x?x?xf32>,
%arg1: memref<?x?x?xi32>,
%arg2: memref<?x?x?xf32>) {
linalg.pooling_sum(%arg0, %arg1, %arg2) {strides = [2, 1, 2]}:
memref<?x?x?xf32>, memref<?x?x?xi32>, memref<?x?x?xf32>
return
}
// CHECK-LABEL: func @pooling_sum
// CHECK: linalg.pooling_sum(%{{.*}}, %{{.*}}, %{{.*}})
// CHECK-SAME: {strides = [2, 1, 2]}
// CHECK-SAME: memref<?x?x?xf32>, memref<?x?x?xi32>, memref<?x?x?xf32>
// -----
// CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
// CHECK-DAG: #[[strided3D:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2)>