[mlir][linalg] Pass all operands to tile to the tile loop region builder (NFC).

Extend the signature of the tile loop nest region builder to take all operand values to use and not just the scf::For iterArgs. This change allows us to pass in all block arguments of TiledLoop and use them directly instead of replacing them after the loop generation.

Reviewed By: pifon2a

Differential Revision: https://reviews.llvm.org/D109569
This commit is contained in:
Tobias Gysi 2021-09-10 08:34:56 +00:00
parent baf1444929
commit 16488dc300
4 changed files with 35 additions and 31 deletions

View File

@ -263,7 +263,7 @@ struct RegionMatcher {
/// Utility class used to generate nested loops with ranges described by /// Utility class used to generate nested loops with ranges described by
/// `loopRanges` and loop type described by the `iteratorTypes`. `bodyBuilderFn` /// `loopRanges` and loop type described by the `iteratorTypes`. `bodyBuilderFn`
/// is used to generate the body of the innermost loop. It is passed a range /// is used to generate the body of the innermost loop. It is passed a range
/// of loop induction variables and a range of iterArgs. /// of loop induction variables and a range of operand values to use.
template <typename LoopTy> template <typename LoopTy>
struct GenerateLoopNest { struct GenerateLoopNest {
static void doit(OpBuilder &b, Location loc, ArrayRef<Range> loopRanges, static void doit(OpBuilder &b, Location loc, ArrayRef<Range> loopRanges,

View File

@ -431,8 +431,9 @@ static Optional<LinalgLoops> linalgOpToLoopsImpl(PatternRewriter &rewriter,
GenerateLoopNest<LoopTy>::doit( GenerateLoopNest<LoopTy>::doit(
rewriter, linalgOp.getLoc(), loopRanges, linalgOp, iteratorTypes, rewriter, linalgOp.getLoc(), loopRanges, linalgOp, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange ivs, [&](OpBuilder &b, Location loc, ValueRange ivs,
ValueRange iterArgs) -> scf::ValueVector { ValueRange operandValuesToUse) -> scf::ValueVector {
assert(iterArgs.empty() && "unexpected iterArgs"); assert(operandValuesToUse == linalgOp->getOperands() &&
"expect operands are captured and not passed by loop argument");
allIvs.append(ivs.begin(), ivs.end()); allIvs.append(ivs.begin(), ivs.end());
llvm::TypeSwitch<Operation *>(linalgOp) llvm::TypeSwitch<Operation *>(linalgOp)
.Case<ConvOp, PoolingMaxOp, PoolingMinOp, PoolingSumOp, LinalgOp>( .Case<ConvOp, PoolingMaxOp, PoolingMinOp, PoolingSumOp, LinalgOp>(

View File

@ -227,9 +227,9 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
// 2. Create the tiled loops. // 2. Create the tiled loops.
LinalgOp res = op; LinalgOp res = op;
SmallVector<Value, 4> ivs, tensorResults; SmallVector<Value, 4> ivs, tensorResults;
auto tiledLoopBodyBuilder = [&](OpBuilder &b, Location loc, auto tiledLoopBodyBuilder =
ValueRange localIvs, [&](OpBuilder &b, Location loc, ValueRange localIvs,
ValueRange iterArgs) -> scf::ValueVector { ValueRange operandValuesToUse) -> scf::ValueVector {
ivs.assign(localIvs.begin(), localIvs.end()); ivs.assign(localIvs.begin(), localIvs.end());
// When an `interchangeVector` is present, it has been applied to the // When an `interchangeVector` is present, it has been applied to the
@ -241,20 +241,16 @@ tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes,
else else
interchangedIvs.assign(ivs.begin(), ivs.end()); interchangedIvs.assign(ivs.begin(), ivs.end());
assert(op.getOutputTensorOperands().size() == iterArgs.size() && // Tile the `operandValuesToUse` that either match the `op` operands
"num output tensors must match number of loop iter arguments"); // themselves or the tile loop arguments forwarding them.
assert(operandValuesToUse.size() ==
SmallVector<Value> operands = op.getInputOperands(); static_cast<size_t>(op.getNumInputsAndOutputs()) &&
SmallVector<Value> outputBuffers = op.getOutputBufferOperands(); "expect the number of operands and inputs and outputs to match");
// TODO: thanks to simplifying assumption we do not need to worry about SmallVector<Value> valuesToTile = operandValuesToUse;
// order of output buffers and tensors: there is only ever one kind.
assert(outputBuffers.empty() || iterArgs.empty());
operands.append(outputBuffers.begin(), outputBuffers.end());
operands.append(iterArgs.begin(), iterArgs.end());
auto sizeBounds = auto sizeBounds =
applyMapToValues(b, loc, shapeSizesToLoopsMap, allShapeSizes); applyMapToValues(b, loc, shapeSizesToLoopsMap, allShapeSizes);
SmallVector<Value, 4> tiledOperands = makeTiledShapes( SmallVector<Value, 4> tiledOperands = makeTiledShapes(
b, loc, op, operands, interchangedIvs, tileSizes, sizeBounds); b, loc, op, valuesToTile, interchangedIvs, tileSizes, sizeBounds);
// TODO: use an interface/adaptor to avoid leaking position in // TODO: use an interface/adaptor to avoid leaking position in
// `tiledOperands`. // `tiledOperands`.

View File

@ -225,7 +225,18 @@ void GenerateLoopNest<scf::ForOp>::doit(
SmallVector<Value, 4> lbs, ubs, steps; SmallVector<Value, 4> lbs, ubs, steps;
unpackRanges(loopRanges, lbs, ubs, steps); unpackRanges(loopRanges, lbs, ubs, steps);
LoopNest loopNest = mlir::scf::buildLoopNest( LoopNest loopNest = mlir::scf::buildLoopNest(
b, loc, lbs, ubs, steps, iterArgInitValues, bodyBuilderFn); b, loc, lbs, ubs, steps, iterArgInitValues,
[&](OpBuilder &b, Location loc, ValueRange ivs, ValueRange iterArgs) {
assert(iterArgs.size() == linalgOp.getOutputTensorOperands().size() &&
"expect the number of output tensors and iter args to match");
SmallVector<Value> operandValuesToUse =
linalgOp.getInputAndOutputOperands();
if (!iterArgs.empty()) {
operandValuesToUse = linalgOp.getInputOperands();
operandValuesToUse.append(iterArgs.begin(), iterArgs.end());
}
return bodyBuilderFn(b, loc, ivs, operandValuesToUse);
});
if (!distributionOptions || loopNest.loops.empty()) if (!distributionOptions || loopNest.loops.empty())
return; return;
@ -268,7 +279,9 @@ void GenerateLoopNest<AffineForOp>::doit(
mlir::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps, mlir::buildAffineLoopNest(b, loc, lbs, ubs, constantSteps,
[&](OpBuilder &b, Location loc, ValueRange ivs) { [&](OpBuilder &b, Location loc, ValueRange ivs) {
bodyBuilderFn(b, loc, ivs, {}); SmallVector<Value> operandValuesToUse =
linalgOp.getInputAndOutputOperands();
bodyBuilderFn(b, loc, ivs, operandValuesToUse);
}); });
} }
@ -289,9 +302,10 @@ void GenerateLoopNest<TiledLoopOp>::doit(
auto wrappedBuilderFn = [&](OpBuilder &nestedBuilder, Location nestedLoc, auto wrappedBuilderFn = [&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange ivs, ValueRange inputs, ValueRange ivs, ValueRange inputs,
ValueRange outputs) { ValueRange outputs) {
SmallVector<Value> outputTensors = linalgOp.getOutputTensorOperands(); SmallVector<Value> operandValuesToUse = inputs;
operandValuesToUse.append(outputs.begin(), outputs.end());
scf::ValueVector results = scf::ValueVector results =
bodyBuilderFn(nestedBuilder, nestedLoc, ivs, outputTensors); bodyBuilderFn(nestedBuilder, nestedLoc, ivs, operandValuesToUse);
nestedBuilder.create<linalg::YieldOp>(nestedLoc, results); nestedBuilder.create<linalg::YieldOp>(nestedLoc, results);
}; };
@ -302,15 +316,6 @@ void GenerateLoopNest<TiledLoopOp>::doit(
b.getArrayAttr(iteratorTypes), wrappedBuilderFn); b.getArrayAttr(iteratorTypes), wrappedBuilderFn);
if (!distributionTypes.empty()) if (!distributionTypes.empty())
tiledLoop.setDistributionTypes(b, distributionTypes); tiledLoop.setDistributionTypes(b, distributionTypes);
// Replace inputs/outputs with the corresponding region args.
auto isInsideTiledLoop = [&](OpOperand &operand) {
return operand.getOwner()->getBlock() == tiledLoop.getBody();
};
for (auto it : llvm::zip(inputOperands, tiledLoop.getRegionInputArgs()))
std::get<0>(it).replaceUsesWithIf(std::get<1>(it), isInsideTiledLoop);
for (auto it : llvm::zip(outputOperands, tiledLoop.getRegionOutputArgs()))
std::get<0>(it).replaceUsesWithIf(std::get<1>(it), isInsideTiledLoop);
} }
/// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`. /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`.
@ -505,7 +510,9 @@ void GenerateLoopNest<scf::ParallelOp>::doit(
generateParallelLoopNest( generateParallelLoopNest(
b, loc, lbs, ubs, steps, iteratorTypes, b, loc, lbs, ubs, steps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange ivs) { [&](OpBuilder &b, Location loc, ValueRange ivs) {
bodyBuilderFn(b, loc, ivs, {}); SmallVector<Value> operandValuesToUse =
linalgOp.getInputAndOutputOperands();
bodyBuilderFn(b, loc, ivs, operandValuesToUse);
}, },
ivs, distributionMethod); ivs, distributionMethod);