forked from OSchip/llvm-project
[mlir][linalg] Cleanup LinalgOp usage in fusion (NFC).
Replace the uses of deprecated Structured Op Interface methods in Fusion.cpp. This patch is based on https://reviews.llvm.org/D103394. Differential Revision: https://reviews.llvm.org/D103437
This commit is contained in:
parent
c2e5226a85
commit
7594f5028a
|
@ -69,10 +69,9 @@ struct ShapeDimension {
|
|||
static ShapeDimension
|
||||
getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
|
||||
bool fromSubViewOpOnly = false) {
|
||||
auto maps = op.indexing_maps();
|
||||
// Iterate over the inputs and outputs in order.
|
||||
// Extract the subranges from the linearized ranges.
|
||||
for (auto en : llvm::enumerate(op.getShapedOperands())) {
|
||||
for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
|
||||
// The method `getRangeFromOperandShape` requires using SubViewOp or
|
||||
// SubTensorOps. If the value isnt defined from there continue.
|
||||
// todo: The method should be adapted to get the values from
|
||||
|
@ -80,27 +79,26 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
|
|||
// currently returns a `linalg.range`. The fix here is to move this op to
|
||||
// `std` dialect and add the method to `ViewInterface`.
|
||||
if (fromSubViewOpOnly && !isa_and_nonnull<memref::SubViewOp, SubTensorOp>(
|
||||
en.value().getDefiningOp()))
|
||||
opOperand->get().getDefiningOp()))
|
||||
continue;
|
||||
|
||||
unsigned idx = en.index();
|
||||
auto map = maps[idx].cast<AffineMapAttr>().getValue();
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "getShapeDefiningLoopRange I/O idx: " << idx << "\n");
|
||||
AffineMap map = op.getTiedIndexingMap(opOperand);
|
||||
LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange I/O idx: "
|
||||
<< opOperand->getOperandNumber() << "\n");
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "getShapeDefiningLoopRange map: " << map << "\n");
|
||||
Value shape = en.value();
|
||||
SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
|
||||
for (auto en2 : llvm::enumerate(map.getResults())) {
|
||||
auto dimExpr = en2.value().dyn_cast<AffineDimExpr>();
|
||||
for (auto en : llvm::enumerate(map.getResults())) {
|
||||
auto dimExpr = en.value().dyn_cast<AffineDimExpr>();
|
||||
if (!dimExpr)
|
||||
continue;
|
||||
if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
|
||||
if (loopDepth == en.value().cast<AffineDimExpr>().getPosition()) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
|
||||
<< loopDepth << "\n");
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< "getShapeDefiningLoopRange shape: " << shape << "\n");
|
||||
return ShapeDimension{shape, static_cast<unsigned>(en2.index())};
|
||||
LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange shape: "
|
||||
<< opOperand->get() << "\n");
|
||||
return ShapeDimension{opOperand->get(),
|
||||
static_cast<unsigned>(en.index())};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -122,26 +120,24 @@ getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
|
|||
// would need to add the intermediate results to `linalg.yield`. After that a
|
||||
// canonicalization pass would move the unused output args of the `tiled_loop`
|
||||
// to the `input` section.
|
||||
static SmallVector<Value, 4> getTiledOperands(OpBuilder &b, LinalgOp producer) {
|
||||
static SmallVector<Value> getTiledOperands(OpBuilder &b, LinalgOp producer) {
|
||||
auto tiledLoop = dyn_cast<TiledLoopOp>(b.getBlock()->getParentOp());
|
||||
if (!tiledLoop)
|
||||
return llvm::to_vector<4>(producer.getShapedOperands());
|
||||
return producer.getInputAndOutputOperands();
|
||||
|
||||
SmallVector<Value, 4> tiledOperands;
|
||||
SmallVector<Value> tiledOperands;
|
||||
assert(producer.hasTensorSemantics() &&
|
||||
"only fusion on tensors is currently supported for TiledLinalgOp");
|
||||
|
||||
for (auto producerInput : producer.getInputTensors()) {
|
||||
OpOperand *addedInput = tiledLoop.findInputOperand(producerInput);
|
||||
for (OpOperand *producerInput : producer.getInputTensorOperands()) {
|
||||
OpOperand *addedInput = tiledLoop.findInputOperand(producerInput->get());
|
||||
if (addedInput == nullptr)
|
||||
addedInput = &tiledLoop.appendInputOperand(b, producerInput);
|
||||
addedInput = &tiledLoop.appendInputOperand(b, producerInput->get());
|
||||
BlockArgument addedBlockArg = tiledLoop.getTiedBlockArgument(*addedInput);
|
||||
tiledOperands.push_back(addedBlockArg);
|
||||
}
|
||||
for (auto &en : llvm::enumerate(producer.getOutputTensors())) {
|
||||
Value producerOutput = en.value();
|
||||
|
||||
Value result = producer->getResult(en.index());
|
||||
for (OpOperand *producerOutput : producer.getOutputTensorOperands()) {
|
||||
OpResult result = producer.getTiedOpResult(producerOutput);
|
||||
OpOperand *resultInputOperand = tiledLoop.findInputOperand(result);
|
||||
OpOperand *resultOutputOperand = tiledLoop.findOutputOperand(result);
|
||||
assert((resultInputOperand != nullptr) ^ (resultOutputOperand != nullptr) &&
|
||||
|
@ -152,10 +148,11 @@ static SmallVector<Value, 4> getTiledOperands(OpBuilder &b, LinalgOp producer) {
|
|||
int opNumber = isInput ? resultInputOperand->getOperandNumber()
|
||||
: resultOutputOperand->getOperandNumber();
|
||||
|
||||
OpOperand *addedOutput = tiledLoop.findOutputOperand(producerOutput);
|
||||
OpOperand *addedOutput = tiledLoop.findOutputOperand(producerOutput->get());
|
||||
if (addedOutput == nullptr)
|
||||
addedOutput = isInput ? &tiledLoop.appendInputOperand(b, producerOutput)
|
||||
: &tiledLoop.appendOutputOperand(b, producerOutput);
|
||||
addedOutput =
|
||||
isInput ? &tiledLoop.appendInputOperand(b, producerOutput->get())
|
||||
: &tiledLoop.appendOutputOperand(b, producerOutput->get());
|
||||
|
||||
OpOperand &resultOperand = tiledLoop->getOpOperand(opNumber);
|
||||
auto addedBlockArg = tiledLoop.getTiedBlockArgument(*addedOutput);
|
||||
|
@ -200,7 +197,7 @@ static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
|
|||
}
|
||||
|
||||
SmallVector<Value, 8> clonedShapes;
|
||||
clonedShapes.reserve(producer.getNumShapedOperands());
|
||||
clonedShapes.reserve(producer.getNumInputsAndOutputs());
|
||||
|
||||
// Compute subranges for all tensor input/output operands.
|
||||
clonedShapes.append(makeTiledShapes(b, loc, producer,
|
||||
|
@ -267,16 +264,9 @@ static Range getRangeFromOperandShape(OpBuilder &b, Location loc,
|
|||
llvm_unreachable("SubviewOp or SubTensorOp expected");
|
||||
}
|
||||
|
||||
/// Fuses the producer of `producerIdx` into the loop immediately enclosing
|
||||
/// `consumer`. This is achieved by "recomputing" the `producer` at the time it
|
||||
/// is needed just before the `consumer.
|
||||
///
|
||||
/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are
|
||||
/// 2 cases:
|
||||
/// 1. Buffer case: `producerIdx` is the index of the buffer in
|
||||
/// `producer.getOutputBuffers()`.
|
||||
/// 2. Tensor case: `producerIdx` is the index of the tensor in
|
||||
/// `producer.getResults()`.
|
||||
/// Fuses the producer into the loop immediately enclosing the consumer.
|
||||
/// This is achieved by "recomputing" the producer at the time it
|
||||
/// is needed just before the consumer.
|
||||
static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp, AffineMap producerMap,
|
||||
OpOperand &consumerOpOperand) {
|
||||
LLVM_DEBUG(llvm::dbgs() << "Producer map: " << producerMap << "\n");
|
||||
|
@ -548,9 +538,10 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
|
|||
OpBuilder::InsertionGuard g(b);
|
||||
b.setInsertionPoint(consumerOp);
|
||||
LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumerOp << "\n");
|
||||
OpOperand *opOperand =
|
||||
producerOp.getOutputOperand(producerOpResult.getResultNumber());
|
||||
LinalgOp fusedProducer =
|
||||
fuse(b, producerOp,
|
||||
producerOp.getOutputIndexingMap(producerOpResult.getResultNumber()),
|
||||
fuse(b, producerOp, producerOp.getTiedIndexingMap(opOperand),
|
||||
consumerOpOperand);
|
||||
|
||||
// Replace use.
|
||||
|
@ -770,9 +761,9 @@ FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
|
|||
FusableOpDependencesTy fusableDependences;
|
||||
DenseMap<Operation *, SmallVector<AffineMap, 1>> fusedProducerIndexingMap;
|
||||
for (LinalgOp op : reverse(ops)) {
|
||||
for (OpOperand &opOperand : op.getShapedOpOperands()) {
|
||||
for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
|
||||
Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
|
||||
fusableDependence = findFusableProducer(opOperand, dependenceGraph);
|
||||
fusableDependence = findFusableProducer(*opOperand, dependenceGraph);
|
||||
if (!fusableDependence)
|
||||
continue;
|
||||
// Canonicalize indexed generic ops before fusion.
|
||||
|
@ -905,10 +896,16 @@ fuseOperations(OpBuilder &b, LinalgOp rootOp, TiledLinalgOp tiledLinalgOp,
|
|||
// To keep the second type of information while letting the unfused op die
|
||||
// unused, we need to forward the producer output operand.
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(tiledLinalgOp.loops.front())) {
|
||||
for (auto &operand : forOp.getIterOpOperands())
|
||||
if (auto opResult = operand.get().dyn_cast<OpResult>())
|
||||
if (opResult.getOwner() == origOp)
|
||||
operand.set(origOp.getOutputTensors()[opResult.getResultNumber()]);
|
||||
for (auto &operand : forOp.getIterOpOperands()) {
|
||||
if (auto opResult = operand.get().dyn_cast<OpResult>()) {
|
||||
if (opResult.getOwner() == origOp) {
|
||||
Value output =
|
||||
origOp.getOutputOperand(opResult.getResultNumber())->get();
|
||||
assert(output.getType().isa<RankedTensorType>());
|
||||
operand.set(output);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return fusedOps;
|
||||
|
|
Loading…
Reference in New Issue