[mlir][linalg] remove interchange option on linalg to loop lowering.

The interchange option attached to the linalg to loop lowering affects only the loops and does not update the memory accesses generated in to body of the operation. Instead of performing the interchange during the loop lowering use the interchange pattern.

Differential Revision: https://reviews.llvm.org/D100758
This commit is contained in:
Tobias Gysi 2021-04-22 08:22:37 +00:00
parent fbc6f42dbe
commit 0e777e4ad7
4 changed files with 48 additions and 183 deletions

View File

@ -62,12 +62,6 @@ def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> {
let summary = "Lower the operations from the linalg dialect into affine "
"loops";
let constructor = "mlir::createConvertLinalgToAffineLoopsPass()";
let options = [
ListOption<"interchangeVector", "interchange-vector", "unsigned",
"Permute the loops in the nest following the given "
"interchange vector",
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
];
let dependentDialects = [
"AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"];
}
@ -75,12 +69,6 @@ def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> {
def LinalgLowerToLoops : FunctionPass<"convert-linalg-to-loops"> {
let summary = "Lower the operations from the linalg dialect into loops";
let constructor = "mlir::createConvertLinalgToLoopsPass()";
let options = [
ListOption<"interchangeVector", "interchange-vector", "unsigned",
"Permute the loops in the nest following the given "
"interchange vector",
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
];
let dependentDialects = [
"linalg::LinalgDialect",
"scf::SCFDialect",
@ -103,12 +91,6 @@ def LinalgLowerToParallelLoops
let summary = "Lower the operations from the linalg dialect into parallel "
"loops";
let constructor = "mlir::createConvertLinalgToParallelLoopsPass()";
let options = [
ListOption<"interchangeVector", "interchange-vector", "unsigned",
"Permute the loops in the nest following the given "
"interchange vector",
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
];
let dependentDialects = [
"AffineDialect",
"linalg::LinalgDialect",

View File

@ -338,28 +338,16 @@ LogicalResult vectorizeLinalgOp(OpBuilder &builder, Operation *op,
/// Emits a loop nest of `LoopTy` with the proper body for `op`.
template <typename LoopTy>
Optional<LinalgLoops>
linalgLowerOpToLoops(OpBuilder &builder, Operation *op,
ArrayRef<unsigned> interchangeVector = {});
Optional<LinalgLoops> linalgLowerOpToLoops(OpBuilder &builder, Operation *op);
/// Emits a loop nest of `scf.for` with the proper body for `op`. The generated
/// loop nest will follow the `interchangeVector`-permutated iterator order. If
/// `interchangeVector` is empty, then no permutation happens.
LogicalResult linalgOpToLoops(OpBuilder &builder, Operation *op,
ArrayRef<unsigned> interchangeVector = {});
/// Emits a loop nest of `scf.for` with the proper body for `op`.
LogicalResult linalgOpToLoops(OpBuilder &builder, Operation *op);
/// Emits a loop nest of `scf.parallel` with the proper body for `op`. The
/// generated loop nest will follow the `interchangeVector`-permutated
// iterator order. If `interchangeVector` is empty, then no permutation happens.
LogicalResult
linalgOpToParallelLoops(OpBuilder &builder, Operation *op,
ArrayRef<unsigned> interchangeVector = {});
/// Emits a loop nest of `scf.parallel` with the proper body for `op`.
LogicalResult linalgOpToParallelLoops(OpBuilder &builder, Operation *op);
/// Emits a loop nest of `affine.for` with the proper body for `op`. The
/// generated loop nest will follow the `interchangeVector`-permutated
// iterator order. If `interchangeVector` is empty, then no permutation happens.
LogicalResult linalgOpToAffineLoops(OpBuilder &builder, Operation *op,
ArrayRef<unsigned> interchangeVector = {});
/// Emits a loop nest of `affine.for` with the proper body for `op`.
LogicalResult linalgOpToAffineLoops(OpBuilder &builder, Operation *op);
//===----------------------------------------------------------------------===//
// Preconditions that ensure the corresponding transformation succeeds and can
@ -808,10 +796,9 @@ struct LinalgLoweringPattern : public RewritePattern {
LinalgLoweringPattern(
MLIRContext *context, LinalgLoweringType loweringType,
LinalgTransformationFilter filter = LinalgTransformationFilter(),
ArrayRef<unsigned> interchangeVector = {}, PatternBenefit benefit = 1)
PatternBenefit benefit = 1)
: RewritePattern(OpTy::getOperationName(), benefit, context),
filter(filter), loweringType(loweringType),
interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
filter(filter), loweringType(loweringType) {}
// TODO: Move implementation to .cpp once named ops are auto-generated.
LogicalResult matchAndRewrite(Operation *op,
@ -827,15 +814,15 @@ struct LinalgLoweringPattern : public RewritePattern {
// TODO: Move lowering to library calls here.
return failure();
case LinalgLoweringType::Loops:
if (failed(linalgOpToLoops(rewriter, op, interchangeVector)))
if (failed(linalgOpToLoops(rewriter, op)))
return failure();
break;
case LinalgLoweringType::AffineLoops:
if (failed(linalgOpToAffineLoops(rewriter, op, interchangeVector)))
if (failed(linalgOpToAffineLoops(rewriter, op)))
return failure();
break;
case LinalgLoweringType::ParallelLoops:
if (failed(linalgOpToParallelLoops(rewriter, op, interchangeVector)))
if (failed(linalgOpToParallelLoops(rewriter, op)))
return failure();
break;
}
@ -850,8 +837,6 @@ private:
/// Controls whether the pattern lowers to library calls, scf.for, affine.for
/// or scf.parallel.
LinalgLoweringType loweringType;
/// Permutated loop order in the generated loop nest.
SmallVector<unsigned, 4> interchangeVector;
};
/// Linalg generalization patterns

View File

@ -457,9 +457,8 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs,
}
template <typename LoopTy>
static Optional<LinalgLoops>
linalgOpToLoopsImpl(Operation *op, OpBuilder &builder,
ArrayRef<unsigned> interchangeVector) {
static Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op,
OpBuilder &builder) {
using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy;
ScopedContext scope(builder, op->getLoc());
@ -472,13 +471,6 @@ linalgOpToLoopsImpl(Operation *op, OpBuilder &builder,
auto loopRanges = linalgOp.createLoopRanges(builder, op->getLoc());
auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue());
if (!interchangeVector.empty()) {
assert(interchangeVector.size() == loopRanges.size());
assert(interchangeVector.size() == iteratorTypes.size());
applyPermutationToVector(loopRanges, interchangeVector);
applyPermutationToVector(iteratorTypes, interchangeVector);
}
SmallVector<Value, 4> allIvs;
GenerateLoopNest<LoopTy>::doit(
loopRanges, /*iterInitArgs=*/{}, iteratorTypes,
@ -511,11 +503,10 @@ linalgOpToLoopsImpl(Operation *op, OpBuilder &builder,
}
/// Replace the index operations in the body of the loop nest by the matching
/// induction variables. If available use the interchange vector to map the
/// interchanged induction variables to the dimension of the index operation.
static void replaceIndexOpsByInductionVariables(
LinalgOp linalgOp, PatternRewriter &rewriter, ArrayRef<Operation *> loopOps,
ArrayRef<unsigned> interchangeVector) {
/// induction variables.
static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp,
PatternRewriter &rewriter,
ArrayRef<Operation *> loopOps) {
// Extract the induction variables of the loop nest from outer to inner.
SmallVector<Value> allIvs;
for (Operation *loopOp : loopOps) {
@ -538,16 +529,8 @@ static void replaceIndexOpsByInductionVariables(
if (!loopOps.empty()) {
LoopLikeOpInterface loopOp = loopOps.back();
for (IndexOp indexOp :
llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>())) {
// Search the indexing dimension in the interchange vector if available.
assert(interchangeVector.empty() ||
interchangeVector.size() == linalgOp.getNumLoops());
const auto *it = llvm::find(interchangeVector, indexOp.dim());
uint64_t dim = it != interchangeVector.end()
? std::distance(interchangeVector.begin(), it)
: indexOp.dim();
rewriter.replaceOp(indexOp, allIvs[dim]);
}
llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>()))
rewriter.replaceOp(indexOp, allIvs[indexOp.dim()]);
}
}
@ -555,39 +538,31 @@ namespace {
template <typename LoopType>
class LinalgRewritePattern : public RewritePattern {
public:
LinalgRewritePattern(MLIRContext *context,
ArrayRef<unsigned> interchangeVector)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
LinalgRewritePattern(MLIRContext *context)
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override {
auto linalgOp = dyn_cast<LinalgOp>(op);
if (!isa<LinalgOp>(op))
return failure();
Optional<LinalgLoops> loopOps =
linalgOpToLoopsImpl<LoopType>(op, rewriter, interchangeVector);
Optional<LinalgLoops> loopOps = linalgOpToLoopsImpl<LoopType>(op, rewriter);
if (!loopOps.hasValue())
return failure();
replaceIndexOpsByInductionVariables(linalgOp, rewriter, loopOps.getValue(),
interchangeVector);
replaceIndexOpsByInductionVariables(linalgOp, rewriter, loopOps.getValue());
rewriter.eraseOp(op);
return success();
}
private:
SmallVector<unsigned, 4> interchangeVector;
};
struct FoldAffineOp;
} // namespace
template <typename LoopType>
static void lowerLinalgToLoopsImpl(FuncOp funcOp,
ArrayRef<unsigned> interchangeVector) {
static void lowerLinalgToLoopsImpl(FuncOp funcOp) {
MLIRContext *context = funcOp.getContext();
RewritePatternSet patterns(context);
patterns.add<LinalgRewritePattern<LoopType>>(context, interchangeVector);
patterns.add<LinalgRewritePattern<LoopType>>(context);
memref::DimOp::getCanonicalizationPatterns(patterns, context);
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
patterns.add<FoldAffineOp>(context);
@ -639,7 +614,7 @@ struct LowerToAffineLoops
registry.insert<memref::MemRefDialect>();
}
void runOnFunction() override {
lowerLinalgToLoopsImpl<AffineForOp>(getFunction(), interchangeVector);
lowerLinalgToLoopsImpl<AffineForOp>(getFunction());
}
};
@ -648,14 +623,14 @@ struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> {
registry.insert<memref::MemRefDialect, scf::SCFDialect>();
}
void runOnFunction() override {
lowerLinalgToLoopsImpl<scf::ForOp>(getFunction(), interchangeVector);
lowerLinalgToLoopsImpl<scf::ForOp>(getFunction());
}
};
struct LowerToParallelLoops
: public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> {
void runOnFunction() override {
lowerLinalgToLoopsImpl<scf::ParallelOp>(getFunction(), interchangeVector);
lowerLinalgToLoopsImpl<scf::ParallelOp>(getFunction());
}
};
} // namespace
@ -676,43 +651,38 @@ mlir::createConvertLinalgToAffineLoopsPass() {
/// Emits a loop nest with the proper body for `op`.
template <typename LoopTy>
Optional<LinalgLoops>
mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, Operation *op,
ArrayRef<unsigned> interchangeVector) {
return linalgOpToLoopsImpl<LoopTy>(op, builder, interchangeVector);
Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder,
Operation *op) {
return linalgOpToLoopsImpl<LoopTy>(op, builder);
}
template Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops<AffineForOp>(
OpBuilder &builder, Operation *op, ArrayRef<unsigned> interchangeVector);
template Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops<scf::ForOp>(
OpBuilder &builder, Operation *op, ArrayRef<unsigned> interchangeVector);
template Optional<LinalgLoops>
mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>(
OpBuilder &builder, Operation *op, ArrayRef<unsigned> interchangeVector);
mlir::linalg::linalgLowerOpToLoops<AffineForOp>(OpBuilder &builder,
Operation *op);
template Optional<LinalgLoops>
mlir::linalg::linalgLowerOpToLoops<scf::ForOp>(OpBuilder &builder,
Operation *op);
template Optional<LinalgLoops>
mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>(OpBuilder &builder,
Operation *op);
/// Emits a loop nest of `affine.for` with the proper body for `op`.
LogicalResult
mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder, Operation *op,
ArrayRef<unsigned> interchangeVector) {
Optional<LinalgLoops> loops =
linalgLowerOpToLoops<AffineForOp>(builder, op, interchangeVector);
LogicalResult mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder,
Operation *op) {
Optional<LinalgLoops> loops = linalgLowerOpToLoops<AffineForOp>(builder, op);
return loops ? success() : failure();
}
/// Emits a loop nest of `scf.for` with the proper body for `op`.
LogicalResult
mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op,
ArrayRef<unsigned> interchangeVector) {
Optional<LinalgLoops> loops =
linalgLowerOpToLoops<scf::ForOp>(builder, op, interchangeVector);
LogicalResult mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op) {
Optional<LinalgLoops> loops = linalgLowerOpToLoops<scf::ForOp>(builder, op);
return loops ? success() : failure();
}
/// Emits a loop nest of `scf.parallel` with the proper body for `op`.
LogicalResult
mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder, Operation *op,
ArrayRef<unsigned> interchangeVector) {
LogicalResult mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder,
Operation *op) {
Optional<LinalgLoops> loops =
linalgLowerOpToLoops<scf::ParallelOp>(builder, op, interchangeVector);
linalgLowerOpToLoops<scf::ParallelOp>(builder, op);
return loops ? success() : failure();
}

View File

@ -1,72 +0,0 @@
// RUN: mlir-opt %s -convert-linalg-to-loops="interchange-vector=4,0,3,1,2" -split-input-file | FileCheck --check-prefix=LOOP %s
// RUN: mlir-opt %s -convert-linalg-to-parallel-loops="interchange-vector=4,0,3,1,2" -split-input-file | FileCheck --check-prefix=PARALLEL %s
// RUN: mlir-opt %s -convert-linalg-to-affine-loops="interchange-vector=4,0,3,1,2" -split-input-file | FileCheck --check-prefix=AFFINE %s
func @copy(%input: memref<1x2x3x4x5xf32>, %output: memref<1x2x3x4x5xf32>) {
linalg.copy(%input, %output): memref<1x2x3x4x5xf32>, memref<1x2x3x4x5xf32>
return
}
// LOOP: scf.for %{{.*}} = %c0 to %c5 step %c1
// LOOP: scf.for %{{.*}} = %c0 to %c1 step %c1
// LOOP: scf.for %{{.*}} = %c0 to %c4 step %c1
// LOOP: scf.for %{{.*}} = %c0 to %c2 step %c1
// LOOP: scf.for %{{.*}} = %c0 to %c3 step %c1
// PARALLEL: scf.parallel
// PARALLEL-SAME: to (%c5, %c1, %c4, %c2, %c3)
// AFFINE: affine.for %{{.*}} = 0 to 5
// AFFINE: affine.for %{{.*}} = 0 to 1
// AFFINE: affine.for %{{.*}} = 0 to 4
// AFFINE: affine.for %{{.*}} = 0 to 2
// AFFINE: affine.for %{{.*}} = 0 to 3
// -----
#map = affine_map<(i, j, k, l, m) -> (i, j, k, l, m)>
func @generic(%output: memref<1x2x3x4x5xindex>) {
linalg.generic {indexing_maps = [#map],
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
outs(%output : memref<1x2x3x4x5xindex>) {
^bb0(%arg0 : index):
%i = linalg.index 0 : index
%j = linalg.index 1 : index
%k = linalg.index 2 : index
%l = linalg.index 3 : index
%m = linalg.index 4 : index
%0 = addi %i, %j : index
%1 = addi %0, %k : index
%2 = addi %1, %l : index
%3 = addi %2, %m : index
linalg.yield %3: index
}
return
}
// LOOP: scf.for %[[m:.*]] = %c0 to %c5 step %c1
// LOOP: scf.for %[[i:.*]] = %c0 to %c1 step %c1
// LOOP: scf.for %[[l:.*]] = %c0 to %c4 step %c1
// LOOP: scf.for %[[j:.*]] = %c0 to %c2 step %c1
// LOOP: scf.for %[[k:.*]] = %c0 to %c3 step %c1
// LOOP: %{{.*}} = addi %[[i]], %[[j]] : index
// LOOP: %{{.*}} = addi %{{.*}}, %[[k]] : index
// LOOP: %{{.*}} = addi %{{.*}}, %[[l]] : index
// LOOP: %{{.*}} = addi %{{.*}}, %[[m]] : index
// PARALLEL: scf.parallel (%[[m:.*]], %[[i:.*]], %[[l:.*]], %[[j:.*]], %[[k:.*]]) =
// PARALLEL-SAME: to (%c5, %c1, %c4, %c2, %c3)
// PARALLEL: %{{.*}} = addi %[[i]], %[[j]] : index
// PARALLEL: %{{.*}} = addi %{{.*}}, %[[k]] : index
// PARALLEL: %{{.*}} = addi %{{.*}}, %[[l]] : index
// PARALLEL: %{{.*}} = addi %{{.*}}, %[[m]] : index
// AFFINE: affine.for %[[m:.*]] = 0 to 5
// AFFINE: affine.for %[[i:.*]] = 0 to 1
// AFFINE: affine.for %[[l:.*]] = 0 to 4
// AFFINE: affine.for %[[j:.*]] = 0 to 2
// AFFINE: affine.for %[[k:.*]] = 0 to 3
// AFFINE: %{{.*}} = addi %[[i]], %[[j]] : index
// AFFINE: %{{.*}} = addi %{{.*}}, %[[k]] : index
// AFFINE: %{{.*}} = addi %{{.*}}, %[[l]] : index
// AFFINE: %{{.*}} = addi %{{.*}}, %[[m]] : index