forked from OSchip/llvm-project
[mlir][Linalg] NFC - Expose helper function `substituteMin`.
This commit is contained in:
parent
1fe042041c
commit
5b2d8503d1
|
@ -893,6 +893,30 @@ struct AffineMinSCFCanonicalizationPattern
|
|||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// Helper struct to return the results of `substituteMin`.
|
||||
struct AffineMapAndOperands {
|
||||
AffineMap map;
|
||||
SmallVector<Value> dims;
|
||||
SmallVector<Value> symbols;
|
||||
};
|
||||
/// Traverse the dims of the AffineMap of `affineMinOp` and substitute scf loop
|
||||
/// induction variables by new expressions involving the lower or upper bound:
|
||||
/// - If the AffineDimExpr mapped to a loop IV has a positive sign, it is
|
||||
/// replaced by the loop upper bound.
|
||||
/// - If the AffineDimExpr mapped to a loop IV has a negative sign, it is
|
||||
/// replaced by the loop lower bound.
|
||||
/// All loop induction variables are iteratively replaced, unless a
|
||||
/// `substituteOperation` hook is passed to more finely determine which
|
||||
/// operations are substituted.
|
||||
/// This is used as an intermediate step in computing bounding boxes and
|
||||
/// canonicalize AffineMinOps. All dim and symbol operands are assumed to have
|
||||
/// positive values (positive orthant assumptions).
|
||||
/// Return a new AffineMap, dims and symbols that have been canonicalized and
|
||||
/// simplified.
|
||||
AffineMapAndOperands substituteMin(
|
||||
AffineMinOp affineMinOp,
|
||||
llvm::function_ref<bool(Operation *)> substituteOperation = nullptr);
|
||||
|
||||
/// Converts Convolution op into vector contraction.
|
||||
///
|
||||
/// Conversion expects ConvOp to have dimensions marked in the *mask* as
|
||||
|
|
|
@ -536,8 +536,10 @@ static AffineExpr substituteLoopInExpr(AffineExpr expr, AffineExpr dimExpr,
|
|||
|
||||
/// Traverse the `dims` and substitute known min or max expressions in place of
|
||||
/// induction variables in `exprs`.
|
||||
static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims,
|
||||
SmallVectorImpl<Value> &symbols) {
|
||||
static AffineMap substitute(
|
||||
AffineMap map, SmallVectorImpl<Value> &dims,
|
||||
SmallVectorImpl<Value> &symbols,
|
||||
llvm::function_ref<bool(Operation *)> substituteOperation = nullptr) {
|
||||
auto exprs = llvm::to_vector<4>(map.getResults());
|
||||
for (AffineExpr &expr : exprs) {
|
||||
bool substituted = true;
|
||||
|
@ -549,17 +551,19 @@ static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims,
|
|||
LLVM_DEBUG(DBGS() << "Subst: " << dim << " @ " << dimExpr << "\n");
|
||||
AffineExpr substitutedExpr;
|
||||
if (auto forOp = scf::getForInductionVarOwner(dim))
|
||||
substitutedExpr = substituteLoopInExpr(
|
||||
expr, dimExpr, forOp.lowerBound(), forOp.upperBound(),
|
||||
forOp.step(), dims, symbols);
|
||||
if (!substituteOperation || substituteOperation(forOp))
|
||||
substitutedExpr = substituteLoopInExpr(
|
||||
expr, dimExpr, forOp.lowerBound(), forOp.upperBound(),
|
||||
forOp.step(), dims, symbols);
|
||||
|
||||
if (auto parallelForOp = scf::getParallelForInductionVarOwner(dim))
|
||||
for (unsigned idx = 0, e = parallelForOp.getNumLoops(); idx < e;
|
||||
++idx)
|
||||
substitutedExpr = substituteLoopInExpr(
|
||||
expr, dimExpr, parallelForOp.lowerBound()[idx],
|
||||
parallelForOp.upperBound()[idx], parallelForOp.step()[idx],
|
||||
dims, symbols);
|
||||
if (!substituteOperation || substituteOperation(parallelForOp))
|
||||
for (unsigned idx = 0, e = parallelForOp.getNumLoops(); idx < e;
|
||||
++idx)
|
||||
substitutedExpr = substituteLoopInExpr(
|
||||
expr, dimExpr, parallelForOp.lowerBound()[idx],
|
||||
parallelForOp.upperBound()[idx], parallelForOp.step()[idx],
|
||||
dims, symbols);
|
||||
|
||||
if (!substitutedExpr)
|
||||
continue;
|
||||
|
@ -578,6 +582,9 @@ static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims,
|
|||
exprs.front().getContext());
|
||||
|
||||
LLVM_DEBUG(DBGS() << "Map to simplify: " << map << "\n");
|
||||
LLVM_DEBUG(DBGS() << "Operands:\n");
|
||||
for (Value v : operands)
|
||||
LLVM_DEBUG(DBGS() << v << "\n");
|
||||
|
||||
// Pull in affine.apply operations and compose them fully into the
|
||||
// result.
|
||||
|
@ -596,14 +603,38 @@ static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims,
|
|||
return AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext());
|
||||
}
|
||||
|
||||
/// Traverse the dims of the AffineMap of `affineMinOp` and substitute scf loop
|
||||
/// induction variables by new expressions involving the lower or upper bound:
|
||||
/// - If the AffineDimExpr mapped to a loop IV has a positive sign, it is
|
||||
/// replaced by the loop upper bound.
|
||||
/// - If the AffineDimExpr mapped to a loop IV has a negative sign, it is
|
||||
/// replaced by the loop lower bound.
|
||||
/// All loop induction variables are iteratively replaced, unless a
|
||||
/// `substituteOperation` hook is passed to more finely determine which
|
||||
/// operations are substituted.
|
||||
/// This is used as an intermediate step in computing bounding boxes and
|
||||
/// canonicalize AffineMinOps. All dim and symbol operands are assumed to have
|
||||
/// positive values (positive orthant assumptions).
|
||||
/// Return a new AffineMap, dims and symbols that have been canonicalized and
|
||||
/// simplified.
|
||||
AffineMapAndOperands mlir::linalg::substituteMin(
|
||||
AffineMinOp affineMinOp,
|
||||
llvm::function_ref<bool(Operation *)> substituteOperation) {
|
||||
AffineMapAndOperands res{affineMinOp.getAffineMap(),
|
||||
SmallVector<Value>(affineMinOp.getDimOperands()),
|
||||
SmallVector<Value>(affineMinOp.getSymbolOperands())};
|
||||
res.map = substitute(affineMinOp.getAffineMap(), res.dims, res.symbols,
|
||||
substituteOperation);
|
||||
return res;
|
||||
}
|
||||
|
||||
LogicalResult AffineMinSCFCanonicalizationPattern::matchAndRewrite(
|
||||
AffineMinOp minOp, PatternRewriter &rewriter) const {
|
||||
LLVM_DEBUG(DBGS() << "Canonicalize AffineMinSCF: " << *minOp.getOperation()
|
||||
<< "\n");
|
||||
|
||||
SmallVector<Value, 4> dims(minOp.getDimOperands()),
|
||||
symbols(minOp.getSymbolOperands());
|
||||
AffineMap map = substitute(minOp.getAffineMap(), dims, symbols);
|
||||
auto affineMapAndOperands = substituteMin(minOp);
|
||||
AffineMap map = affineMapAndOperands.map;
|
||||
|
||||
LLVM_DEBUG(DBGS() << "Resulting map: " << map << "\n");
|
||||
|
||||
|
@ -638,8 +669,8 @@ LogicalResult AffineMinSCFCanonicalizationPattern::matchAndRewrite(
|
|||
rewriter.replaceOpWithNewOp<ConstantIndexOp>(minOp, cst.getValue());
|
||||
} else {
|
||||
auto resultMap = AffineMap::get(0, map.getNumSymbols(), {e}, ctx);
|
||||
SmallVector<Value, 4> resultOperands = dims;
|
||||
resultOperands.append(symbols.begin(), symbols.end());
|
||||
SmallVector<Value> resultOperands = affineMapAndOperands.dims;
|
||||
llvm::append_range(resultOperands, affineMapAndOperands.symbols);
|
||||
canonicalizeMapAndOperands(&resultMap, &resultOperands);
|
||||
resultMap = simplifyAffineMap(resultMap);
|
||||
rewriter.replaceOpWithNewOp<AffineApplyOp>(minOp, resultMap,
|
||||
|
|
Loading…
Reference in New Issue