[mlir][linalg] Fuse producers with non-permutation indexing maps

Until now Linalg fusion only allow fusing producers whose operands
are all permutation indexing maps. It's easier to deduce the
subtensor/subview but it is an unnecessary constraint, as in tiling
we have more advanced logic to deduce the subranges even when the
operand is not of permutation indexing maps, e.g., the input operand
for convolution ops.

This patch uses the logic on tiling side to deduce subranges for
fusion. This enables fusing convolution with its consumer ops
when possible.

Along the way, we are now generating proper affine.min ops to guard
against size boundaries, if we cannot be certain they won't be
out of bounds.

Differential Revision: https://reviews.llvm.org/D99014
This commit is contained in:
Lei Zhang 2021-03-24 17:49:58 -04:00
parent ddf93abf49
commit e58597ee1c
8 changed files with 458 additions and 214 deletions

View File

@ -59,104 +59,6 @@ using llvm::dbgs;
/// More advanced use cases, analyses as well as profitability heuristics are
/// left for future work.
// Fill `offset`, `sizes` and `strides` used to iterate over the shape indexed
// by `permutationMap`.
static void inferShapeComponents(AffineMap permutationMap,
ArrayRef<Range> loopRanges,
SmallVectorImpl<OpFoldResult> &offsets,
SmallVectorImpl<OpFoldResult> &sizes,
SmallVectorImpl<OpFoldResult> &strides) {
assert(permutationMap.isProjectedPermutation() &&
"expected some subset of a permutation map");
SmallVector<Range, 4> shapeRanges(permutationMap.getNumResults());
unsigned idx = 0;
for (AffineExpr e : permutationMap.getResults()) {
// loopToOperandRangesMaps are permutations-only, just swap indices.
unsigned loopPos = e.cast<AffineDimExpr>().getPosition();
shapeRanges[idx++] = loopRanges[loopPos];
}
// Construct a new subshape for the tile.
unsigned rank = shapeRanges.size();
offsets.reserve(rank);
sizes.reserve(rank);
strides.reserve(rank);
for (auto r : shapeRanges) {
offsets.push_back(r.offset);
sizes.push_back(r.size);
strides.push_back(r.stride);
}
}
// Return a cloned version of `op` that operates on `loopRanges`, assumed to be
// a subset of the original loop ranges of `op`.
// This is achieved by applying the `loopToOperandRangesMaps` permutation maps
// to the `loopRanges` in order to obtain view ranges.
static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
ArrayRef<Range> loopRanges) {
SmallVector<Value, 8> clonedShapes;
clonedShapes.reserve(op.getNumShapedOperands());
// Iterate over the shape operands in order.
// Extract the subranges from the linearized ranges.
for (auto en : llvm::enumerate(op.getShapedOperands())) {
unsigned shapedOperandIdx = en.index();
AffineMap map = op.getIndexingMap(shapedOperandIdx);
LLVM_DEBUG(llvm::dbgs() << "shapedOperandIdx: " << shapedOperandIdx
<< " with indexingMap: " << map << "\n");
SmallVector<OpFoldResult, 4> offsets, sizes, strides;
inferShapeComponents(map, loopRanges, offsets, sizes, strides);
Value shape = en.value();
Value sub =
shape.getType().isa<MemRefType>()
? b.create<memref::SubViewOp>(loc, shape, offsets, sizes, strides)
.getResult()
: b.create<SubTensorOp>(loc, shape, offsets, sizes, strides)
.getResult();
clonedShapes.push_back(sub);
}
// Append the other operands.
auto operands = op.getAssumedNonShapedOperands();
clonedShapes.append(operands.begin(), operands.end());
// Iterate over the results in order.
// Extract the subtensor type from the linearized range.
// Since we do not enforce any canonicalizations on the fly, this is always
// fully dynamic at construction time.
SmallVector<Type, 4> resultTypes;
resultTypes.reserve(op->getNumResults());
for (RankedTensorType t : op.getOutputTensorTypes()) {
unsigned rank = t.getRank();
SmallVector<int64_t, 4> staticOffsetsVector(
rank, ShapedType::kDynamicStrideOrOffset);
SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
SmallVector<int64_t, 4> staticStridesVector(
rank, ShapedType::kDynamicStrideOrOffset);
resultTypes.push_back(SubTensorOp::inferResultType(
t.cast<RankedTensorType>(), staticOffsetsVector, staticSizesVector,
staticStridesVector));
}
Operation *clonedOp = op.clone(b, loc, resultTypes, clonedShapes);
// When the producer is an IndexedGenericOp, we have to transform its block
// IV arguments according to the tiling of the consumer, i.e. offset them by
// the values computed in `loopRanges`.
if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) {
auto &block = indexedGenericOp.region().front();
OpBuilder::InsertionGuard g(b);
b.setInsertionPointToStart(&block);
for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) {
Value oldIndex = block.getArgument(i);
// TODO: replace by an affine_apply.
AddIOp newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex,
loopRanges[i].offset);
oldIndex.replaceAllUsesExcept(newIndex,
SmallPtrSet<Operation *, 1>{newIndex});
}
}
return clonedOp;
}
struct ShapeDimension {
Value shape;
unsigned dimension;
@ -208,35 +110,86 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
llvm_unreachable("Expect to be able to extract a shape defining loop range");
}
/// Fuse the producer by cloning the `producer`. The `fusedLoopsAndRanges`
/// Fuses the producer by cloning the `producer`. The `fusedLoopsAndRanges`
/// provides the loop range information for the fused loops. The rest are
/// obtained from the producer itself, since they are not tiled + fused.
static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
static LinalgOp fuse(OpBuilder &builder, LinalgOp producer,
const DenseMap<unsigned, Range> &fusedLoopsAndRanges) {
SmallVector<Value, 8> ivs, tileSizes, sizeBounds;
SmallVector<Range, 8> loopRanges;
auto zero = std_constant_index(0);
auto one = std_constant_index(1);
Location loc = producer.getLoc();
unsigned nPar = producer.getNumParallelLoops();
unsigned nRed = producer.getNumReductionLoops();
unsigned nWin = producer.getNumWindowLoops();
SmallVector<Range, 8> loopRanges(nPar + nRed + nWin);
for (auto fusedLoops : fusedLoopsAndRanges)
loopRanges[fusedLoops.first] = fusedLoops.second;
// Iterate over all dimensions. For the dimensions not identified by the
// producer map for `producerIdx`, we need to explicitly compute the shape
// that defines the loop ranges using the `producer`.
for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) {
if (loopRanges[i].offset)
LLVM_DEBUG(llvm::dbgs()
<< "existing LoopRange: " << loopRanges[i] << "\n");
else {
for (unsigned i = 0, e = producer.getNumLoops(); i < e; ++i) {
auto it = fusedLoopsAndRanges.find(i);
if (it != fusedLoopsAndRanges.end()) {
ivs.push_back(it->second.offset);
tileSizes.push_back(it->second.size);
sizeBounds.push_back(nullptr);
loopRanges.push_back(it->second);
LLVM_DEBUG(llvm::dbgs() << "tiled loop#" << i << " with LoopRange "
<< loopRanges.back() << "\n");
} else {
auto shapeDim = getShapeDefiningLoopRange(producer, i);
Value dim = memref_dim(shapeDim.shape, shapeDim.dimension);
loopRanges[i] = Range{std_constant_index(0), dim, std_constant_index(1)};
LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
tileSizes.push_back(zero);
sizeBounds.push_back(dim);
loopRanges.push_back(Range{zero, dim, one});
LLVM_DEBUG(llvm::dbgs() << "full loop#" << i << " with LoopRange "
<< loopRanges.back() << "\n");
}
}
return cloneWithLoopRanges(b, producer.getLoc(), producer, loopRanges);
SmallVector<Value, 8> clonedShapes;
clonedShapes.reserve(producer.getNumShapedOperands());
// Compute subranges for all tensor input/output operands.
auto tiledOperands = llvm::to_vector<4>(producer.getShapedOperands());
clonedShapes.append(makeTiledShapes(builder, loc, producer, tiledOperands,
ivs, tileSizes, sizeBounds));
// Append the other operands.
auto operands = producer.getAssumedNonShapedOperands();
clonedShapes.append(operands.begin(), operands.end());
// Iterate over the results in order.
// Extract the subtensor type from the linearized range.
// Since we do not enforce any canonicalizations on the fly, this is always
// fully dynamic at construction time.
SmallVector<Type, 4> resultTypes;
resultTypes.reserve(producer->getNumResults());
for (RankedTensorType t : producer.getOutputTensorTypes()) {
unsigned rank = t.getRank();
SmallVector<int64_t, 4> staticOffsetsVector(
rank, ShapedType::kDynamicStrideOrOffset);
SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
SmallVector<int64_t, 4> staticStridesVector(
rank, ShapedType::kDynamicStrideOrOffset);
resultTypes.push_back(SubTensorOp::inferResultType(
t.cast<RankedTensorType>(), staticOffsetsVector, staticSizesVector,
staticStridesVector));
}
Operation *clonedOp = producer.clone(builder, loc, resultTypes, clonedShapes);
// When the producer is an IndexedGenericOp, we have to transform its block
// IV arguments according to the tiling of the consumer, i.e. offset them by
// the values computed in `loopRanges`.
if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) {
auto &block = indexedGenericOp.region().front();
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointToStart(&block);
for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) {
Value oldIndex = block.getArgument(i);
// TODO: replace by an affine_apply.
AddIOp newIndex = builder.create<AddIOp>(indexedGenericOp.getLoc(),
oldIndex, loopRanges[i].offset);
oldIndex.replaceAllUsesExcept(newIndex,
SmallPtrSet<Operation *, 1>{newIndex});
}
}
return clonedOp;
}
/// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is

View File

@ -27,6 +27,9 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/LoopUtils.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "linalg-utils"
using namespace mlir;
using namespace mlir::edsc;
@ -447,11 +450,14 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &builder, Location loc,
// that define tile subshapes.
SmallVector<Value, 8> lbs, subShapeSizes;
for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) {
LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n");
bool isTiled = !isZero(tileSizes[idx]);
lbs.push_back(isTiled ? ivs[idxIvs++] : (Value)std_constant_index(0));
// Before composing, we need to make range a closed interval.
Value size = isTiled ? tileSizes[idx] : sizeBounds[idx];
subShapeSizes.push_back(size - std_constant_index(1));
LLVM_DEBUG(llvm::dbgs() << "lb: " << lbs.back() << "\n");
LLVM_DEBUG(llvm::dbgs() << "size: " << subShapeSizes.back() << "\n");
}
MLIRContext *context = builder.getContext();
@ -459,14 +465,18 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &builder, Location loc,
tiledShapes.reserve(tiledOperands.size());
for (auto en : llvm::enumerate(tiledOperands)) {
Value shapedOp = en.value();
LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for operand " << shapedOp);
ShapedType shapedType = shapedOp.getType().cast<ShapedType>();
unsigned rank = shapedType.getRank();
AffineMap map = linalgOp.getIndexingMap(en.index());
// If the shape is not tiled, we can use it as is.
if (!isTiled(map, tileSizes)) {
tiledShapes.push_back(shapedOp);
LLVM_DEBUG(llvm::dbgs()
<< ": not tiled: use shape: " << shapedType << "\n");
continue;
}
LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subshape...\n");
// Construct a new subview / subtensor for the tile.
SmallVector<OpFoldResult, 4> offsets, sizes, strides;
@ -474,22 +484,28 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &builder, Location loc,
sizes.reserve(rank);
strides.reserve(rank);
for (unsigned r = 0; r < rank; ++r) {
LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for dim#" << r);
if (!isTiled(map.getSubMap({r}), tileSizes)) {
offsets.push_back(builder.getIndexAttr(0));
sizes.push_back(memref_dim(shapedOp, r).value);
Value dim = memref_dim(shapedOp, r).value;
sizes.push_back(dim);
strides.push_back(builder.getIndexAttr(1));
LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n");
continue;
}
LLVM_DEBUG(llvm::dbgs() << ": tiled: figure out subsize...\n");
// Tiling creates a new slice at the proper index, the slice step is 1
// (i.e. the op does not subsample, stepping occurs in the loop).
auto m = map.getSubMap({r});
LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: submap: " << map << "\n");
auto offset = applyMapToValues(builder, loc, m, lbs).front();
offsets.push_back(offset);
auto closedIntSize =
applyMapToValues(builder, loc, m, subShapeSizes).front();
// Resulting size needs to be made half open interval again.
auto size = closedIntSize + std_constant_index(1);
LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: raw size: " << size << "\n");
// The size of the subview / subtensor should be trimmed to avoid
// out-of-bounds accesses, unless we statically know the subshape size
@ -498,6 +514,9 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &builder, Location loc,
auto sizeCst = size.getDefiningOp<ConstantIndexOp>();
if (ShapedType::isDynamic(shapeSize) || !sizeCst ||
(shapeSize % sizeCst.getValue()) != 0) {
LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: shapeSize=" << shapeSize
<< ", size: " << size
<< ": make sure in bound with affine.min\n");
AffineExpr dim0, dim1, dim2;
bindDims(context, dim0, dim1, dim2);
// Compute min(size, dim - offset) to avoid out-of-bounds accesses.
@ -510,6 +529,9 @@ SmallVector<Value, 4> makeTiledShapes(OpBuilder &builder, Location loc,
}
sizes.push_back(size);
LLVM_DEBUG(llvm::dbgs()
<< "makeTiledShapes: new offset: " << offset << "\n");
LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: new size: " << size << "\n");
strides.push_back(builder.getIndexAttr(1));
}

View File

@ -16,6 +16,7 @@ module {
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>
// CHECK: func @basic_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
@ -47,8 +48,10 @@ module {
// CHECK: %[[TILE_N_2:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N_2]]]
// CHECK: %[[SV3:.+]] = memref.subview %[[ARG2]][%[[IV0]], %[[IV1]]]
// CHECK-SAME: [%[[TILE_M_2]], %[[TILE_N_2]]]
// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_M]])[%[[M_2]]]
// CHECK: %[[TILE_N_3:.+]] = affine.min #[[MAP4]](%[[IV1]], %[[TILE_N]])[%[[N_2]]]
// CHECK: %[[SV3_2:.+]] = memref.subview %[[ARG2]][%[[IV0]], %[[IV1]]]
// CHECK-SAME: [%[[TILE_M]], %[[TILE_N]]]
// CHECK-SAME: [%[[TILE_M_3]], %[[TILE_N_3]]]
// CHECK: linalg.fill(%[[SV3_2]], %[[CST]])
// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_producer"
// CHECK: scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] {
@ -86,6 +89,7 @@ module {
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (32, -d0 + s0)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>
// CHECK: func @rhs_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
@ -112,10 +116,13 @@ module {
// CHECK: %[[SV2:.+]] = memref.subview %[[ARG3]][0, %[[IV0]]]
// CHECK-SAME: [%[[M]], %[[TILE_N_2]]]
// CHECK: %[[K_2:.+]] = memref.dim %[[ARG1]], %[[C0]]
// CHECK: %[[N_3:.+]] = memref.dim %[[ARG1]], %[[C1]]
// CHECK: %[[TILE_N_3:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_N]])[%[[N_3]]]
// CHECK: %[[SV3:.+]] = memref.subview %[[ARG1]][0, %[[IV0]]]
// CHECK-SAME: [%[[K_2]], %[[TILE_N]]]
// CHECK-SAME: [%[[K_2]], %[[TILE_N_3]]]
// CHECK: %[[TILE_N_4:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_N]])[%[[N]]]
// CHECK: %[[SV3_2:.+]] = memref.subview %[[ARG2]][0, %[[IV0]]]
// CHECK-SAME: [%[[K_2]], %[[TILE_N]]]
// CHECK-SAME: [%[[K]], %[[TILE_N_4]]]
// CHECK: linalg.copy(%[[SV3]], %[[SV3_2]])
// CHECK-SAME: __internal_linalg_transform__ = "after_rhs_fusion_producer"
// CHECK-NOT: linalg.fill
@ -164,6 +171,7 @@ module {
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)>
// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>
// CHECK: func @two_operand_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
@ -191,13 +199,17 @@ module {
// CHECK: %[[N:.+]] = memref.dim %[[ARG3]], %[[C1]]
// CHECK: %[[SV2:.+]] = memref.subview %[[ARG3]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M_2]], %[[N]]]
// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_M]])[%[[M_2]]]
// CHECK: %[[SV2_2:.+]] = memref.subview %[[ARG3]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M]], %[[N]]]
// CHECK-SAME: [%[[TILE_M_3]], %[[N]]]
// CHECK: %[[M_3:.+]] = memref.dim %[[ARG0]], %[[C0]]
// CHECK: %[[TILE_M_4:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_M]])[%[[M_3]]]
// CHECK: %[[K_2:.+]] = memref.dim %[[ARG0]], %[[C1]]
// CHECK: %[[SV3:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M]], %[[K_2]]]
// CHECK-SAME: [%[[TILE_M_4]], %[[K_2]]]
// CHECK: %[[TILE_M_5:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_M]])[%[[M]]]
// CHECK: %[[SV3_2:.+]] = memref.subview %[[ARG1]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M]], %[[K_2]]]
// CHECK-SAME: [%[[TILE_M_5]], %[[K]]]
// CHECK: linalg.copy(%[[SV3]], %[[SV3_2]])
// CHECK-SAME: __internal_linalg_transform__ = "after_two_operand_fusion_producer"
// CHECK: linalg.fill(%[[SV2_2]], %[[CST]])
@ -271,23 +283,24 @@ module {
// CHECK: %[[N:.+]] = memref.dim %[[ARG4]], %[[C1]]
// CHECK: %[[SV2:.+]] = memref.subview %[[ARG4]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M_2]], %[[N]]]
// CHECK: %[[K2_2:.+]] = memref.dim %[[ARG1]], %[[C1]]
// CHECK: %[[M_3:.+]] = memref.dim %[[ARG0]], %[[C0]]
// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_M]])[%[[M_3]]]
// CHECK: %[[K1:.+]] = memref.dim %[[ARG0]], %[[C1]]
// CHECK: %[[SV3:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M]], %[[K1]]]
// CHECK: %[[SV4:.+]] = memref.subview %[[ARG1]][0, 0] [%[[K1]], %[[K2_2]]]
// CHECK-SAME: [%[[TILE_M_3]], %[[K1]]]
// CHECK: %[[TILE_M_4:.+]] = affine.min #[[MAP4]](%[[IV0]], %[[TILE_M]])[%[[M]]]
// CHECK: %[[SV1_2:.+]] = memref.subview %[[ARG2]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M]], %[[K2_2]]]
// CHECK-SAME: [%[[TILE_M_4]], %[[K2]]]
// CHECK: linalg.matmul
// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_producer"
// CHECK-SAME: ins(%[[SV3]], %[[SV4]]
// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
// CHECK-SAME: ins(%[[SV3]], %[[ARG1]]
// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32>)
// CHECK-SAME: outs(%[[SV1_2]] : memref<?x?xf32, #[[MAP1]]>)
// CHECK-DAG: %[[N_2:.+]] = memref.dim %[[ARG3]], %[[C1]]
// CHECK: %[[N_2:.+]] = memref.dim %[[ARG3]], %[[C1]]
// CHECK: scf.parallel (%[[IV1:.+]]) =
// CHECK-SAME: (%[[C0]]) to (%[[N_2]]) step (%[[C64]]) {
// CHECK-NEXT: scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] {
// CHECK: %[[TILE_K:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[K]]]
// CHECK-NEXT: scf.for %[[IV2:.+]] = %[[C0]] to %[[K2]] step %[[C16]] {
// CHECK: %[[TILE_K:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[K2]]]
// CHECK: %[[SV6:.+]] = memref.subview %[[SV1]][0, %[[IV2]]]
// CHECK-SAME: [%[[TILE_M]], %[[TILE_K]]]
// CHECK: %[[K_2:.+]] = memref.dim %[[ARG3]], %[[C0]]
@ -348,10 +361,11 @@ module {
// CHECK: %[[T6:.+]] = memref.subview %[[ARG2]][%[[ARG3]], %[[ARG4]]]
// CHECK: %[[T8:.+]] = memref.subview %[[ARG0]][%[[ARG3]], 0]
// CHECK: %[[T9:.+]] = memref.subview %[[ARG1]][0, %[[ARG4]]]
// CHECK: %[[T10:.+]] = memref.subview %[[T2]][%[[ARG3]], %[[ARG4]]]
// CHECK: linalg.matmul
// CHECK-SAME: after_transpose_fusion_producer
// CHECK-SAME: ins(%[[T8]], %[[T9]]
// CHECK-SAME: outs(%[[T5]]
// CHECK-SAME: outs(%[[T10]]
// CHECK-NOT: linalg.matmul
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[T5]], %[[T5]]

View File

@ -36,18 +36,19 @@ module {
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref<?x?xf32>
// CHECK: %[[TEMP:.+]] = memref.alloc(%{{.*}}, %{{.*}}) : memref<?x?xf32>
// CHECK: scf.parallel (%[[IV0:.+]], %[[IV1:.+]]) = {{.*}} {
// CHECK-DAG: %[[SV_TEMP:.+]] = memref.subview %[[TEMP]][%[[IV0]], %[[IV1]]]
// CHECK: %[[SV_TEMP_1:.+]] = memref.subview %[[TEMP]][%[[IV0]], %[[IV1]]]
// CHECK-DAG: %[[SV_ARG2:.+]] = memref.subview %[[ARG2]][%[[IV1]]]
// CHECK-DAG: %[[SV_ARG3:.+]] = memref.subview %[[ARG3]][%[[IV0]], %[[IV1]]]
// CHECK-DAG: %[[SV_ARG0:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0]
// CHECK-DAG: %[[SV_ARG1:.+]] = memref.subview %[[ARG1]][0, %[[IV1]]]
// CHECK: linalg.fill(%[[SV_TEMP]], %{{.+}})
// CHECK: %[[SV_TEMP_2:.+]] = memref.subview %[[TEMP]][%[[IV0]], %[[IV1]]]
// CHECK: linalg.fill(%[[SV_TEMP_2]], %{{.+}})
// CHECK: linalg.matmul
// CHECK-SAME: ins(%[[SV_ARG0]], %[[SV_ARG1]]
// CHECK-SAME: : memref<?x?xf32, #[[MAP2]]>, memref<?x?xf32, #[[MAP2]]>)
// CHECK-SAME: outs(%[[SV_TEMP]] : memref<?x?xf32, #[[MAP2]]>)
// CHECK-SAME: outs(%[[SV_TEMP_2]] : memref<?x?xf32, #[[MAP2]]>)
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[SV_TEMP]], %[[SV_ARG2]]
// CHECK-SAME: ins(%[[SV_TEMP_1]], %[[SV_ARG2]]
// CHECK-SAME: : memref<?x?xf32, #[[MAP2]]>, memref<?xf32, #[[MAP3]]>)
// CHECK-SAME: outs(%[[SV_ARG3]] : memref<?x?xf32, #[[MAP2]]>)
// CHECK: scf.yield
@ -83,6 +84,8 @@ module {
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>
// CHECK: func @sequence_of_matmul
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref<?x?xf32>
@ -100,37 +103,40 @@ module {
// CHECK: scf.parallel (%[[IV0:.+]]) = (%[[C0]]) to (%[[M]])
// CHECK-SAME: step (%[[C16]]) {
// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
// CHECK: %[[SV_ALLOC2:.+]] = memref.subview %[[ALLOC2]][%[[IV0]], 0]
// CHECK: %[[SV_ALLOC3:.+]] = memref.subview %[[ALLOC2]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M]], %[[N2]]]
// CHECK: %[[M_2:.+]] = memref.dim %[[ARG4]], %[[C0]]
// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M_2]]]
// CHECK: %[[N3:.+]] = memref.dim %[[ARG4]], %[[C1]]
// CHECK: %[[SV_ARG4:.+]] = memref.subview %[[ARG4]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M_2]], %[[N3]]]
// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP2]](%[[IV0]], %[[TILE_M]])[%[[M_2]]]
// CHECK: %[[SV_ARG4_2:.+]] = memref.subview %[[ARG4]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M]], %[[N3]]]
// CHECK-SAME: [%[[TILE_M_3]], %[[N3]]]
// CHECK: %[[TILE_M_4:.+]] = affine.min #[[MAP2]](%[[IV0]], %[[TILE_M]])[%[[M]]]
// CHECK: %[[SV_ALLOC1:.+]] = memref.subview %[[ALLOC1]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M]], %[[N1]]]
// CHECK: %[[SV_ARG2:.+]] = memref.subview %[[ARG2]][0, 0] [%[[N1]], %[[N2]]]
// CHECK-SAME: [%[[TILE_M_4]], %[[N1]]]
// CHECK: %[[SV_ALLOC2:.+]] = memref.subview %[[ALLOC2]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M_4]], %[[N2]]]
// CHECK: %[[N0:.+]] = memref.dim %[[ARG0]], %[[C1]]
// CHECK: %[[SV_ARG0:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M:.+]], %[[N0]]]
// CHECK: %[[SV_ARG1:.+]] = memref.subview %[[ARG1]][0, 0] [%[[N0]], %[[N1]]]
// CHECK-SAME: [%[[TILE_M_4]], %[[N0]]]
// CHECK: linalg.fill(%[[SV_ALLOC1]], %{{.+}})
// CHECK: linalg.matmul ins(%[[SV_ARG0]], %[[SV_ARG1]]
// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
// CHECK: linalg.matmul ins(%[[SV_ARG0]], %[[ARG1]]
// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32>)
// CHECK-SAME: outs(%[[SV_ALLOC1]] : memref<?x?xf32, #[[MAP1]]>)
// CHECK: linalg.fill(%[[SV_ALLOC2]], %{{.+}})
// CHECK: linalg.matmul ins(%[[SV_ALLOC1]], %[[SV_ARG2]]
// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32, #[[MAP1]]>)
// CHECK: linalg.matmul ins(%[[SV_ALLOC1]], %[[ARG2]]
// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32>)
// CHECK-SAME: outs(%[[SV_ALLOC2]] : memref<?x?xf32, #[[MAP1]]>)
// CHECK: linalg.fill(%[[SV_ARG4_2]], %{{.+}})
// CHECK: linalg.matmul ins(%[[SV_ALLOC2]], %[[ARG3]]
// CHECK: linalg.matmul ins(%[[SV_ALLOC3]], %[[ARG3]]
// CHECK-SAME: : memref<?x?xf32, #[[MAP1]]>, memref<?x?xf32>)
// CHECK-SAME: outs(%[[SV_ARG4]] : memref<?x?xf32, #[[MAP1]]>)
// CHECK: scf.yield
// CHECK: }
// -----
module {
@ -189,8 +195,8 @@ module {
module {
func @tensor_matmul_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
%arg2: tensor<?x?xf32>, %arg3: tensor<?x?xf32>,
%arg4: tensor<?x?xf32>, %arg5: tensor<?x?xf32>,
%arg6: tensor<?x?xf32>) -> tensor<?x?xf32> {
%arg4: tensor<?x?xf32>, %arg5: tensor<?x?xf32>,
%arg6: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N0] * [N0, N1]
%1 = linalg.matmul ins(%0, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>)
@ -200,7 +206,12 @@ module {
return %2 : tensor<?x?xf32>
}
}
// CHECK-LABEL: func @tensor_matmul_fusion(
// CHECK: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1) -> (16, d0 - d1)>
// CHECK: #[[MAP2:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>
// CHECK: func @tensor_matmul_fusion(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
@ -210,36 +221,39 @@ module {
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK: %[[M:.+]] = memref.dim %[[ARG0]], %c0 : tensor<?x?xf32>
// CHECK: %[[R0:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]] =
// CHECK-SAME: iter_args(%[[ARG8:.+]] = %[[ARG6]]) -> (tensor<?x?xf32>) {
// CHECK: %[[N3:.+]] = memref.dim %[[ARG8]], %[[C1]]
// CHECK: %[[STARG6:.+]] = subtensor %[[ARG8]][%[[IV0]], 0]
// CHECK-SAME: [%{{[a-zA-Z0-9_]+}}, %[[N3]]]
// CHECK: %[[N2:.+]] = memref.dim %[[ARG3]], %[[C1]]
// CHECK: %[[N1:.+]] = memref.dim %[[ARG1]], %[[C1]]
// CHECK: %[[STARG3:.+]] = subtensor %[[ARG3]][0, 0]
// CHECK-SAME: [%[[N1]], %[[N2]]]
// CHECK: %[[STARG4:.+]] = subtensor %[[ARG4]][%[[IV0]], 0]
// CHECK-SAME: [%{{[a-zA-Z0-9_]+}}, %[[N2]]]
// CHECK: %[[N0:.+]] = memref.dim %[[ARG0]], %[[C1]]
// CHECK: %[[STARG0:.+]] = subtensor %[[ARG0]][%[[IV0]], 0]
// CHECK-SAME: [%{{[a-zA-Z0-9_]+}}, %[[N0]]]
// CHECK: %[[STARG1:.+]] = subtensor %[[ARG1]][0, 0]
// CHECK-SAME: [%[[N0]], %[[N1]]]
// CHECK: %[[STARG2:.+]] = subtensor %[[ARG2]][%[[IV0]], 0]
// CHECK-SAME: [%{{[a-zA-Z0-9_]+}}, %[[N1]]]
// CHECK: %[[T0:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[STARG0]], %[[STARG1]]
// CHECK-SAME: ) outs(%[[STARG2]] : tensor<?x?xf32>)
// CHECK: %[[T1:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[T0]], %[[STARG3]]
// CHECK-SAME: ) outs(%[[STARG4]] : tensor<?x?xf32>)
// CHECK: %[[T2:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[T1]], %[[ARG5]]
// CHECK-SAME: ) outs(%[[STARG6]] : tensor<?x?xf32>)
// CHECK: %[[R1:.+]] = subtensor_insert %[[T2]]
// CHECK-SAME: into %[[ARG8]][%[[IV0]], 0]
// CHECK: scf.yield %[[R1]]
// CHECK: }
// CHECK: return %[[R0]]
// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
// CHECK: %[[M_1:.+]] = memref.dim %[[ARG8]], %[[C0]]
// CHECK: %[[TILE_M_1:.+]] = affine.min #[[MAP1]](%[[M_1]], %[[IV0]])
// CHECK: %[[N3:.+]] = memref.dim %[[ARG8]], %[[C1]]
// CHECK: %[[STARG6:.+]] = subtensor %[[ARG8]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M_1]], %[[N3]]]
// CHECK: %[[M_2:.+]] = memref.dim %[[ARG4]], %[[C0]]
// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP2]](%[[IV0]], %[[TILE_M]])[%[[M_2]]]
// CHECK: %[[N2:.+]] = memref.dim %[[ARG4]], %[[C1]]
// CHECK: %[[STARG4:.+]] = subtensor %[[ARG4]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M_2]], %[[N2]]]
// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP2]](%[[IV0]], %[[TILE_M]])[%[[M]]]
// CHECK: %[[N0:.+]] = memref.dim %[[ARG0]], %[[C1]]
// CHECK: %[[STARG0:.+]] = subtensor %[[ARG0]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M_3]], %[[N0]]]
// CHECK: %[[M_3:.+]] = memref.dim %[[ARG2]], %[[C0]]
// CHECK: %[[TILE_M_4:.+]] = affine.min #[[MAP2]](%[[IV0]], %[[TILE_M]])[%[[M_3]]]
// CHECK: %[[N1:.+]] = memref.dim %[[ARG2]], %[[C1]]
// CHECK: %[[STARG2:.+]] = subtensor %[[ARG2]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M_4]], %[[N1]]]
// CHECK: %[[T0:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[STARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>
// CHECK-SAME: ) outs(%[[STARG2]] : tensor<?x?xf32>)
// CHECK: %[[T1:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[T0]], %arg3 : tensor<?x?xf32>, tensor<?x?xf32>
// CHECK-SAME: ) outs(%[[STARG4]] : tensor<?x?xf32>)
// CHECK: %[[T2:.+]] = linalg.matmul
// CHECK-SAME: ins(%[[T1]], %arg5 : tensor<?x?xf32>, tensor<?x?xf32>
// CHECK-SAME: ) outs(%[[STARG6]] : tensor<?x?xf32>)
// CHECK: %[[R1:.+]] = subtensor_insert %[[T2]]
// CHECK-SAME: into %[[ARG8]][%[[IV0]], 0] [%[[TILE_M_1]], %[[N3]]]
// CHECK: scf.yield %[[R1]] : tensor<?x?xf32>
// CHECK: }

View File

@ -17,12 +17,15 @@ module {
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)>
// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1) -> (64, d0 - d1)>
// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>
// CHECK: func @matmul_fusion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK-DAG: %[[C32:.+]] = constant 32 : index
@ -38,18 +41,20 @@ module {
// CHECK: %[[N3:.+]] = memref.dim %[[ARG6]], %[[C1]]
// CHECK: %[[ST_ARG6:.+]] = subtensor %[[ARG6]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M_2]], %[[N3]]]
// CHECK: %[[N2:.+]] = memref.dim %[[ARG1]], %[[C1]]
// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP5]](%[[IV0]], %[[TILE_M]])[%[[M]]]
// CHECK: %[[N1:.+]] = memref.dim %[[ARG0]], %[[C1]]
// CHECK: %[[ST_ARG0:.+]] = subtensor %[[ARG0]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M]], %[[N1]]]
// CHECK: %[[ST_ARG1:.+]] = subtensor %[[ARG1]][0, 0]
// CHECK-SAME: [%[[N1]], %[[N2]]]
// CHECK-SAME: [%[[TILE_M_3]], %[[N1]]]
// CHECK: %[[M_3:.+]] = memref.dim %[[ARG2]], %[[C0]]
// CHECK: %[[TILE_M_4:.+]] = affine.min #[[MAP5]](%[[IV0]], %[[TILE_M]])[%[[M_3]]]
// CHECK: %[[N2_2:.+]] = memref.dim %[[ARG2]], %[[C1]]
// CHECK: %[[ST_ARG2:.+]] = subtensor %[[ARG2]][%[[IV0]], 0]
// CHECK-SAME: [%[[TILE_M]], %[[N2]]]
// CHECK-SAME: [%[[TILE_M_4]], %[[N2_2]]]
// CHECK: %[[LHS:.+]] = linalg.matmul
// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_producer"
// CHECK-SAME: ins(%[[ST_ARG0]], %[[ST_ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>)
// CHECK-SAME: ins(%[[ST_ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>)
// CHECK-SAME: outs(%[[ST_ARG2]] : tensor<?x?xf32>)
// CHECK: %[[N2:.+]] = memref.dim %[[ARG1]], %[[C1]]
// CHECK: %[[N3_2:.+]] = memref.dim %[[ARG3]], %[[C1]]
// CHECK: %[[YIELD0:.+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] =
// CHECK-SAME: %[[C0]] to %[[N3_2]] step %[[C64]]
@ -59,7 +64,7 @@ module {
// CHECK-SAME: iter_args(%[[ARG10:.+]] = %[[ARG8]]) -> (tensor<?x?xf32>) {
// CHECK: %[[TILE_N2:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[N2]]]
// CHECK: %[[ST_LHS:.+]] = subtensor %[[LHS]][0, %[[IV2]]]
// CHECK-SAME: [%[[TILE_M]], %[[TILE_N2]]]
// CHECK-SAME: [%[[TILE_M_3]], %[[TILE_N2]]]
// CHECK: %[[N2_3:.+]] = memref.dim %[[ARG3]], %[[C0]]
// CHECK: %[[TILE_N2_2:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[N2_3]]]
// CHECK: %[[TILE_N3:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N3_2]]]

View File

@ -252,25 +252,36 @@ func @f5(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
}
return %E : memref<?x?xf32, offset: 0, strides: [?, ?]>
}
// CHECK-LABEL: func @f5
// CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
// CHECK: #[[BOUND_2_MAP:.+]] = affine_map<(d0)[s0] -> (2, -d0 + s0)>
// CHECK: #[[BOUND_ID_MAP:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>
// CHECK: #[[BOUND_4_MAP:.+]] = affine_map<(d0)[s0] -> (4, -d0 + s0)>
// CHECK: func @f5
// HECK-SAME: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}})
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
// CHECK-DAG: %[[B_1:.*]] = memref.dim %[[B]], %[[C1:.*]] : memref<?x?xf32, #[[$strided2D]]>
// CHECK-DAG: %[[D_0:.*]] = memref.dim %[[D]], %[[C0:.*]] : memref<?x?xf32, #[[$strided2D]]>
// CHECK-DAG: %[[D_1:.*]] = memref.dim %[[D]], %[[C1:.*]] : memref<?x?xf32, #[[$strided2D]]>
// CHECK-DAG: %[[B_1:.*]] = memref.dim %[[B]], %[[C1]] : memref<?x?xf32, #[[$strided2D]]>
// CHECK-DAG: %[[C_0:.*]] = memref.dim %[[C]], %[[C0]] : memref<?x?xf32, #[[$strided2D]]>
// CHECK-DAG: %[[D_0:.*]] = memref.dim %[[D]], %[[C0]] : memref<?x?xf32, #[[$strided2D]]>
// CHECK-DAG: %[[D_1:.*]] = memref.dim %[[D]], %[[C1]] : memref<?x?xf32, #[[$strided2D]]>
// CHECK-DAG: %[[B_00:.*]] = memref.subview %[[B]][0, 0]{{.*}}
// CHECK: scf.for %[[I:.*]] = %{{.*}} to %[[D_0]] step %{{.*}} {
// CHECK-DAG: %[[A_I0:.*]] = memref.subview %[[A]][%[[I]], 0]
// CHECK-DAG: %[[C_I0:.*]] = memref.subview %[[C]][%[[I]], 0]
// CHECK: %[[BOUND_2_C0:.+]] = affine.min #[[BOUND_2_MAP]](%[[I]])[%[[C_0]]]
// CHECK: %[[C_I0:.*]] = memref.subview %[[C]][%[[I]], 0] [%[[BOUND_2_C0]]
// CHECK: %[[BOUND_2_D0:.+]] = affine.min #[[BOUND_2_MAP]](%[[I]])[%[[D_0]]]
// CHECK: %[[A_I0:.*]] = memref.subview %[[A]][%[[I]], 0]
// Note that %[[BOUND_ID_C0]] is essentially %[[BOUND_2_C0]].
// CHECK: %[[BOUND_ID_C0:.+]] = affine.min #[[BOUND_ID_MAP]](%[[I]], %[[BOUND_2_C0]])[%[[C_0]]]
// CHECK: %[[C_I0_OUT:.*]] = memref.subview %[[C]][%[[I]], 0] [%[[BOUND_ID_C0]]
// CHECK: scf.for %[[J:.*]] = %{{.*}} to %[[B_1]] step %{{.*}} {
// CHECK: %[[E_IJ:.*]] = memref.subview %[[E]][%[[I]], %[[J]]]
// CHECK: scf.for %[[K:.*]] = %{{.*}} to %[[D_1]] step %{{.*}} {
// CHECK-DAG: %[[D_IK:.*]] = memref.subview %[[D]][%[[I]], %[[K]]]
// CHECK-DAG: %[[B_0K:.*]] = memref.subview %[[B]][0, %[[K]]]
// CHECK-DAG: %[[B_KJ:.*]] = memref.subview %[[B]][%[[K]], %[[J]]]
// CHECK: linalg.matmul ins(%[[A_I0]], %[[B_00]]{{.*}} outs(%[[C_I0]]
// CHECK: linalg.matmul ins(%[[C_I0]], %[[B_0K]]{{.*}} outs(%[[D_IK]]
// CHECK: %[[D_IK:.*]] = memref.subview %[[D]][%[[I]], %[[K]]] [2, 4]
// CHECK: %[[B_KJ:.*]] = memref.subview %[[B]][%[[K]], %[[J]]]
// CHECK: %[[B_0K:.*]] = memref.subview %[[B]][0, %[[K]]]
// CHECK: %[[BOUND_4_D1:.+]] = affine.min #[[BOUND_4_MAP]](%[[K]])[%[[D_1]]]
// CHECK: %[[D_IK_OUT:.+]] = memref.subview %[[D]][%[[I]], %[[K]]] [%[[BOUND_2_D0]], %[[BOUND_4_D1]]]
// CHECK: linalg.matmul ins(%[[A_I0]], %[[B_00]]{{.*}} outs(%[[C_I0_OUT]]
// CHECK: linalg.matmul ins(%[[C_I0]], %[[B_0K]]{{.*}} outs(%[[D_IK_OUT]]
// CHECK: linalg.matmul ins(%[[D_IK]], %[[B_KJ]]{{.*}} outs(%[[E_IJ]]
// -----

View File

@ -1,11 +1,5 @@
// RUN: mlir-opt %s -test-linalg-greedy-fusion -split-input-file | FileCheck %s
#map0 = affine_map<(d0)[s0] -> (2, -d0 + s0)>
#map1 = affine_map<(d0)[s0] -> (4, -d0 + s0)>
#map2 = affine_map<(d0)[s0] -> (3, -d0 + s0)>
#map3 = affine_map<(d0, d1) -> (2, d0 - d1)>
#map4 = affine_map<(d0, d1) -> (3, d0 - d1)>
func @matmul_tensors(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
%t0 = linalg.matmul ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg2: tensor<?x?xf32>)
@ -36,23 +30,250 @@ func @matmul_tensors(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tens
return %3 : tensor<?x?xf32>
}
// CHECK-LABEL: func @matmul_tensors(
// CHECK: #[[BOUND2_MAP:.+]] = affine_map<(d0)[s0] -> (2, -d0 + s0)>
// CHECK: #[[BOUND4_MAP:.+]] = affine_map<(d0)[s0] -> (4, -d0 + s0)>
// CHECK: func @matmul_tensors(
// CHECK-SAME: %[[A:[0-9a-z]*]]: tensor<?x?xf32>
// CHECK-SAME: %[[B:[0-9a-z]*]]: tensor<?x?xf32>
// CHECK-SAME: %[[C:[0-9a-z]*]]: tensor<?x?xf32>
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
// CHECK-DAG: %[[dA0:.*]] = memref.dim %[[A]], %[[C0]] : tensor<?x?xf32>
// CHECK-DAG: %[[dA1:.*]] = memref.dim %[[A]], %[[C1]] : tensor<?x?xf32>
// CHECK-DAG: %[[dB0:.*]] = memref.dim %[[B]], %[[C0]] : tensor<?x?xf32>
// CHECK-DAG: %[[dB1:.*]] = memref.dim %[[B]], %[[C1]] : tensor<?x?xf32>
// CHECK-DAG: %[[dC0:.*]] = memref.dim %[[C]], %[[C0]] : tensor<?x?xf32>
// CHECK-DAG: %[[dC1:.*]] = memref.dim %[[C]], %[[C1]] : tensor<?x?xf32>
// CHECK: scf.for %[[I:[0-9a-z]*]]
// CHECK: %[[stA:.*]] = subtensor %[[A]][%[[I]], 0] [2, %[[dA1]]] [1, 1] : tensor<?x?xf32> to tensor<2x?xf32>
// CHECK: %[[sizeA0:.*]] = affine.min #[[BOUND2_MAP]](%[[I]])[%[[dA0]]]
// CHECK: %[[stA:.*]] = subtensor %[[A]][%[[I]], 0] [%[[sizeA0]], %[[dA1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[sizeC0:.*]] = affine.min #[[BOUND2_MAP]](%[[I]])[%[[dC0]]]
// CHECK-NEXT: scf.for %[[J:[0-9a-z]*]]
// CHECK-NEXT: scf.for %[[K:[0-9a-z]*]] {{.*}} iter_args(%[[RES:[0-9a-z]*]]
// CHECK-DAG: %[[stB1:.*]] = subtensor %[[B]][%[[K]], %[[J]]] [4, 3] [1, 1] : tensor<?x?xf32> to tensor<4x3xf32>
// CHECK-DAG: %[[stF:.*]] = subtensor %[[RES]][%[[I]], %[[J]]] [2, 3] [1, 1] : tensor<?x?xf32> to tensor<2x3xf32>
//
// subtensors of the producing matmul.
// CHECK-DAG: %[[stB2:.*]] = subtensor %[[B]][0, %[[K]]] [%[[dA1]], 4] [1, 1] : tensor<?x?xf32> to tensor<?x4xf32>
// CHECK-DAG: %[[stC:.*]] = subtensor %[[C]][%[[I]], %[[K]]] [2, 4] [1, 1] : tensor<?x?xf32> to tensor<2x4xf32>
// CHECK: %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor<2x?xf32>, tensor<?x4xf32>) outs(%[[stC]] : tensor<2x4xf32>) -> tensor<2x4xf32>
// CHECK-NEXT: %[[stG:.*]] = linalg.matmul ins(%[[stD]], %[[stB1]] : tensor<2x4xf32>, tensor<4x3xf32>) outs(%[[stF]] : tensor<2x3xf32>) -> tensor<2x3xf32>
// CHECK: %[[sizeB1:.*]] = affine.min #[[BOUND4_MAP]](%[[K]])[%[[dB1]]]
// CHECK: %[[stB2:.*]] = subtensor %[[B]][0, %[[K]]] [%[[dB0]], %[[sizeB1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[sizeC1:.*]] = affine.min #[[BOUND4_MAP]](%[[K]])[%[[dC1]]]
// CHECK: %[[stC:.*]] = subtensor %[[C]][%[[I]], %[[K]]] [%[[sizeC0]], %[[sizeC1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK: %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[stC]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[CAST:.*]] = tensor.cast %[[stD]] : tensor<?x?xf32> to tensor<?x4xf32>
// CHECK-NEXT: %[[stG:.*]] = linalg.matmul ins(%[[CAST]], %[[stB1]] : tensor<?x4xf32>, tensor<4x3xf32>) outs(%[[stF]] : tensor<2x3xf32>) -> tensor<2x3xf32>
// CHECK-NEXT: subtensor_insert %[[stG]] into %[[RES]][%[[I]], %[[J]]]
// -----
func @conv_tensors_static(%input: tensor<1x225x225x32xf32>, %filter: tensor<3x3x3x32xf32>, %elementwise: tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32> {
%c112 = constant 112 : index
%c32 = constant 32 : index
%c16 = constant 16 : index
%c8 = constant 8 : index
%c4 = constant 4 : index
%c0 = constant 0 : index
%cst = constant 0.0 : f32
%init = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32>
%fill = linalg.fill(%init, %cst) : tensor<1x112x112x32xf32>, f32 -> tensor<1x112x112x32xf32>
%conv = linalg.conv_2d_input_nhwc_filter_hwcf
{dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
ins(%input, %filter : tensor<1x225x225x32xf32>, tensor<3x3x3x32xf32>)
outs(%fill : tensor<1x112x112x32xf32>) -> tensor<1x112x112x32xf32>
%for0 = scf.for %iv0 = %c0 to %c112 step %c8 iter_args(%arg0 = %fill) -> tensor<1x112x112x32xf32> {
%for1 = scf.for %iv1 = %c0 to %c112 step %c16 iter_args(%arg1 = %arg0) -> tensor<1x112x112x32xf32> {
%for2 = scf.for %iv2 = %c0 to %c32 step %c4 iter_args(%arg2 = %arg1) -> tensor<1x112x112x32xf32> {
%0 = subtensor %conv[0, %iv0, %iv1, %iv2][1, 8, 16, 4][1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x8x16x4xf32>
%1 = subtensor %elementwise[0, %iv0, %iv1, %iv2][1, 8, 16, 4][1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x8x16x4xf32>
%2 = subtensor %arg2[0, %iv0, %iv1, %iv2][1, 8, 16, 4][1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x8x16x4xf32>
%add = linalg.generic
{
indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
}
ins(%0, %1 : tensor<1x8x16x4xf32>, tensor<1x8x16x4xf32>) outs(%2 : tensor<1x8x16x4xf32>) {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
%result = addf %arg3, %arg4 : f32
linalg.yield %result : f32
} -> tensor<1x8x16x4xf32>
%insert = subtensor_insert %add into %arg2[0, %iv0, %iv1, %iv2] [1, 8, 16, 4] [1, 1, 1, 1] : tensor<1x8x16x4xf32> into tensor<1x112x112x32xf32>
scf.yield %insert : tensor<1x112x112x32xf32>
}
scf.yield %for2 : tensor<1x112x112x32xf32>
}
scf.yield %for1 : tensor<1x112x112x32xf32>
}
return %for0 : tensor<1x112x112x32xf32>
}
// CHECK: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * 2)>
// CHECK: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK: func @conv_tensors_static
// CHECK-SAME: (%[[INPUT:.+]]: tensor<1x225x225x32xf32>, %[[FILTER:.+]]: tensor<3x3x3x32xf32>, %[[ELEM:.+]]: tensor<1x112x112x32xf32>)
// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 112, 112, 32] : tensor<1x112x112x32xf32>
// CHECK-NEXT: %[[FILL:.+]] = linalg.fill(%[[INIT]], %cst) : tensor<1x112x112x32xf32>, f32 -> tensor<1x112x112x32xf32>
// CHECK-NEXT: scf.for %[[IV0:.+]] = %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG0:.+]] = %[[FILL]])
// CHECK-NEXT: %[[OFFSET_H:.+]] = affine.apply #[[MAP0]](%[[IV0]])
// CHECK-NEXT: scf.for %[[IV1:.+]] = %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG1:.+]] = %[[ARG0]])
// CHECK-NEXT: %[[OFFSET_W:.+]] = affine.apply #[[MAP0]](%[[IV1]])
// CHECK-NEXT: %[[ST_INPUT:.+]] = subtensor %arg0[0, %[[OFFSET_H]], %[[OFFSET_W]], 0] [1, 17, 33, 32] [1, 1, 1, 1] : tensor<1x225x225x32xf32> to tensor<1x17x33x32xf32>
// CHECK-NEXT: scf.for %[[IV2:.+]] = %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG2:.+]] = %[[ARG1]])
// CHECK-NEXT: %[[ST_ELEM:.+]] = subtensor %[[ELEM]][0, %[[IV0]], %[[IV1]], %[[IV2]]] [1, 8, 16, 4] [1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x8x16x4xf32>
// CHECK-NEXT: %[[ST_ARG2:.+]] = subtensor %[[ARG2]][0, %[[IV0]], %[[IV1]], %[[IV2]]] [1, 8, 16, 4] [1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x8x16x4xf32>
// CHECK-NEXT: %[[ST_FILTER:.+]] = subtensor %[[FILTER]][0, 0, 0, %[[IV2]]] [3, 3, 3, 4] [1, 1, 1, 1] : tensor<3x3x3x32xf32> to tensor<3x3x3x4xf32>
// CHECK-NEXT: %[[ST_FILL:.+]] = subtensor %[[FILL]][0, %[[IV0]], %[[IV1]], %[[IV2]]] [1, 8, 16, 4] [1, 1, 1, 1] : tensor<1x112x112x32xf32> to tensor<1x8x16x4xf32>
// CHECK-NEXT: %[[ST_CONV:.+]] = linalg.conv_2d_input_nhwc_filter_hwcf
// CHECK-SAME: ins(%[[ST_INPUT]], %[[ST_FILTER]] : tensor<1x17x33x32xf32>, tensor<3x3x3x4xf32>)
// CHECK-SAME: outs(%[[ST_FILL]] : tensor<1x8x16x4xf32>)
// CHECK-NEXT: %[[ADD:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ST_CONV]], %[[ST_ELEM]] : tensor<1x8x16x4xf32>, tensor<1x8x16x4xf32>)
// CHECK-SAME: outs(%[[ST_ARG2]] : tensor<1x8x16x4xf32>)
// CHECK: subtensor_insert %[[ADD]] into %[[ARG2]][0, %[[IV0]], %[[IV1]], %[[IV2]]] [1, 8, 16, 4]
// -----
#bound4_map = affine_map<(d0)[s0] -> (4, -d0 + s0)>
#bound8_map = affine_map<(d0)[s0] -> (8, -d0 + s0)>
#bound16_map = affine_map<(d0)[s0] -> (16, -d0 + s0)>
func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?xf32>, %elementwise: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
%cst = constant 0.0 : f32
%c0 = constant 0 : index
%c1 = constant 1 : index
%c2 = constant 2 : index
%c3 = constant 3 : index
%c4 = constant 4 : index
%c8 = constant 8 : index
%c16 = constant 16 : index
%n = memref.dim %elementwise, %c0 : tensor<?x?x?x?xf32>
%oh = memref.dim %elementwise, %c1 : tensor<?x?x?x?xf32>
%ow = memref.dim %elementwise, %c2 : tensor<?x?x?x?xf32>
%oc = memref.dim %elementwise, %c3 : tensor<?x?x?x?xf32>
%init = linalg.init_tensor [%n, %oh, %ow, %oc] : tensor<?x?x?x?xf32>
%fill = linalg.fill(%init, %cst) : tensor<?x?x?x?xf32>, f32 -> tensor<?x?x?x?xf32>
%conv = linalg.conv_2d_input_nhwc_filter_hwcf
{dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>}
ins(%input, %filter : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
outs(%fill : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
%for0 = scf.for %iv0 = %c0 to %oh step %c8 iter_args(%arg0 = %fill) -> tensor<?x?x?x?xf32> {
%for1 = scf.for %iv1 = %c0 to %ow step %c16 iter_args(%arg1 = %arg0) -> tensor<?x?x?x?xf32> {
%for2 = scf.for %iv2 = %c0 to %oc step %c4 iter_args(%arg2 = %arg1) -> tensor<?x?x?x?xf32> {
%for3 = scf.for %iv3 = %c0 to %oc step %c2 iter_args(%arg3 = %arg2) -> tensor<?x?x?x?xf32> {
%n_size = affine.min #bound8_map(%iv0)[%n]
%oh_size = affine.min #bound16_map(%iv1)[%oh]
%ow_size = affine.min #bound4_map(%iv2)[%ow]
%oc_size = affine.min #bound4_map(%iv2)[%oc]
%0 = subtensor %conv[%iv0, %iv1, %iv2, %iv3][%n_size, %oh_size, %ow_size, %oc_size][1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
%1 = subtensor %elementwise[%iv0, %iv1, %iv2, %iv3][%n_size, %oh_size, %ow_size, %oc_size][1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
%2 = subtensor %arg3[%iv0, %iv1, %iv2, %iv3][%n_size, %oh_size, %ow_size, %oc_size][1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
%add = linalg.generic
{
indexing_maps = [
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
}
ins(%0, %1 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) outs(%2 : tensor<?x?x?x?xf32>) {
^bb0(%arg4: f32, %arg5: f32, %arg6: f32):
%result = addf %arg4, %arg5 : f32
linalg.yield %result : f32
} -> tensor<?x?x?x?xf32>
%insert = subtensor_insert %add into %arg3[%iv0, %iv1, %iv2, %iv3] [%n_size, %oh_size, %ow_size, %oc_size] [1, 1, 1, 1] : tensor<?x?x?x?xf32> into tensor<?x?x?x?xf32>
scf.yield %insert : tensor<?x?x?x?xf32>
}
scf.yield %for3 : tensor<?x?x?x?xf32>
}
scf.yield %for2 : tensor<?x?x?x?xf32>
}
scf.yield %for1 : tensor<?x?x?x?xf32>
}
return %for0 : tensor<?x?x?x?xf32>
}
// -----
// CHECK: #[[BOUND8_MAP:.+]] = affine_map<(d0)[s0] -> (8, -d0 + s0)>
// CHECK: #[[BOUND_MAP:.+]] = affine_map<(d0, d1)[s0] -> (d1, -d0 + s0)>
// CHECK: #[[BOUND16_MAP:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)>
// CHECK: #[[X2_MAP:.+]] = affine_map<(d0) -> (d0 * 2)>
// CHECK: #[[INPUT_BOUND:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * 2 + s0 - 2, d1 * -2 + s1)>
// CHECK: #[[BOUND4_MAP:.+]] = affine_map<(d0)[s0] -> (4, -d0 + s0)>
// CHECK: func @conv_tensors_dynamic
// CHECK-SAME: (%[[INPUT]]: tensor<?x?x?x?xf32>, %[[FILTER]]: tensor<?x?x?x?xf32>, %[[ELEM]]: tensor<?x?x?x?xf32>)
// CHECK-DAG: %[[C0:.+]] = constant 0 : index
// CHECK-DAG: %[[C1:.+]] = constant 1 : index
// CHECK-DAG: %[[C2:.+]] = constant 2 : index
// CHECK-DAG: %[[C3:.+]] = constant 3 : index
// CHECK-DAG: %[[ELEM_N:.+]] = memref.dim %[[ELEM]], %[[C0]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[ELEM_OH:.+]] = memref.dim %[[ELEM]], %[[C1]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[ELEM_OW:.+]] = memref.dim %[[ELEM]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[ELEM_OC:.+]] = memref.dim %[[ELEM]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[ELEM_N]], %[[ELEM_OH]], %[[ELEM_OW]], %[[ELEM_OC]]] : tensor<?x?x?x?xf32>
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %cst) : tensor<?x?x?x?xf32>, f32 -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[FILTER_H:.+]] = memref.dim %[[FILTER]], %[[C0]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[FILTER_W:.+]] = memref.dim %[[FILTER]], %[[C1]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_N:.+]] = memref.dim %[[INPUT]], %[[C0]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_H:.+]] = memref.dim %[[INPUT]], %[[C1]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_W:.+]] = memref.dim %[[INPUT]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_C:.+]] = memref.dim %[[INPUT]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[FILTER_IC:.+]] = memref.dim %[[FILTER]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[FILTER_OC:.+]] = memref.dim %[[FILTER]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK: scf.for %[[IV0:.+]] = %{{.+}} to %[[ELEM_OH]] step %{{.+}} iter_args(%{{.+}} = %[[FILL]])
// CHECK-NEXT: %[[SIZE_ELEM_N:.+]] = affine.min #[[BOUND8_MAP]](%[[IV0]])[%[[ELEM_N]]]
// CHECK-NEXT: %[[SIZE_INPUT_N:.+]] = affine.min #[[BOUND_MAP]](%[[IV0]], %[[SIZE_ELEM_N]])[%[[INPUT_N]]]
// CHECK-NEXT: %[[SIZE_ELEM_N_2:.+]] = affine.min #[[BOUND_MAP]](%[[IV0]], %[[SIZE_ELEM_N]])[%[[ELEM_N]]]
// CHECK-NEXT: scf.for %[[IV1:.+]] = %{{.+}} to %[[ELEM_OW]]
// CHECK-NEXT: %[[SIZE_ELEM_OH:.+]] = affine.min #[[BOUND16_MAP]](%[[IV1]])[%[[ELEM_OH]]]
// CHECK-NEXT: %[[OFFSET_OH:.+]] = affine.apply #[[X2_MAP]](%[[IV1]])
// CHECK-NEXT: %[[SIZE_INPUT_H:.+]] = affine.min #[[INPUT_BOUND]](%[[SIZE_ELEM_OH]], %[[IV1]])[%[[FILTER_H]], %[[INPUT_H]]]
// CHECK-NEXT: %[[SIZE_ELEM_OH_2:.+]] = affine.min #[[BOUND_MAP]](%[[IV1]], %[[SIZE_ELEM_OH]])[%[[ELEM_OH]]]
// CHECK-NEXT: scf.for %[[IV2:.+]] = %{{.+}} to %[[ELEM_OC]]
// CHECK-NEXT: %[[SIZE_ELEM_OW:.+]] = affine.min #[[BOUND4_MAP]](%[[IV2]])[%[[ELEM_OW]]]
// CHECK-NEXT: %[[SIZE_ELEM_OC:.+]] = affine.min #[[BOUND4_MAP]](%[[IV2]])[%[[ELEM_OC]]]
// CHECK-NEXT: %[[OFFSET_OW:.+]] = affine.apply #[[X2_MAP]](%[[IV2]])
// CHECK-NEXT: %[[SIZE_INPUT_W:.+]] = affine.min #[[INPUT_BOUND]](%[[SIZE_ELEM_OW]], %[[IV2]])[%[[FILTER_W]], %[[INPUT_W]]]
// CHECK-NEXT: %[[ST_INPUT:.+]] = subtensor %[[INPUT]][%[[IV0]], %[[OFFSET_OH]], %[[OFFSET_OW]], 0]
// CHECK-SAME: [%[[SIZE_INPUT_N]], %[[SIZE_INPUT_H]], %[[SIZE_INPUT_W]], %[[INPUT_C]]]
// CHECK-NEXT: %[[SIZE_ELEM_OW_2:.+]] = affine.min #[[BOUND_MAP]](%[[IV2]], %[[SIZE_ELEM_OW]])[%[[ELEM_OW]]]
// CHECK-NEXT: scf.for %[[IV3:.+]] = %{{.+}} to %[[ELEM_OC]] step %{{.+}} iter_args(%[[ARG:[a-z0-9]+]]
// CHECK-NEXT: %[[ST_ELEM:.+]] = subtensor %[[ELEM]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
// CHECK-SAME: [%[[SIZE_ELEM_N]], %[[SIZE_ELEM_OH]], %[[SIZE_ELEM_OW]], %[[SIZE_ELEM_OC]]]
// CHECK-NEXT: %[[ST_ARG:.+]] = subtensor %[[ARG]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
// CHECK-SAME: [%[[SIZE_ELEM_N]], %[[SIZE_ELEM_OH]], %[[SIZE_ELEM_OW]], %[[SIZE_ELEM_OC]]]
// CHECK-NEXT: %[[SIZE_ELEM_OC_2:.+]] = affine.min #[[BOUND_MAP]](%[[IV3]], %[[SIZE_ELEM_OC]])[%[[FILTER_OC]]]
// CHECK-NEXT: %[[ST_FILTER:.+]] = subtensor %[[FILTER]][0, 0, 0, %[[IV3]]]
// CHECK-SAME: [%[[FILTER_H]], %[[FILTER_W]], %[[FILTER_IC]], %[[SIZE_ELEM_OC_2]]]
// CHECK-NEXT: %[[SIZE_ELEM_OC_3:.+]] = affine.min #[[BOUND_MAP]](%[[IV3]], %[[SIZE_ELEM_OC]])[%[[ELEM_OC]]]
// CHECK-NEXT: %[[ST_FILL:.+]] = subtensor %[[FILL]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
// CHECK-SAME: [%[[SIZE_ELEM_N_2]], %[[SIZE_ELEM_OH_2]], %[[SIZE_ELEM_OW_2]], %[[SIZE_ELEM_OC_3]]]
// CHECK-NEXT: %[[ST_CONV:.+]] = linalg.conv_2d_input_nhwc_filter_hwcf
// CHECK-SAME: ins(%[[ST_INPUT]], %[[ST_FILTER]] : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
// CHECK-SAME: outs(%[[ST_FILL]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// CHECK-NEXT: %[[ST_ADD:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ST_CONV]], %[[ST_ELEM]] : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
// CHECK-SAME: outs(%[[ST_ARG]] : tensor<?x?x?x?xf32>)
// CHECK: subtensor_insert %[[ST_ADD]] into %[[ARG]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
// CHECK-SAME: [%[[SIZE_ELEM_N]], %[[SIZE_ELEM_OH]], %[[SIZE_ELEM_OW]], %[[SIZE_ELEM_OC]]]

View File

@ -179,6 +179,10 @@ static LogicalResult fuseLinalgOpsGreedily(FuncOp f) {
namespace {
struct TestLinalgGreedyFusion
: public PassWrapper<TestLinalgGreedyFusion, FunctionPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AffineDialect, linalg::LinalgDialect, memref::MemRefDialect,
scf::SCFDialect>();
}
void runOnFunction() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns =