llvm-project/mlir/lib/Transforms/LowerAffine.cpp

756 lines
32 KiB
C++

//===- LowerAffine.cpp - Lower affine constructs to primitives ------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file lowers affine constructs (If and For statements, AffineApply
// operations) within a function into their lower level CFG equivalent blocks.
//
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/LowerAffine.h"
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Pass/Pass.h"
#include "mlir/StandardOps/Ops.h"
#include "mlir/Support/Functional.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
using namespace mlir;
namespace {
// Visit affine expressions recursively and build the sequence of operations
// that correspond to it. Visitation functions return an Value of the
// expression subtree they visited or `nullptr` on error.
class AffineApplyExpander
: public AffineExprVisitor<AffineApplyExpander, Value *> {
public:
// This internal class expects arguments to be non-null, checks must be
// performed at the call site.
AffineApplyExpander(OpBuilder &builder, ArrayRef<Value *> dimValues,
ArrayRef<Value *> symbolValues, Location loc)
: builder(builder), dimValues(dimValues), symbolValues(symbolValues),
loc(loc) {}
template <typename OpTy> Value *buildBinaryExpr(AffineBinaryOpExpr expr) {
auto lhs = visit(expr.getLHS());
auto rhs = visit(expr.getRHS());
if (!lhs || !rhs)
return nullptr;
auto op = builder.create<OpTy>(loc, lhs, rhs);
return op.getResult();
}
Value *visitAddExpr(AffineBinaryOpExpr expr) {
return buildBinaryExpr<AddIOp>(expr);
}
Value *visitMulExpr(AffineBinaryOpExpr expr) {
return buildBinaryExpr<MulIOp>(expr);
}
// Euclidean modulo operation: negative RHS is not allowed.
// Remainder of the euclidean integer division is always non-negative.
//
// Implemented as
//
// a mod b =
// let remainder = srem a, b;
// negative = a < 0 in
// select negative, remainder + b, remainder.
Value *visitModExpr(AffineBinaryOpExpr expr) {
auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
if (!rhsConst) {
emitError(
loc,
"semi-affine expressions (modulo by non-const) are not supported");
return nullptr;
}
if (rhsConst.getValue() <= 0) {
emitError(loc, "modulo by non-positive value is not supported");
return nullptr;
}
auto lhs = visit(expr.getLHS());
auto rhs = visit(expr.getRHS());
assert(lhs && rhs && "unexpected affine expr lowering failure");
Value *remainder = builder.create<RemISOp>(loc, lhs, rhs);
Value *zeroCst = builder.create<ConstantIndexOp>(loc, 0);
Value *isRemainderNegative =
builder.create<CmpIOp>(loc, CmpIPredicate::SLT, remainder, zeroCst);
Value *correctedRemainder = builder.create<AddIOp>(loc, remainder, rhs);
Value *result = builder.create<SelectOp>(loc, isRemainderNegative,
correctedRemainder, remainder);
return result;
}
// Floor division operation (rounds towards negative infinity).
//
// For positive divisors, it can be implemented without branching and with a
// single division operation as
//
// a floordiv b =
// let negative = a < 0 in
// let absolute = negative ? -a - 1 : a in
// let quotient = absolute / b in
// negative ? -quotient - 1 : quotient
Value *visitFloorDivExpr(AffineBinaryOpExpr expr) {
auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
if (!rhsConst) {
emitError(
loc,
"semi-affine expressions (division by non-const) are not supported");
return nullptr;
}
if (rhsConst.getValue() <= 0) {
emitError(loc, "division by non-positive value is not supported");
return nullptr;
}
auto lhs = visit(expr.getLHS());
auto rhs = visit(expr.getRHS());
assert(lhs && rhs && "unexpected affine expr lowering failure");
Value *zeroCst = builder.create<ConstantIndexOp>(loc, 0);
Value *noneCst = builder.create<ConstantIndexOp>(loc, -1);
Value *negative =
builder.create<CmpIOp>(loc, CmpIPredicate::SLT, lhs, zeroCst);
Value *negatedDecremented = builder.create<SubIOp>(loc, noneCst, lhs);
Value *dividend =
builder.create<SelectOp>(loc, negative, negatedDecremented, lhs);
Value *quotient = builder.create<DivISOp>(loc, dividend, rhs);
Value *correctedQuotient = builder.create<SubIOp>(loc, noneCst, quotient);
Value *result =
builder.create<SelectOp>(loc, negative, correctedQuotient, quotient);
return result;
}
// Ceiling division operation (rounds towards positive infinity).
//
// For positive divisors, it can be implemented without branching and with a
// single division operation as
//
// a ceildiv b =
// let negative = a <= 0 in
// let absolute = negative ? -a : a - 1 in
// let quotient = absolute / b in
// negative ? -quotient : quotient + 1
Value *visitCeilDivExpr(AffineBinaryOpExpr expr) {
auto rhsConst = expr.getRHS().dyn_cast<AffineConstantExpr>();
if (!rhsConst) {
emitError(loc) << "semi-affine expressions (division by non-const) are "
"not supported";
return nullptr;
}
if (rhsConst.getValue() <= 0) {
emitError(loc, "division by non-positive value is not supported");
return nullptr;
}
auto lhs = visit(expr.getLHS());
auto rhs = visit(expr.getRHS());
assert(lhs && rhs && "unexpected affine expr lowering failure");
Value *zeroCst = builder.create<ConstantIndexOp>(loc, 0);
Value *oneCst = builder.create<ConstantIndexOp>(loc, 1);
Value *nonPositive =
builder.create<CmpIOp>(loc, CmpIPredicate::SLE, lhs, zeroCst);
Value *negated = builder.create<SubIOp>(loc, zeroCst, lhs);
Value *decremented = builder.create<SubIOp>(loc, lhs, oneCst);
Value *dividend =
builder.create<SelectOp>(loc, nonPositive, negated, decremented);
Value *quotient = builder.create<DivISOp>(loc, dividend, rhs);
Value *negatedQuotient = builder.create<SubIOp>(loc, zeroCst, quotient);
Value *incrementedQuotient = builder.create<AddIOp>(loc, quotient, oneCst);
Value *result = builder.create<SelectOp>(loc, nonPositive, negatedQuotient,
incrementedQuotient);
return result;
}
Value *visitConstantExpr(AffineConstantExpr expr) {
auto valueAttr =
builder.getIntegerAttr(builder.getIndexType(), expr.getValue());
auto op =
builder.create<ConstantOp>(loc, builder.getIndexType(), valueAttr);
return op.getResult();
}
Value *visitDimExpr(AffineDimExpr expr) {
assert(expr.getPosition() < dimValues.size() &&
"affine dim position out of range");
return dimValues[expr.getPosition()];
}
Value *visitSymbolExpr(AffineSymbolExpr expr) {
assert(expr.getPosition() < symbolValues.size() &&
"symbol dim position out of range");
return symbolValues[expr.getPosition()];
}
private:
OpBuilder &builder;
ArrayRef<Value *> dimValues;
ArrayRef<Value *> symbolValues;
Location loc;
};
} // namespace
// Create a sequence of operations that implement the `expr` applied to the
// given dimension and symbol values.
mlir::Value *mlir::expandAffineExpr(OpBuilder &builder, Location loc,
AffineExpr expr,
ArrayRef<Value *> dimValues,
ArrayRef<Value *> symbolValues) {
return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr);
}
// Create a sequence of operations that implement the `affineMap` applied to
// the given `operands` (as it it were an AffineApplyOp).
Optional<SmallVector<Value *, 8>> static expandAffineMap(
OpBuilder &builder, Location loc, AffineMap affineMap,
ArrayRef<Value *> operands) {
auto numDims = affineMap.getNumDims();
auto expanded = functional::map(
[numDims, &builder, loc, operands](AffineExpr expr) {
return expandAffineExpr(builder, loc, expr,
operands.take_front(numDims),
operands.drop_front(numDims));
},
affineMap.getResults());
if (llvm::all_of(expanded, [](Value *v) { return v; }))
return expanded;
return None;
}
// Given a range of values, emit the code that reduces them with "min" or "max"
// depending on the provided comparison predicate. The predicate defines which
// comparison to perform, "lt" for "min", "gt" for "max" and is used for the
// `cmpi` operation followed by the `select` operation:
//
// %cond = cmpi "predicate" %v0, %v1
// %result = select %cond, %v0, %v1
//
// Multiple values are scanned in a linear sequence. This creates a data
// dependences that wouldn't exist in a tree reduction, but is easier to
// recognize as a reduction by the subsequent passes.
static Value *buildMinMaxReductionSeq(Location loc, CmpIPredicate predicate,
ArrayRef<Value *> values,
OpBuilder &builder) {
assert(!llvm::empty(values) && "empty min/max chain");
auto valueIt = values.begin();
Value *value = *valueIt++;
for (; valueIt != values.end(); ++valueIt) {
auto cmpOp = builder.create<CmpIOp>(loc, predicate, value, *valueIt);
value = builder.create<SelectOp>(loc, cmpOp.getResult(), value, *valueIt);
}
return value;
}
// Emit instructions that correspond to the affine map in the lower bound
// applied to the respective operands, and compute the maximum value across
// the results.
Value *mlir::lowerAffineLowerBound(AffineForOp op, OpBuilder &builder) {
SmallVector<Value *, 8> boundOperands(op.getLowerBoundOperands());
auto lbValues = expandAffineMap(builder, op.getLoc(), op.getLowerBoundMap(),
boundOperands);
if (!lbValues)
return nullptr;
return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::SGT, *lbValues,
builder);
}
// Emit instructions that correspond to the affine map in the upper bound
// applied to the respective operands, and compute the minimum value across
// the results.
Value *mlir::lowerAffineUpperBound(AffineForOp op, OpBuilder &builder) {
SmallVector<Value *, 8> boundOperands(op.getUpperBoundOperands());
auto ubValues = expandAffineMap(builder, op.getLoc(), op.getUpperBoundMap(),
boundOperands);
if (!ubValues)
return nullptr;
return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::SLT, *ubValues,
builder);
}
namespace {
// Affine terminators are removed.
class AffineTerminatorLowering : public ConversionPattern {
public:
AffineTerminatorLowering(MLIRContext *ctx)
: ConversionPattern(AffineTerminatorOp::getOperationName(), 1, ctx) {}
virtual PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
rewriter.replaceOp(op, {});
return matchSuccess();
}
};
// Create a CFG subgraph for the loop around its body blocks (if the body
// contained other loops, they have been already lowered to a flow of blocks).
// Maintain the invariants that a CFG subgraph created for any loop has a single
// entry and a single exit, and that the entry/exit blocks are respectively
// first/last blocks in the parent region. The original loop operation is
// replaced by the initialization operations that set up the initial value of
// the loop induction variable (%iv) and computes the loop bounds that are loop-
// invariant for affine loops. The operations following the original affine.for
// are split out into a separate continuation (exit) block. A condition block is
// created before the continuation block. It checks the exit condition of the
// loop and branches either to the continuation block, or to the first block of
// the body. Induction variable modification is appended to the last block of
// the body (which is the exit block from the body subgraph thanks to the
// invariant we maintain) along with a branch that loops back to the condition
// block.
//
// NOTE: this relies on the DialectConversion infrastructure knowing how to undo
// the creation of operations if the conversion fails. In particular, lowering
// of the affine maps may insert operations and then fail on a semi-affine map.
//
// +---------------------------------+
// | <code before the AffineForOp> |
// | <compute initial %iv value> |
// | br cond(%iv) |
// +---------------------------------+
// |
// -------| |
// | v v
// | +--------------------------------+
// | | cond(%iv): |
// | | <compare %iv to upper bound> |
// | | cond_br %r, body, end |
// | +--------------------------------+
// | | |
// | | -------------|
// | v |
// | +--------------------------------+ |
// | | body-first: | |
// | | <body contents> | |
// | +--------------------------------+ |
// | | |
// | ... |
// | | |
// | +--------------------------------+ |
// | | body-last: | |
// | | <body contents> | |
// | | %new_iv =<add step to %iv> | |
// | | br cond(%new_iv) | |
// | +--------------------------------+ |
// | | |
// |----------- |--------------------
// v
// +--------------------------------+
// | end: |
// | <code after the AffineForOp> |
// +--------------------------------+
//
class AffineForLowering : public ConversionPattern {
public:
AffineForLowering(MLIRContext *ctx)
: ConversionPattern(AffineForOp::getOperationName(), 1, ctx) {}
virtual PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto forOp = cast<AffineForOp>(op);
Location loc = op->getLoc();
// Start by splitting the block containing the 'affine.for' into two parts.
// The part before will get the init code, the part after will be the end
// point.
auto *initBlock = rewriter.getInsertionBlock();
auto initPosition = rewriter.getInsertionPoint();
auto *endBlock = rewriter.splitBlock(initBlock, initPosition);
// Use the first block of the loop body as the condition block since it is
// the block that has the induction variable as its argument. Split out
// all operations from the first block into a new block. Move all body
// blocks from the loop body region to the region containing the loop.
auto *conditionBlock = &forOp.getRegion().front();
auto *firstBodyBlock =
rewriter.splitBlock(conditionBlock, conditionBlock->begin());
auto *lastBodyBlock = &forOp.getRegion().back();
rewriter.inlineRegionBefore(forOp.getRegion(), endBlock);
auto *iv = conditionBlock->getArgument(0);
// Append the induction variable stepping logic to the last body block and
// branch back to the condition block. Construct an affine expression f :
// (x -> x+step) and apply this expression to the induction variable.
rewriter.setInsertionPointToEnd(lastBodyBlock);
auto affStep = rewriter.getAffineConstantExpr(forOp.getStep());
auto affDim = rewriter.getAffineDimExpr(0);
auto stepped = expandAffineExpr(rewriter, loc, affDim + affStep, iv, {});
if (!stepped)
return matchFailure();
rewriter.create<BranchOp>(loc, conditionBlock, stepped);
// Compute loop bounds before branching to the condition.
rewriter.setInsertionPointToEnd(initBlock);
Value *lowerBound = lowerAffineLowerBound(forOp, rewriter);
Value *upperBound = lowerAffineUpperBound(forOp, rewriter);
if (!lowerBound || !upperBound)
return matchFailure();
rewriter.create<BranchOp>(loc, conditionBlock, lowerBound);
// With the body block done, we can fill in the condition block.
rewriter.setInsertionPointToEnd(conditionBlock);
auto comparison =
rewriter.create<CmpIOp>(loc, CmpIPredicate::SLT, iv, upperBound);
rewriter.create<CondBranchOp>(loc, comparison, firstBodyBlock,
ArrayRef<Value *>(), endBlock,
ArrayRef<Value *>());
// Ok, we're done!
rewriter.replaceOp(op, {});
return matchSuccess();
}
};
// Create a CFG subgraph for the affine.if operation (including its "then" and
// optional "else" operation blocks). We maintain the invariants that the
// subgraph has a single entry and a single exit point, and that the entry/exit
// blocks are respectively the first/last block of the enclosing region. The
// operations following the affine.if are split into a continuation (subgraph
// exit) block. The condition is lowered to a chain of blocks that implement the
// short-circuit scheme. Condition blocks are created by splitting out an empty
// block from the block that contains the affine.if operation. They
// conditionally branch to either the first block of the "then" region, or to
// the first block of the "else" region. If the latter is absent, they branch
// to the continuation block instead. The last blocks of "then" and "else"
// regions (which are known to be exit blocks thanks to the invariant we
// maintain).
//
// NOTE: this relies on the DialectConversion infrastructure knowing how to undo
// the creation of operations if the conversion fails. In particular, lowering
// of the affine maps may insert operations and then fail on a semi-affine map.
//
// +--------------------------------+
// | <code before the AffineIfOp> |
// | %zero = constant 0 : index |
// | %v = affine.apply #expr1(%ops) |
// | %c = cmpi "sge" %v, %zero |
// | cond_br %c, %next, %else |
// +--------------------------------+
// | |
// | --------------|
// v |
// +--------------------------------+ |
// | next: | |
// | <repeat the check for expr2> | |
// | cond_br %c, %next2, %else | |
// +--------------------------------+ |
// | | |
// ... --------------|
// | <Per-expression checks> |
// v |
// +--------------------------------+ |
// | last: | |
// | <repeat the check for exprN> | |
// | cond_br %c, %then, %else | |
// +--------------------------------+ |
// | | |
// | --------------|
// v |
// +--------------------------------+ |
// | then: | |
// | <then contents> | |
// | br continue | |
// +--------------------------------+ |
// | |
// |---------- |-------------
// | V
// | +--------------------------------+
// | | else: |
// | | <else contents> |
// | | br continue |
// | +--------------------------------+
// | |
// ------| |
// v v
// +--------------------------------+
// | continue: |
// | <code after the AffineIfOp> |
// +--------------------------------+
//
class AffineIfLowering : public ConversionPattern {
public:
AffineIfLowering(MLIRContext *ctx)
: ConversionPattern(AffineIfOp::getOperationName(), 1, ctx) {}
virtual PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto ifOp = cast<AffineIfOp>(op);
auto loc = op->getLoc();
// Start by splitting the block containing the 'affine.if' into two parts.
// The part before will contain the condition, the part after will be the
// continuation point.
auto *condBlock = rewriter.getInsertionBlock();
auto opPosition = rewriter.getInsertionPoint();
auto *continueBlock = rewriter.splitBlock(condBlock, opPosition);
// Move blocks from the "then" region to the region containing 'affine.if',
// place it before the continuation block, and branch to it.
auto *thenBlock = &ifOp.getThenBlocks().front();
rewriter.setInsertionPointToEnd(&ifOp.getThenBlocks().back());
rewriter.create<BranchOp>(loc, continueBlock);
rewriter.inlineRegionBefore(ifOp.getThenBlocks(), continueBlock);
// Move blocks from the "else" region (if present) to the region containing
// 'affine.if', place it before the continuation block and branch to it. It
// will be placed after the "then" regions.
auto *elseBlock = continueBlock;
if (!ifOp.getElseBlocks().empty()) {
elseBlock = &ifOp.getElseBlocks().front();
rewriter.setInsertionPointToEnd(&ifOp.getElseBlocks().back());
rewriter.create<BranchOp>(loc, continueBlock);
rewriter.inlineRegionBefore(ifOp.getElseBlocks(), continueBlock);
}
// Now we just have to handle the condition logic.
auto integerSet = ifOp.getIntegerSet();
// Implement short-circuit logic. For each affine expression in the
// 'affine.if' condition, convert it into an affine map and call
// `affine.apply` to obtain the resulting value. Perform the equality or
// the greater-than-or-equality test between this value and zero depending
// on the equality flag of the condition. If the test fails, jump
// immediately to the false branch, which may be the else block if it is
// present or the continuation block otherwise. If the test succeeds, jump
// to the next block testing the next conjunct of the condition in the
// similar way. When all conjuncts have been handled, jump to the 'then'
// block instead.
rewriter.setInsertionPointToEnd(condBlock);
Value *zeroConstant = rewriter.create<ConstantIndexOp>(loc, 0);
for (unsigned i = 0, e = integerSet.getNumConstraints(); i < e; ++i) {
AffineExpr constraintExpr = integerSet.getConstraint(i);
bool isEquality = integerSet.isEq(i);
// Create the fall-through block for the next condition, if present, by
// splitting an empty block out of an existing block. Otherwise treat the
// first "then" block as the block we should branch to if the (last)
// condition is true.
auto *nextBlock = (i == e - 1)
? thenBlock
: rewriter.splitBlock(condBlock, condBlock->end());
// Build and apply an affine expression
auto numDims = integerSet.getNumDims();
Value *affResult = expandAffineExpr(rewriter, loc, constraintExpr,
operands.take_front(numDims),
operands.drop_front(numDims));
if (!affResult)
return matchFailure();
auto comparisonOp = rewriter.create<CmpIOp>(
loc, isEquality ? CmpIPredicate::EQ : CmpIPredicate::SGE, affResult,
zeroConstant);
rewriter.create<CondBranchOp>(loc, comparisonOp.getResult(), nextBlock,
/*trueArgs=*/ArrayRef<Value *>(), elseBlock,
/*falseArgs=*/ArrayRef<Value *>());
rewriter.setInsertionPointToEnd(nextBlock);
condBlock = nextBlock;
}
// Ok, we're done!
rewriter.replaceOp(op, {});
return matchSuccess();
}
};
// Convert an "affine.apply" operation into a sequence of arithmetic
// operations using the StandardOps dialect.
class AffineApplyLowering : public ConversionPattern {
public:
AffineApplyLowering(MLIRContext *ctx)
: ConversionPattern(AffineApplyOp::getOperationName(), 1, ctx) {}
virtual PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto affineApplyOp = cast<AffineApplyOp>(op);
auto maybeExpandedMap = expandAffineMap(
rewriter, op->getLoc(), affineApplyOp.getAffineMap(), operands);
if (!maybeExpandedMap)
return matchFailure();
rewriter.replaceOp(op, *maybeExpandedMap);
return matchSuccess();
}
};
// Apply the affine map from an 'affine.load' operation to its operands, and
// feed the results to a newly created 'std.load' operation (which replaces the
// original 'affine.load').
class AffineLoadLowering : public ConversionPattern {
public:
AffineLoadLowering(MLIRContext *ctx)
: ConversionPattern(AffineLoadOp::getOperationName(), 1, ctx) {}
virtual PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto affineLoadOp = cast<AffineLoadOp>(op);
// Expand affine map from 'affineLoadOp'.
auto maybeExpandedMap =
expandAffineMap(rewriter, op->getLoc(), affineLoadOp.getAffineMap(),
operands.drop_front());
if (!maybeExpandedMap)
return matchFailure();
// Build std.load memref[expandedMap.results].
rewriter.replaceOpWithNewOp<LoadOp>(op, operands[0], *maybeExpandedMap);
return matchSuccess();
}
};
// Apply the affine map from an 'affine.store' operation to its operands, and
// feed the results to a newly created 'std.store' operation (which replaces the
// original 'affine.store').
class AffineStoreLowering : public ConversionPattern {
public:
AffineStoreLowering(MLIRContext *ctx)
: ConversionPattern(AffineStoreOp::getOperationName(), 1, ctx) {}
virtual PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto affineStoreOp = cast<AffineStoreOp>(op);
// Expand affine map from 'affineStoreOp'.
auto maybeExpandedMap =
expandAffineMap(rewriter, op->getLoc(), affineStoreOp.getAffineMap(),
operands.drop_front(2));
if (!maybeExpandedMap)
return matchFailure();
// Build std.store valutToStore, memref[expandedMap.results].
rewriter.replaceOpWithNewOp<StoreOp>(op, operands[0], operands[1],
*maybeExpandedMap);
return matchSuccess();
}
};
// Apply the affine maps from an 'affine.dma_start' operation to each of their
// respective map operands, and feed the results to a newly created
// 'std.dma_start' operation (which replaces the original 'affine.dma_start').
class AffineDmaStartLowering : public ConversionPattern {
public:
AffineDmaStartLowering(MLIRContext *ctx)
: ConversionPattern(AffineDmaStartOp::getOperationName(), 1, ctx) {}
virtual PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto affineDmaStartOp = cast<AffineDmaStartOp>(op);
// Expand affine map for DMA source memref.
auto maybeExpandedSrcMap = expandAffineMap(
rewriter, op->getLoc(), affineDmaStartOp.getSrcMap(),
operands.drop_front(affineDmaStartOp.getSrcMemRefOperandIndex() + 1));
if (!maybeExpandedSrcMap)
return matchFailure();
// Expand affine map for DMA destination memref.
auto maybeExpandedDstMap = expandAffineMap(
rewriter, op->getLoc(), affineDmaStartOp.getDstMap(),
operands.drop_front(affineDmaStartOp.getDstMemRefOperandIndex() + 1));
if (!maybeExpandedDstMap)
return matchFailure();
// Expand affine map for DMA tag memref.
auto maybeExpandedTagMap = expandAffineMap(
rewriter, op->getLoc(), affineDmaStartOp.getTagMap(),
operands.drop_front(affineDmaStartOp.getTagMemRefOperandIndex() + 1));
if (!maybeExpandedTagMap)
return matchFailure();
// Build std.dma_start operation with affine map results.
auto *srcMemRef = operands[affineDmaStartOp.getSrcMemRefOperandIndex()];
auto *dstMemRef = operands[affineDmaStartOp.getDstMemRefOperandIndex()];
auto *tagMemRef = operands[affineDmaStartOp.getTagMemRefOperandIndex()];
unsigned numElementsIndex = affineDmaStartOp.getTagMemRefOperandIndex() +
1 + affineDmaStartOp.getTagMap().getNumInputs();
auto *numElements = operands[numElementsIndex];
auto *stride =
affineDmaStartOp.isStrided() ? operands[numElementsIndex + 1] : nullptr;
auto *eltsPerStride =
affineDmaStartOp.isStrided() ? operands[numElementsIndex + 2] : nullptr;
rewriter.replaceOpWithNewOp<DmaStartOp>(
op, srcMemRef, *maybeExpandedSrcMap, dstMemRef, *maybeExpandedDstMap,
numElements, tagMemRef, *maybeExpandedTagMap, stride, eltsPerStride);
return matchSuccess();
}
};
// Apply the affine map from an 'affine.dma_wait' operation tag memref,
// and feed the results to a newly created 'std.dma_wait' operation (which
// replaces the original 'affine.dma_wait').
class AffineDmaWaitLowering : public ConversionPattern {
public:
AffineDmaWaitLowering(MLIRContext *ctx)
: ConversionPattern(AffineDmaWaitOp::getOperationName(), 1, ctx) {}
virtual PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
auto affineDmaWaitOp = cast<AffineDmaWaitOp>(op);
// Expand affine map for DMA tag memref.
auto maybeExpandedTagMap =
expandAffineMap(rewriter, op->getLoc(), affineDmaWaitOp.getTagMap(),
operands.drop_front());
if (!maybeExpandedTagMap)
return matchFailure();
// Build std.dma_wait operation with affine map results.
unsigned numElementsIndex = 1 + affineDmaWaitOp.getTagMap().getNumInputs();
rewriter.replaceOpWithNewOp<DmaWaitOp>(
op, operands[0], *maybeExpandedTagMap, operands[numElementsIndex]);
return matchSuccess();
}
};
} // end namespace
LogicalResult mlir::lowerAffineConstructs(Function function) {
OwningRewritePatternList patterns;
RewriteListBuilder<AffineApplyLowering, AffineDmaStartLowering,
AffineDmaWaitLowering, AffineLoadLowering,
AffineStoreLowering, AffineForLowering, AffineIfLowering,
AffineTerminatorLowering>::build(patterns,
function.getContext());
ConversionTarget target(*function.getContext());
target.addLegalDialect<StandardOpsDialect>();
return applyConversionPatterns(function, target, std::move(patterns));
}
namespace {
class LowerAffinePass : public FunctionPass<LowerAffinePass> {
void runOnFunction() override { lowerAffineConstructs(getFunction()); }
};
} // namespace
/// Lowers If and For operations within a function into their lower level CFG
/// equivalent blocks.
FunctionPassBase *mlir::createLowerAffinePass() {
return new LowerAffinePass();
}
static PassRegistration<LowerAffinePass>
pass("lower-affine",
"Lower If, For, AffineApply operations to primitive equivalents");