forked from OSchip/llvm-project
Decouple affine->standard lowering from the pass
The lowering from the Affine dialect to the Standard dialect was originally implemented as a standalone pass. However, it may be used by other passes willing to lower away some of the affine constructs as a part of their operation. Decouple the transformation functions from the pass infrastructure and expose the entry point for the lowering. Also update the lowering functions to use `LogicalResult` instead of bool for return values. -- PiperOrigin-RevId: 250229198
This commit is contained in:
parent
ffc4cf7091
commit
d4c071cc69
|
@ -0,0 +1,28 @@
|
||||||
|
//===- LowerAffine.h - Convert Affine to Standard dialect -------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The MLIR Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
#ifndef MLIR_TRANSFORMS_LOWERAFFINE_H
|
||||||
|
#define MLIR_TRANSFORMS_LOWERAFFINE_H
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
class Function;
|
||||||
|
class LogicalResult;
|
||||||
|
|
||||||
|
LogicalResult lowerAffineConstructs(Function &function);
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // MLIR_TRANSFORMS_LOWERAFFINE_H
|
|
@ -20,6 +20,7 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Transforms/LowerAffine.h"
|
||||||
#include "mlir/AffineOps/AffineOps.h"
|
#include "mlir/AffineOps/AffineOps.h"
|
||||||
#include "mlir/IR/AffineExprVisitor.h"
|
#include "mlir/IR/AffineExprVisitor.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
|
@ -243,10 +244,6 @@ Optional<SmallVector<Value *, 8>> static expandAffineMap(
|
||||||
namespace {
|
namespace {
|
||||||
struct LowerAffinePass : public FunctionPass<LowerAffinePass> {
|
struct LowerAffinePass : public FunctionPass<LowerAffinePass> {
|
||||||
void runOnFunction() override;
|
void runOnFunction() override;
|
||||||
|
|
||||||
bool lowerAffineFor(AffineForOp forOp);
|
|
||||||
bool lowerAffineIf(AffineIfOp ifOp);
|
|
||||||
bool lowerAffineApply(AffineApplyOp op);
|
|
||||||
};
|
};
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
@ -319,7 +316,7 @@ static Value *buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate,
|
||||||
// | <code after the AffineForOp> |
|
// | <code after the AffineForOp> |
|
||||||
// +--------------------------------+
|
// +--------------------------------+
|
||||||
//
|
//
|
||||||
bool LowerAffinePass::lowerAffineFor(AffineForOp forOp) {
|
static LogicalResult lowerAffineFor(AffineForOp forOp) {
|
||||||
auto loc = forOp.getLoc();
|
auto loc = forOp.getLoc();
|
||||||
auto *forInst = forOp.getOperation();
|
auto *forInst = forOp.getOperation();
|
||||||
|
|
||||||
|
@ -356,7 +353,7 @@ bool LowerAffinePass::lowerAffineFor(AffineForOp forOp) {
|
||||||
auto affDim = builder.getAffineDimExpr(0);
|
auto affDim = builder.getAffineDimExpr(0);
|
||||||
auto stepped = expandAffineExpr(&builder, loc, affDim + affStep, iv, {});
|
auto stepped = expandAffineExpr(&builder, loc, affDim + affStep, iv, {});
|
||||||
if (!stepped)
|
if (!stepped)
|
||||||
return true;
|
return failure();
|
||||||
// We know we applied a one-dimensional map.
|
// We know we applied a one-dimensional map.
|
||||||
builder.create<BranchOp>(loc, conditionBlock, stepped);
|
builder.create<BranchOp>(loc, conditionBlock, stepped);
|
||||||
|
|
||||||
|
@ -369,7 +366,7 @@ bool LowerAffinePass::lowerAffineFor(AffineForOp forOp) {
|
||||||
auto lbValues = expandAffineMap(&builder, forInst->getLoc(),
|
auto lbValues = expandAffineMap(&builder, forInst->getLoc(),
|
||||||
forOp.getLowerBoundMap(), operands);
|
forOp.getLowerBoundMap(), operands);
|
||||||
if (!lbValues)
|
if (!lbValues)
|
||||||
return true;
|
return failure();
|
||||||
Value *lowerBound =
|
Value *lowerBound =
|
||||||
buildMinMaxReductionSeq(loc, CmpIPredicate::SGT, *lbValues, builder);
|
buildMinMaxReductionSeq(loc, CmpIPredicate::SGT, *lbValues, builder);
|
||||||
|
|
||||||
|
@ -378,7 +375,7 @@ bool LowerAffinePass::lowerAffineFor(AffineForOp forOp) {
|
||||||
auto ubValues = expandAffineMap(&builder, forInst->getLoc(),
|
auto ubValues = expandAffineMap(&builder, forInst->getLoc(),
|
||||||
forOp.getUpperBoundMap(), operands);
|
forOp.getUpperBoundMap(), operands);
|
||||||
if (!ubValues)
|
if (!ubValues)
|
||||||
return true;
|
return failure();
|
||||||
Value *upperBound =
|
Value *upperBound =
|
||||||
buildMinMaxReductionSeq(loc, CmpIPredicate::SLT, *ubValues, builder);
|
buildMinMaxReductionSeq(loc, CmpIPredicate::SLT, *ubValues, builder);
|
||||||
builder.create<BranchOp>(loc, conditionBlock, lowerBound);
|
builder.create<BranchOp>(loc, conditionBlock, lowerBound);
|
||||||
|
@ -392,7 +389,7 @@ bool LowerAffinePass::lowerAffineFor(AffineForOp forOp) {
|
||||||
|
|
||||||
// Ok, we're done!
|
// Ok, we're done!
|
||||||
forOp.erase();
|
forOp.erase();
|
||||||
return false;
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert an "if" operation into a flow of basic blocks.
|
// Convert an "if" operation into a flow of basic blocks.
|
||||||
|
@ -454,7 +451,7 @@ bool LowerAffinePass::lowerAffineFor(AffineForOp forOp) {
|
||||||
// | <code after the AffineIfOp> |
|
// | <code after the AffineIfOp> |
|
||||||
// +--------------------------------+
|
// +--------------------------------+
|
||||||
//
|
//
|
||||||
bool LowerAffinePass::lowerAffineIf(AffineIfOp ifOp) {
|
static LogicalResult lowerAffineIf(AffineIfOp ifOp) {
|
||||||
auto *ifInst = ifOp.getOperation();
|
auto *ifInst = ifOp.getOperation();
|
||||||
auto loc = ifInst->getLoc();
|
auto loc = ifInst->getLoc();
|
||||||
|
|
||||||
|
@ -476,7 +473,7 @@ bool LowerAffinePass::lowerAffineIf(AffineIfOp ifOp) {
|
||||||
if (!oldThenBlocks.empty()) {
|
if (!oldThenBlocks.empty()) {
|
||||||
// We currently only handle one 'then' block.
|
// We currently only handle one 'then' block.
|
||||||
if (std::next(oldThenBlocks.begin()) != oldThenBlocks.end())
|
if (std::next(oldThenBlocks.begin()) != oldThenBlocks.end())
|
||||||
return true;
|
return failure();
|
||||||
|
|
||||||
Block *oldThen = &oldThenBlocks.front();
|
Block *oldThen = &oldThenBlocks.front();
|
||||||
|
|
||||||
|
@ -495,7 +492,7 @@ bool LowerAffinePass::lowerAffineIf(AffineIfOp ifOp) {
|
||||||
if (!oldElseBlocks.empty()) {
|
if (!oldElseBlocks.empty()) {
|
||||||
// We currently only handle one 'else' block.
|
// We currently only handle one 'else' block.
|
||||||
if (std::next(oldElseBlocks.begin()) != oldElseBlocks.end())
|
if (std::next(oldElseBlocks.begin()) != oldElseBlocks.end())
|
||||||
return true;
|
return failure();
|
||||||
|
|
||||||
auto *oldElse = &oldElseBlocks.front();
|
auto *oldElse = &oldElseBlocks.front();
|
||||||
elseBlock = new Block();
|
elseBlock = new Block();
|
||||||
|
@ -541,7 +538,7 @@ bool LowerAffinePass::lowerAffineIf(AffineIfOp ifOp) {
|
||||||
operandsRef.take_front(numDims),
|
operandsRef.take_front(numDims),
|
||||||
operandsRef.drop_front(numDims));
|
operandsRef.drop_front(numDims));
|
||||||
if (!affResult)
|
if (!affResult)
|
||||||
return true;
|
return failure();
|
||||||
|
|
||||||
// Compare the result of the apply and branch.
|
// Compare the result of the apply and branch.
|
||||||
auto comparisonOp = builder.create<CmpIOp>(
|
auto comparisonOp = builder.create<CmpIOp>(
|
||||||
|
@ -566,26 +563,26 @@ bool LowerAffinePass::lowerAffineIf(AffineIfOp ifOp) {
|
||||||
|
|
||||||
// Ok, we're done!
|
// Ok, we're done!
|
||||||
ifInst->erase();
|
ifInst->erase();
|
||||||
return false;
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert an "affine.apply" operation into a sequence of arithmetic
|
// Convert an "affine.apply" operation into a sequence of arithmetic
|
||||||
// operations using the StandardOps dialect. Return true on error.
|
// operations using the StandardOps dialect. Return true on error.
|
||||||
bool LowerAffinePass::lowerAffineApply(AffineApplyOp op) {
|
static LogicalResult lowerAffineApply(AffineApplyOp op) {
|
||||||
FuncBuilder builder(op.getOperation());
|
FuncBuilder builder(op.getOperation());
|
||||||
auto maybeExpandedMap =
|
auto maybeExpandedMap =
|
||||||
expandAffineMap(&builder, op.getLoc(), op.getAffineMap(),
|
expandAffineMap(&builder, op.getLoc(), op.getAffineMap(),
|
||||||
llvm::to_vector<8>(op.getOperands()));
|
llvm::to_vector<8>(op.getOperands()));
|
||||||
if (!maybeExpandedMap)
|
if (!maybeExpandedMap)
|
||||||
return true;
|
return failure();
|
||||||
|
|
||||||
Value *original = op.getResult();
|
Value *original = op.getResult();
|
||||||
Value *expanded = (*maybeExpandedMap)[0];
|
Value *expanded = (*maybeExpandedMap)[0];
|
||||||
if (!expanded)
|
if (!expanded)
|
||||||
return true;
|
return failure();
|
||||||
original->replaceAllUsesWith(expanded);
|
original->replaceAllUsesWith(expanded);
|
||||||
op.erase();
|
op.erase();
|
||||||
return false;
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Entry point of the function convertor.
|
// Entry point of the function convertor.
|
||||||
|
@ -600,35 +597,37 @@ bool LowerAffinePass::lowerAffineApply(AffineApplyOp op) {
|
||||||
// Individual operations are simply appended to the end of the last basic block
|
// Individual operations are simply appended to the end of the last basic block
|
||||||
// of the current region. The SESE invariant allows us to easily handle nested
|
// of the current region. The SESE invariant allows us to easily handle nested
|
||||||
// structures of arbitrary complexity.
|
// structures of arbitrary complexity.
|
||||||
//
|
LogicalResult mlir::lowerAffineConstructs(Function &function) {
|
||||||
// During the conversion, we maintain a mapping between the Values present in
|
|
||||||
// the original function and their Value images in the function under
|
|
||||||
// construction. When an Value is used, it gets replaced with the
|
|
||||||
// corresponding Value that has been defined previously. The value flow
|
|
||||||
// starts with function arguments converted to basic block arguments.
|
|
||||||
void LowerAffinePass::runOnFunction() {
|
|
||||||
SmallVector<Operation *, 8> instsToRewrite;
|
SmallVector<Operation *, 8> instsToRewrite;
|
||||||
|
|
||||||
// Collect all the For operations as well as AffineIfOps and AffineApplyOps.
|
// Collect all the For operations as well as AffineIfOps and AffineApplyOps.
|
||||||
// We do this as a prepass to avoid invalidating the walker with our rewrite.
|
// We do this as a prepass to avoid invalidating the walker with our rewrite.
|
||||||
getFunction().walk([&](Operation *op) {
|
function.walk([&](Operation *op) {
|
||||||
if (isa<AffineApplyOp>(op) || isa<AffineForOp>(op) || isa<AffineIfOp>(op))
|
if (isa<AffineApplyOp>(op) || isa<AffineForOp>(op) || isa<AffineIfOp>(op))
|
||||||
instsToRewrite.push_back(op);
|
instsToRewrite.push_back(op);
|
||||||
});
|
});
|
||||||
|
|
||||||
// Rewrite all of the ifs and fors. We walked the operations in postorders,
|
// Rewrite all of the ifs and fors. We walked the operations in postorder,
|
||||||
// so we know that we will rewrite them in the reverse order.
|
// so we know that we will rewrite them in the reverse order.
|
||||||
for (auto *op : llvm::reverse(instsToRewrite)) {
|
for (auto *op : llvm::reverse(instsToRewrite)) {
|
||||||
if (auto ifOp = dyn_cast<AffineIfOp>(op)) {
|
if (auto ifOp = dyn_cast<AffineIfOp>(op)) {
|
||||||
if (lowerAffineIf(ifOp))
|
if (failed(lowerAffineIf(ifOp)))
|
||||||
return signalPassFailure();
|
return failure();
|
||||||
} else if (auto forOp = dyn_cast<AffineForOp>(op)) {
|
} else if (auto forOp = dyn_cast<AffineForOp>(op)) {
|
||||||
if (lowerAffineFor(forOp))
|
if (failed(lowerAffineFor(forOp)))
|
||||||
return signalPassFailure();
|
return failure();
|
||||||
} else if (lowerAffineApply(cast<AffineApplyOp>(op))) {
|
} else if (failed(lowerAffineApply(cast<AffineApplyOp>(op)))) {
|
||||||
return signalPassFailure();
|
return failure();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run the affine lowering as a function pass.
|
||||||
|
void LowerAffinePass::runOnFunction() {
|
||||||
|
if (failed(lowerAffineConstructs(getFunction())))
|
||||||
|
signalPassFailure();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Lowers If and For operations within a function into their lower level CFG
|
/// Lowers If and For operations within a function into their lower level CFG
|
||||||
|
|
Loading…
Reference in New Issue