[mlir][linalg] Cleanup LinalgOp usage in fusion (NFC).

Replace the uses of deprecated Structured Op Interface methods in Fusion.cpp. This patch is based on https://reviews.llvm.org/D103394.

Differential Revision: https://reviews.llvm.org/D103437
This commit is contained in:
Tobias Gysi 2021-06-01 08:20:58 +00:00
parent c2e5226a85
commit 7594f5028a
1 changed files with 43 additions and 46 deletions

View File

@ -69,10 +69,9 @@ struct ShapeDimension {
static ShapeDimension
getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
bool fromSubViewOpOnly = false) {
auto maps = op.indexing_maps();
// Iterate over the inputs and outputs in order.
// Extract the subranges from the linearized ranges.
for (auto en : llvm::enumerate(op.getShapedOperands())) {
for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
// The method `getRangeFromOperandShape` requires using SubViewOp or
// SubTensorOps. If the value isnt defined from there continue.
// todo: The method should be adapted to get the values from
@ -80,27 +79,26 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
// currently returns a `linalg.range`. The fix here is to move this op to
// `std` dialect and add the method to `ViewInterface`.
if (fromSubViewOpOnly && !isa_and_nonnull<memref::SubViewOp, SubTensorOp>(
en.value().getDefiningOp()))
opOperand->get().getDefiningOp()))
continue;
unsigned idx = en.index();
auto map = maps[idx].cast<AffineMapAttr>().getValue();
LLVM_DEBUG(llvm::dbgs()
<< "getShapeDefiningLoopRange I/O idx: " << idx << "\n");
AffineMap map = op.getTiedIndexingMap(opOperand);
LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange I/O idx: "
<< opOperand->getOperandNumber() << "\n");
LLVM_DEBUG(llvm::dbgs()
<< "getShapeDefiningLoopRange map: " << map << "\n");
Value shape = en.value();
SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
for (auto en2 : llvm::enumerate(map.getResults())) {
auto dimExpr = en2.value().dyn_cast<AffineDimExpr>();
for (auto en : llvm::enumerate(map.getResults())) {
auto dimExpr = en.value().dyn_cast<AffineDimExpr>();
if (!dimExpr)
continue;
if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
if (loopDepth == en.value().cast<AffineDimExpr>().getPosition()) {
LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
<< loopDepth << "\n");
LLVM_DEBUG(llvm::dbgs()
<< "getShapeDefiningLoopRange shape: " << shape << "\n");
return ShapeDimension{shape, static_cast<unsigned>(en2.index())};
LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange shape: "
<< opOperand->get() << "\n");
return ShapeDimension{opOperand->get(),
static_cast<unsigned>(en.index())};
}
}
}
@ -122,26 +120,24 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
// would need to add the intermediate results to `linalg.yield`. After that a
// canonicalization pass would move the unused output args of the `tiled_loop`
// to the `input` section.
static SmallVector<Value, 4> getTiledOperands(OpBuilder &b, LinalgOp producer) {
static SmallVector<Value> getTiledOperands(OpBuilder &b, LinalgOp producer) {
auto tiledLoop = dyn_cast<TiledLoopOp>(b.getBlock()->getParentOp());
if (!tiledLoop)
return llvm::to_vector<4>(producer.getShapedOperands());
return producer.getInputAndOutputOperands();
SmallVector<Value, 4> tiledOperands;
SmallVector<Value> tiledOperands;
assert(producer.hasTensorSemantics() &&
"only fusion on tensors is currently supported for TiledLinalgOp");
for (auto producerInput : producer.getInputTensors()) {
OpOperand *addedInput = tiledLoop.findInputOperand(producerInput);
for (OpOperand *producerInput : producer.getInputTensorOperands()) {
OpOperand *addedInput = tiledLoop.findInputOperand(producerInput->get());
if (addedInput == nullptr)
addedInput = &tiledLoop.appendInputOperand(b, producerInput);
addedInput = &tiledLoop.appendInputOperand(b, producerInput->get());
BlockArgument addedBlockArg = tiledLoop.getTiedBlockArgument(*addedInput);
tiledOperands.push_back(addedBlockArg);
}
for (auto &en : llvm::enumerate(producer.getOutputTensors())) {
Value producerOutput = en.value();
Value result = producer->getResult(en.index());
for (OpOperand *producerOutput : producer.getOutputTensorOperands()) {
OpResult result = producer.getTiedOpResult(producerOutput);
OpOperand *resultInputOperand = tiledLoop.findInputOperand(result);
OpOperand *resultOutputOperand = tiledLoop.findOutputOperand(result);
assert((resultInputOperand != nullptr) ^ (resultOutputOperand != nullptr) &&
@ -152,10 +148,11 @@ static SmallVector<Value, 4> getTiledOperands(OpBuilder &b, LinalgOp producer) {
int opNumber = isInput ? resultInputOperand->getOperandNumber()
: resultOutputOperand->getOperandNumber();
OpOperand *addedOutput = tiledLoop.findOutputOperand(producerOutput);
OpOperand *addedOutput = tiledLoop.findOutputOperand(producerOutput->get());
if (addedOutput == nullptr)
addedOutput = isInput ? &tiledLoop.appendInputOperand(b, producerOutput)
: &tiledLoop.appendOutputOperand(b, producerOutput);
addedOutput =
isInput ? &tiledLoop.appendInputOperand(b, producerOutput->get())
: &tiledLoop.appendOutputOperand(b, producerOutput->get());
OpOperand &resultOperand = tiledLoop->getOpOperand(opNumber);
auto addedBlockArg = tiledLoop.getTiedBlockArgument(*addedOutput);
@ -200,7 +197,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
}
SmallVector<Value, 8> clonedShapes;
clonedShapes.reserve(producer.getNumShapedOperands());
clonedShapes.reserve(producer.getNumInputsAndOutputs());
// Compute subranges for all tensor input/output operands.
clonedShapes.append(makeTiledShapes(b, loc, producer,
@ -267,16 +264,9 @@ static Range getRangeFromOperandShape(OpBuilder &b, Location loc,
llvm_unreachable("SubviewOp or SubTensorOp expected");
}
/// Fuses the producer of `producerIdx` into the loop immediately enclosing
/// `consumer`. This is achieved by "recomputing" the `producer` at the time it
/// is needed just before the `consumer.
///
/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are
/// 2 cases:
/// 1. Buffer case: `producerIdx` is the index of the buffer in
/// `producer.getOutputBuffers()`.
/// 2. Tensor case: `producerIdx` is the index of the tensor in
/// `producer.getResults()`.
/// Fuses the producer into the loop immediately enclosing the consumer.
/// This is achieved by "recomputing" the producer at the time it
/// is needed just before the consumer.
static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp, AffineMap producerMap,
OpOperand &consumerOpOperand) {
LLVM_DEBUG(llvm::dbgs() << "Producer map: " << producerMap << "\n");
@ -548,9 +538,10 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(consumerOp);
LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumerOp << "\n");
OpOperand *opOperand =
producerOp.getOutputOperand(producerOpResult.getResultNumber());
LinalgOp fusedProducer =
fuse(b, producerOp,
producerOp.getOutputIndexingMap(producerOpResult.getResultNumber()),
fuse(b, producerOp, producerOp.getTiedIndexingMap(opOperand),
consumerOpOperand);
// Replace use.
@ -770,9 +761,9 @@ FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
FusableOpDependencesTy fusableDependences;
DenseMap<Operation *, SmallVector<AffineMap, 1>> fusedProducerIndexingMap;
for (LinalgOp op : reverse(ops)) {
for (OpOperand &opOperand : op.getShapedOpOperands()) {
for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
fusableDependence = findFusableProducer(opOperand, dependenceGraph);
fusableDependence = findFusableProducer(*opOperand, dependenceGraph);
if (!fusableDependence)
continue;
// Canonicalize indexed generic ops before fusion.
@ -905,10 +896,16 @@ fuseOperations(OpBuilder &b, LinalgOp rootOp, TiledLinalgOp tiledLinalgOp,
// To keep the second type of information while letting the unfused op die
// unused, we need to forward the producer output operand.
if (auto forOp = dyn_cast<scf::ForOp>(tiledLinalgOp.loops.front())) {
for (auto &operand : forOp.getIterOpOperands())
if (auto opResult = operand.get().dyn_cast<OpResult>())
if (opResult.getOwner() == origOp)
operand.set(origOp.getOutputTensors()[opResult.getResultNumber()]);
for (auto &operand : forOp.getIterOpOperands()) {
if (auto opResult = operand.get().dyn_cast<OpResult>()) {
if (opResult.getOwner() == origOp) {
Value output =
origOp.getOutputOperand(opResult.getResultNumber())->get();
assert(output.getType().isa<RankedTensorType>());
operand.set(output);
}
}
}
}
}
return fusedOps;