[mlir][linalg] remove interchange option on linalg to loop lowering.

The interchange option attached to the linalg to loop lowering affects only the loops and does not update the memory accesses generated in to body of the operation. Instead of performing the interchange during the loop lowering use the interchange pattern.

Differential Revision: https://reviews.llvm.org/D100758
This commit is contained in:
Tobias Gysi 2021-04-22 08:22:37 +00:00
parent fbc6f42dbe
commit 0e777e4ad7
4 changed files with 48 additions and 183 deletions

View File

@ -62,12 +62,6 @@ def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> {
let summary = "Lower the operations from the linalg dialect into affine " let summary = "Lower the operations from the linalg dialect into affine "
"loops"; "loops";
let constructor = "mlir::createConvertLinalgToAffineLoopsPass()"; let constructor = "mlir::createConvertLinalgToAffineLoopsPass()";
let options = [
ListOption<"interchangeVector", "interchange-vector", "unsigned",
"Permute the loops in the nest following the given "
"interchange vector",
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
];
let dependentDialects = [ let dependentDialects = [
"AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"]; "AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect"];
} }
@ -75,12 +69,6 @@ def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> {
def LinalgLowerToLoops : FunctionPass<"convert-linalg-to-loops"> { def LinalgLowerToLoops : FunctionPass<"convert-linalg-to-loops"> {
let summary = "Lower the operations from the linalg dialect into loops"; let summary = "Lower the operations from the linalg dialect into loops";
let constructor = "mlir::createConvertLinalgToLoopsPass()"; let constructor = "mlir::createConvertLinalgToLoopsPass()";
let options = [
ListOption<"interchangeVector", "interchange-vector", "unsigned",
"Permute the loops in the nest following the given "
"interchange vector",
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
];
let dependentDialects = [ let dependentDialects = [
"linalg::LinalgDialect", "linalg::LinalgDialect",
"scf::SCFDialect", "scf::SCFDialect",
@ -103,12 +91,6 @@ def LinalgLowerToParallelLoops
let summary = "Lower the operations from the linalg dialect into parallel " let summary = "Lower the operations from the linalg dialect into parallel "
"loops"; "loops";
let constructor = "mlir::createConvertLinalgToParallelLoopsPass()"; let constructor = "mlir::createConvertLinalgToParallelLoopsPass()";
let options = [
ListOption<"interchangeVector", "interchange-vector", "unsigned",
"Permute the loops in the nest following the given "
"interchange vector",
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
];
let dependentDialects = [ let dependentDialects = [
"AffineDialect", "AffineDialect",
"linalg::LinalgDialect", "linalg::LinalgDialect",

View File

@ -338,28 +338,16 @@ LogicalResult vectorizeLinalgOp(OpBuilder &builder, Operation *op,
/// Emits a loop nest of `LoopTy` with the proper body for `op`. /// Emits a loop nest of `LoopTy` with the proper body for `op`.
template <typename LoopTy> template <typename LoopTy>
Optional<LinalgLoops> Optional<LinalgLoops> linalgLowerOpToLoops(OpBuilder &builder, Operation *op);
linalgLowerOpToLoops(OpBuilder &builder, Operation *op,
ArrayRef<unsigned> interchangeVector = {});
/// Emits a loop nest of `scf.for` with the proper body for `op`. The generated /// Emits a loop nest of `scf.for` with the proper body for `op`.
/// loop nest will follow the `interchangeVector`-permutated iterator order. If LogicalResult linalgOpToLoops(OpBuilder &builder, Operation *op);
/// `interchangeVector` is empty, then no permutation happens.
LogicalResult linalgOpToLoops(OpBuilder &builder, Operation *op,
ArrayRef<unsigned> interchangeVector = {});
/// Emits a loop nest of `scf.parallel` with the proper body for `op`. The /// Emits a loop nest of `scf.parallel` with the proper body for `op`.
/// generated loop nest will follow the `interchangeVector`-permutated LogicalResult linalgOpToParallelLoops(OpBuilder &builder, Operation *op);
// iterator order. If `interchangeVector` is empty, then no permutation happens.
LogicalResult
linalgOpToParallelLoops(OpBuilder &builder, Operation *op,
ArrayRef<unsigned> interchangeVector = {});
/// Emits a loop nest of `affine.for` with the proper body for `op`. The /// Emits a loop nest of `affine.for` with the proper body for `op`.
/// generated loop nest will follow the `interchangeVector`-permutated LogicalResult linalgOpToAffineLoops(OpBuilder &builder, Operation *op);
// iterator order. If `interchangeVector` is empty, then no permutation happens.
LogicalResult linalgOpToAffineLoops(OpBuilder &builder, Operation *op,
ArrayRef<unsigned> interchangeVector = {});
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Preconditions that ensure the corresponding transformation succeeds and can // Preconditions that ensure the corresponding transformation succeeds and can
@ -808,10 +796,9 @@ struct LinalgLoweringPattern : public RewritePattern {
LinalgLoweringPattern( LinalgLoweringPattern(
MLIRContext *context, LinalgLoweringType loweringType, MLIRContext *context, LinalgLoweringType loweringType,
LinalgTransformationFilter filter = LinalgTransformationFilter(), LinalgTransformationFilter filter = LinalgTransformationFilter(),
ArrayRef<unsigned> interchangeVector = {}, PatternBenefit benefit = 1) PatternBenefit benefit = 1)
: RewritePattern(OpTy::getOperationName(), benefit, context), : RewritePattern(OpTy::getOperationName(), benefit, context),
filter(filter), loweringType(loweringType), filter(filter), loweringType(loweringType) {}
interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
// TODO: Move implementation to .cpp once named ops are auto-generated. // TODO: Move implementation to .cpp once named ops are auto-generated.
LogicalResult matchAndRewrite(Operation *op, LogicalResult matchAndRewrite(Operation *op,
@ -827,15 +814,15 @@ struct LinalgLoweringPattern : public RewritePattern {
// TODO: Move lowering to library calls here. // TODO: Move lowering to library calls here.
return failure(); return failure();
case LinalgLoweringType::Loops: case LinalgLoweringType::Loops:
if (failed(linalgOpToLoops(rewriter, op, interchangeVector))) if (failed(linalgOpToLoops(rewriter, op)))
return failure(); return failure();
break; break;
case LinalgLoweringType::AffineLoops: case LinalgLoweringType::AffineLoops:
if (failed(linalgOpToAffineLoops(rewriter, op, interchangeVector))) if (failed(linalgOpToAffineLoops(rewriter, op)))
return failure(); return failure();
break; break;
case LinalgLoweringType::ParallelLoops: case LinalgLoweringType::ParallelLoops:
if (failed(linalgOpToParallelLoops(rewriter, op, interchangeVector))) if (failed(linalgOpToParallelLoops(rewriter, op)))
return failure(); return failure();
break; break;
} }
@ -850,8 +837,6 @@ private:
/// Controls whether the pattern lowers to library calls, scf.for, affine.for /// Controls whether the pattern lowers to library calls, scf.for, affine.for
/// or scf.parallel. /// or scf.parallel.
LinalgLoweringType loweringType; LinalgLoweringType loweringType;
/// Permutated loop order in the generated loop nest.
SmallVector<unsigned, 4> interchangeVector;
}; };
/// Linalg generalization patterns /// Linalg generalization patterns

View File

@ -457,9 +457,8 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs,
} }
template <typename LoopTy> template <typename LoopTy>
static Optional<LinalgLoops> static Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op,
linalgOpToLoopsImpl(Operation *op, OpBuilder &builder, OpBuilder &builder) {
ArrayRef<unsigned> interchangeVector) {
using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy; using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy;
ScopedContext scope(builder, op->getLoc()); ScopedContext scope(builder, op->getLoc());
@ -472,13 +471,6 @@ linalgOpToLoopsImpl(Operation *op, OpBuilder &builder,
auto loopRanges = linalgOp.createLoopRanges(builder, op->getLoc()); auto loopRanges = linalgOp.createLoopRanges(builder, op->getLoc());
auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue()); auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue());
if (!interchangeVector.empty()) {
assert(interchangeVector.size() == loopRanges.size());
assert(interchangeVector.size() == iteratorTypes.size());
applyPermutationToVector(loopRanges, interchangeVector);
applyPermutationToVector(iteratorTypes, interchangeVector);
}
SmallVector<Value, 4> allIvs; SmallVector<Value, 4> allIvs;
GenerateLoopNest<LoopTy>::doit( GenerateLoopNest<LoopTy>::doit(
loopRanges, /*iterInitArgs=*/{}, iteratorTypes, loopRanges, /*iterInitArgs=*/{}, iteratorTypes,
@ -511,11 +503,10 @@ linalgOpToLoopsImpl(Operation *op, OpBuilder &builder,
} }
/// Replace the index operations in the body of the loop nest by the matching /// Replace the index operations in the body of the loop nest by the matching
/// induction variables. If available use the interchange vector to map the /// induction variables.
/// interchanged induction variables to the dimension of the index operation. static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp,
static void replaceIndexOpsByInductionVariables( PatternRewriter &rewriter,
LinalgOp linalgOp, PatternRewriter &rewriter, ArrayRef<Operation *> loopOps, ArrayRef<Operation *> loopOps) {
ArrayRef<unsigned> interchangeVector) {
// Extract the induction variables of the loop nest from outer to inner. // Extract the induction variables of the loop nest from outer to inner.
SmallVector<Value> allIvs; SmallVector<Value> allIvs;
for (Operation *loopOp : loopOps) { for (Operation *loopOp : loopOps) {
@ -538,16 +529,8 @@ static void replaceIndexOpsByInductionVariables(
if (!loopOps.empty()) { if (!loopOps.empty()) {
LoopLikeOpInterface loopOp = loopOps.back(); LoopLikeOpInterface loopOp = loopOps.back();
for (IndexOp indexOp : for (IndexOp indexOp :
llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>())) { llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>()))
// Search the indexing dimension in the interchange vector if available. rewriter.replaceOp(indexOp, allIvs[indexOp.dim()]);
assert(interchangeVector.empty() ||
interchangeVector.size() == linalgOp.getNumLoops());
const auto *it = llvm::find(interchangeVector, indexOp.dim());
uint64_t dim = it != interchangeVector.end()
? std::distance(interchangeVector.begin(), it)
: indexOp.dim();
rewriter.replaceOp(indexOp, allIvs[dim]);
}
} }
} }
@ -555,39 +538,31 @@ namespace {
template <typename LoopType> template <typename LoopType>
class LinalgRewritePattern : public RewritePattern { class LinalgRewritePattern : public RewritePattern {
public: public:
LinalgRewritePattern(MLIRContext *context, LinalgRewritePattern(MLIRContext *context)
ArrayRef<unsigned> interchangeVector) : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
: RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
LogicalResult matchAndRewrite(Operation *op, LogicalResult matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto linalgOp = dyn_cast<LinalgOp>(op); auto linalgOp = dyn_cast<LinalgOp>(op);
if (!isa<LinalgOp>(op)) if (!isa<LinalgOp>(op))
return failure(); return failure();
Optional<LinalgLoops> loopOps = Optional<LinalgLoops> loopOps = linalgOpToLoopsImpl<LoopType>(op, rewriter);
linalgOpToLoopsImpl<LoopType>(op, rewriter, interchangeVector);
if (!loopOps.hasValue()) if (!loopOps.hasValue())
return failure(); return failure();
replaceIndexOpsByInductionVariables(linalgOp, rewriter, loopOps.getValue(), replaceIndexOpsByInductionVariables(linalgOp, rewriter, loopOps.getValue());
interchangeVector);
rewriter.eraseOp(op); rewriter.eraseOp(op);
return success(); return success();
} }
private:
SmallVector<unsigned, 4> interchangeVector;
}; };
struct FoldAffineOp; struct FoldAffineOp;
} // namespace } // namespace
template <typename LoopType> template <typename LoopType>
static void lowerLinalgToLoopsImpl(FuncOp funcOp, static void lowerLinalgToLoopsImpl(FuncOp funcOp) {
ArrayRef<unsigned> interchangeVector) {
MLIRContext *context = funcOp.getContext(); MLIRContext *context = funcOp.getContext();
RewritePatternSet patterns(context); RewritePatternSet patterns(context);
patterns.add<LinalgRewritePattern<LoopType>>(context, interchangeVector); patterns.add<LinalgRewritePattern<LoopType>>(context);
memref::DimOp::getCanonicalizationPatterns(patterns, context); memref::DimOp::getCanonicalizationPatterns(patterns, context);
AffineApplyOp::getCanonicalizationPatterns(patterns, context); AffineApplyOp::getCanonicalizationPatterns(patterns, context);
patterns.add<FoldAffineOp>(context); patterns.add<FoldAffineOp>(context);
@ -639,7 +614,7 @@ struct LowerToAffineLoops
registry.insert<memref::MemRefDialect>(); registry.insert<memref::MemRefDialect>();
} }
void runOnFunction() override { void runOnFunction() override {
lowerLinalgToLoopsImpl<AffineForOp>(getFunction(), interchangeVector); lowerLinalgToLoopsImpl<AffineForOp>(getFunction());
} }
}; };
@ -648,14 +623,14 @@ struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> {
registry.insert<memref::MemRefDialect, scf::SCFDialect>(); registry.insert<memref::MemRefDialect, scf::SCFDialect>();
} }
void runOnFunction() override { void runOnFunction() override {
lowerLinalgToLoopsImpl<scf::ForOp>(getFunction(), interchangeVector); lowerLinalgToLoopsImpl<scf::ForOp>(getFunction());
} }
}; };
struct LowerToParallelLoops struct LowerToParallelLoops
: public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> { : public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> {
void runOnFunction() override { void runOnFunction() override {
lowerLinalgToLoopsImpl<scf::ParallelOp>(getFunction(), interchangeVector); lowerLinalgToLoopsImpl<scf::ParallelOp>(getFunction());
} }
}; };
} // namespace } // namespace
@ -676,43 +651,38 @@ mlir::createConvertLinalgToAffineLoopsPass() {
/// Emits a loop nest with the proper body for `op`. /// Emits a loop nest with the proper body for `op`.
template <typename LoopTy> template <typename LoopTy>
Optional<LinalgLoops> Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder,
mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, Operation *op, Operation *op) {
ArrayRef<unsigned> interchangeVector) { return linalgOpToLoopsImpl<LoopTy>(op, builder);
return linalgOpToLoopsImpl<LoopTy>(op, builder, interchangeVector);
} }
template Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops<AffineForOp>(
OpBuilder &builder, Operation *op, ArrayRef<unsigned> interchangeVector);
template Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops<scf::ForOp>(
OpBuilder &builder, Operation *op, ArrayRef<unsigned> interchangeVector);
template Optional<LinalgLoops> template Optional<LinalgLoops>
mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>( mlir::linalg::linalgLowerOpToLoops<AffineForOp>(OpBuilder &builder,
OpBuilder &builder, Operation *op, ArrayRef<unsigned> interchangeVector); Operation *op);
template Optional<LinalgLoops>
mlir::linalg::linalgLowerOpToLoops<scf::ForOp>(OpBuilder &builder,
Operation *op);
template Optional<LinalgLoops>
mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>(OpBuilder &builder,
Operation *op);
/// Emits a loop nest of `affine.for` with the proper body for `op`. /// Emits a loop nest of `affine.for` with the proper body for `op`.
LogicalResult LogicalResult mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder,
mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder, Operation *op, Operation *op) {
ArrayRef<unsigned> interchangeVector) { Optional<LinalgLoops> loops = linalgLowerOpToLoops<AffineForOp>(builder, op);
Optional<LinalgLoops> loops =
linalgLowerOpToLoops<AffineForOp>(builder, op, interchangeVector);
return loops ? success() : failure(); return loops ? success() : failure();
} }
/// Emits a loop nest of `scf.for` with the proper body for `op`. /// Emits a loop nest of `scf.for` with the proper body for `op`.
LogicalResult LogicalResult mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op) {
mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op, Optional<LinalgLoops> loops = linalgLowerOpToLoops<scf::ForOp>(builder, op);
ArrayRef<unsigned> interchangeVector) {
Optional<LinalgLoops> loops =
linalgLowerOpToLoops<scf::ForOp>(builder, op, interchangeVector);
return loops ? success() : failure(); return loops ? success() : failure();
} }
/// Emits a loop nest of `scf.parallel` with the proper body for `op`. /// Emits a loop nest of `scf.parallel` with the proper body for `op`.
LogicalResult LogicalResult mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder,
mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder, Operation *op, Operation *op) {
ArrayRef<unsigned> interchangeVector) {
Optional<LinalgLoops> loops = Optional<LinalgLoops> loops =
linalgLowerOpToLoops<scf::ParallelOp>(builder, op, interchangeVector); linalgLowerOpToLoops<scf::ParallelOp>(builder, op);
return loops ? success() : failure(); return loops ? success() : failure();
} }

View File

@ -1,72 +0,0 @@
// RUN: mlir-opt %s -convert-linalg-to-loops="interchange-vector=4,0,3,1,2" -split-input-file | FileCheck --check-prefix=LOOP %s
// RUN: mlir-opt %s -convert-linalg-to-parallel-loops="interchange-vector=4,0,3,1,2" -split-input-file | FileCheck --check-prefix=PARALLEL %s
// RUN: mlir-opt %s -convert-linalg-to-affine-loops="interchange-vector=4,0,3,1,2" -split-input-file | FileCheck --check-prefix=AFFINE %s
func @copy(%input: memref<1x2x3x4x5xf32>, %output: memref<1x2x3x4x5xf32>) {
linalg.copy(%input, %output): memref<1x2x3x4x5xf32>, memref<1x2x3x4x5xf32>
return
}
// LOOP: scf.for %{{.*}} = %c0 to %c5 step %c1
// LOOP: scf.for %{{.*}} = %c0 to %c1 step %c1
// LOOP: scf.for %{{.*}} = %c0 to %c4 step %c1
// LOOP: scf.for %{{.*}} = %c0 to %c2 step %c1
// LOOP: scf.for %{{.*}} = %c0 to %c3 step %c1
// PARALLEL: scf.parallel
// PARALLEL-SAME: to (%c5, %c1, %c4, %c2, %c3)
// AFFINE: affine.for %{{.*}} = 0 to 5
// AFFINE: affine.for %{{.*}} = 0 to 1
// AFFINE: affine.for %{{.*}} = 0 to 4
// AFFINE: affine.for %{{.*}} = 0 to 2
// AFFINE: affine.for %{{.*}} = 0 to 3
// -----
#map = affine_map<(i, j, k, l, m) -> (i, j, k, l, m)>
func @generic(%output: memref<1x2x3x4x5xindex>) {
linalg.generic {indexing_maps = [#map],
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
outs(%output : memref<1x2x3x4x5xindex>) {
^bb0(%arg0 : index):
%i = linalg.index 0 : index
%j = linalg.index 1 : index
%k = linalg.index 2 : index
%l = linalg.index 3 : index
%m = linalg.index 4 : index
%0 = addi %i, %j : index
%1 = addi %0, %k : index
%2 = addi %1, %l : index
%3 = addi %2, %m : index
linalg.yield %3: index
}
return
}
// LOOP: scf.for %[[m:.*]] = %c0 to %c5 step %c1
// LOOP: scf.for %[[i:.*]] = %c0 to %c1 step %c1
// LOOP: scf.for %[[l:.*]] = %c0 to %c4 step %c1
// LOOP: scf.for %[[j:.*]] = %c0 to %c2 step %c1
// LOOP: scf.for %[[k:.*]] = %c0 to %c3 step %c1
// LOOP: %{{.*}} = addi %[[i]], %[[j]] : index
// LOOP: %{{.*}} = addi %{{.*}}, %[[k]] : index
// LOOP: %{{.*}} = addi %{{.*}}, %[[l]] : index
// LOOP: %{{.*}} = addi %{{.*}}, %[[m]] : index
// PARALLEL: scf.parallel (%[[m:.*]], %[[i:.*]], %[[l:.*]], %[[j:.*]], %[[k:.*]]) =
// PARALLEL-SAME: to (%c5, %c1, %c4, %c2, %c3)
// PARALLEL: %{{.*}} = addi %[[i]], %[[j]] : index
// PARALLEL: %{{.*}} = addi %{{.*}}, %[[k]] : index
// PARALLEL: %{{.*}} = addi %{{.*}}, %[[l]] : index
// PARALLEL: %{{.*}} = addi %{{.*}}, %[[m]] : index
// AFFINE: affine.for %[[m:.*]] = 0 to 5
// AFFINE: affine.for %[[i:.*]] = 0 to 1
// AFFINE: affine.for %[[l:.*]] = 0 to 4
// AFFINE: affine.for %[[j:.*]] = 0 to 2
// AFFINE: affine.for %[[k:.*]] = 0 to 3
// AFFINE: %{{.*}} = addi %[[i]], %[[j]] : index
// AFFINE: %{{.*}} = addi %{{.*}}, %[[k]] : index
// AFFINE: %{{.*}} = addi %{{.*}}, %[[l]] : index
// AFFINE: %{{.*}} = addi %{{.*}}, %[[m]] : index