[mlir][Linalg] NFC - Expose helper function `substituteMin`.

This commit is contained in:
Nicolas Vasilache 2021-03-19 16:21:15 +00:00
parent 1fe042041c
commit 5b2d8503d1
2 changed files with 71 additions and 16 deletions

View File

@ -893,6 +893,30 @@ struct AffineMinSCFCanonicalizationPattern
PatternRewriter &rewriter) const override;
};
/// Helper struct to return the results of `substituteMin`.
struct AffineMapAndOperands {
AffineMap map;
SmallVector<Value> dims;
SmallVector<Value> symbols;
};
/// Traverse the dims of the AffineMap of `affineMinOp` and substitute scf loop
/// induction variables by new expressions involving the lower or upper bound:
/// - If the AffineDimExpr mapped to a loop IV has a positive sign, it is
/// replaced by the loop upper bound.
/// - If the AffineDimExpr mapped to a loop IV has a negative sign, it is
/// replaced by the loop lower bound.
/// All loop induction variables are iteratively replaced, unless a
/// `substituteOperation` hook is passed to more finely determine which
/// operations are substituted.
/// This is used as an intermediate step in computing bounding boxes and
/// canonicalize AffineMinOps. All dim and symbol operands are assumed to have
/// positive values (positive orthant assumptions).
/// Return a new AffineMap, dims and symbols that have been canonicalized and
/// simplified.
AffineMapAndOperands substituteMin(
AffineMinOp affineMinOp,
llvm::function_ref<bool(Operation *)> substituteOperation = nullptr);
/// Converts Convolution op into vector contraction.
///
/// Conversion expects ConvOp to have dimensions marked in the *mask* as

View File

@ -536,8 +536,10 @@ static AffineExpr substituteLoopInExpr(AffineExpr expr, AffineExpr dimExpr,
/// Traverse the `dims` and substitute known min or max expressions in place of
/// induction variables in `exprs`.
static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims,
SmallVectorImpl<Value> &symbols) {
static AffineMap substitute(
AffineMap map, SmallVectorImpl<Value> &dims,
SmallVectorImpl<Value> &symbols,
llvm::function_ref<bool(Operation *)> substituteOperation = nullptr) {
auto exprs = llvm::to_vector<4>(map.getResults());
for (AffineExpr &expr : exprs) {
bool substituted = true;
@ -549,17 +551,19 @@ static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims,
LLVM_DEBUG(DBGS() << "Subst: " << dim << " @ " << dimExpr << "\n");
AffineExpr substitutedExpr;
if (auto forOp = scf::getForInductionVarOwner(dim))
substitutedExpr = substituteLoopInExpr(
expr, dimExpr, forOp.lowerBound(), forOp.upperBound(),
forOp.step(), dims, symbols);
if (!substituteOperation || substituteOperation(forOp))
substitutedExpr = substituteLoopInExpr(
expr, dimExpr, forOp.lowerBound(), forOp.upperBound(),
forOp.step(), dims, symbols);
if (auto parallelForOp = scf::getParallelForInductionVarOwner(dim))
for (unsigned idx = 0, e = parallelForOp.getNumLoops(); idx < e;
++idx)
substitutedExpr = substituteLoopInExpr(
expr, dimExpr, parallelForOp.lowerBound()[idx],
parallelForOp.upperBound()[idx], parallelForOp.step()[idx],
dims, symbols);
if (!substituteOperation || substituteOperation(parallelForOp))
for (unsigned idx = 0, e = parallelForOp.getNumLoops(); idx < e;
++idx)
substitutedExpr = substituteLoopInExpr(
expr, dimExpr, parallelForOp.lowerBound()[idx],
parallelForOp.upperBound()[idx], parallelForOp.step()[idx],
dims, symbols);
if (!substitutedExpr)
continue;
@ -578,6 +582,9 @@ static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims,
exprs.front().getContext());
LLVM_DEBUG(DBGS() << "Map to simplify: " << map << "\n");
LLVM_DEBUG(DBGS() << "Operands:\n");
for (Value v : operands)
LLVM_DEBUG(DBGS() << v << "\n");
// Pull in affine.apply operations and compose them fully into the
// result.
@ -596,14 +603,38 @@ static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims,
return AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext());
}
/// Traverse the dims of the AffineMap of `affineMinOp` and substitute scf loop
/// induction variables by new expressions involving the lower or upper bound:
/// - If the AffineDimExpr mapped to a loop IV has a positive sign, it is
/// replaced by the loop upper bound.
/// - If the AffineDimExpr mapped to a loop IV has a negative sign, it is
/// replaced by the loop lower bound.
/// All loop induction variables are iteratively replaced, unless a
/// `substituteOperation` hook is passed to more finely determine which
/// operations are substituted.
/// This is used as an intermediate step in computing bounding boxes and
/// canonicalize AffineMinOps. All dim and symbol operands are assumed to have
/// positive values (positive orthant assumptions).
/// Return a new AffineMap, dims and symbols that have been canonicalized and
/// simplified.
AffineMapAndOperands mlir::linalg::substituteMin(
AffineMinOp affineMinOp,
llvm::function_ref<bool(Operation *)> substituteOperation) {
AffineMapAndOperands res{affineMinOp.getAffineMap(),
SmallVector<Value>(affineMinOp.getDimOperands()),
SmallVector<Value>(affineMinOp.getSymbolOperands())};
res.map = substitute(affineMinOp.getAffineMap(), res.dims, res.symbols,
substituteOperation);
return res;
}
LogicalResult AffineMinSCFCanonicalizationPattern::matchAndRewrite(
AffineMinOp minOp, PatternRewriter &rewriter) const {
LLVM_DEBUG(DBGS() << "Canonicalize AffineMinSCF: " << *minOp.getOperation()
<< "\n");
SmallVector<Value, 4> dims(minOp.getDimOperands()),
symbols(minOp.getSymbolOperands());
AffineMap map = substitute(minOp.getAffineMap(), dims, symbols);
auto affineMapAndOperands = substituteMin(minOp);
AffineMap map = affineMapAndOperands.map;
LLVM_DEBUG(DBGS() << "Resulting map: " << map << "\n");
@ -638,8 +669,8 @@ LogicalResult AffineMinSCFCanonicalizationPattern::matchAndRewrite(
rewriter.replaceOpWithNewOp<ConstantIndexOp>(minOp, cst.getValue());
} else {
auto resultMap = AffineMap::get(0, map.getNumSymbols(), {e}, ctx);
SmallVector<Value, 4> resultOperands = dims;
resultOperands.append(symbols.begin(), symbols.end());
SmallVector<Value> resultOperands = affineMapAndOperands.dims;
llvm::append_range(resultOperands, affineMapAndOperands.symbols);
canonicalizeMapAndOperands(&resultMap, &resultOperands);
resultMap = simplifyAffineMap(resultMap);
rewriter.replaceOpWithNewOp<AffineApplyOp>(minOp, resultMap,