forked from OSchip/llvm-project
[mlir][Linalg] NFC - Cleanup internal transform APIs and produce better messages on failure to apply.
This commit is contained in:
parent
92e9bddc49
commit
12831be96c
|
@ -228,12 +228,16 @@ LogicalResult transform::FuseOp::verify() {
|
|||
/// Find the first "extract" user of `producerOp` and tile it right before its
|
||||
/// use. The tiled op is fused under the `containingOp`.
|
||||
/// Return this fused op on success or nullptr if anything fails.
|
||||
static Operation *tileAndFuseFirstExtractUse(Operation *producerOp,
|
||||
Operation *containingOp,
|
||||
RewriterBase &rewriter) {
|
||||
static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
|
||||
Diagnostic &diag,
|
||||
Operation *producerOp,
|
||||
Operation *containingOp) {
|
||||
auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
|
||||
if (!tileableProducer)
|
||||
if (!tileableProducer) {
|
||||
diag.attachNote(producerOp->getLoc())
|
||||
<< "producer is not a TileableInterface: " << *producerOp;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Search the producer slices accessed within the containing operation.
|
||||
// TODO: Generalize to more extract/insert/parallel_insert triples, maybe
|
||||
|
@ -244,8 +248,11 @@ static Operation *tileAndFuseFirstExtractUse(Operation *producerOp,
|
|||
});
|
||||
|
||||
// Find a fusion opportunity.
|
||||
if (it == tileableProducer->getUsers().end())
|
||||
if (it == tileableProducer->getUsers().end()) {
|
||||
diag.attachNote(tileableProducer->getLoc())
|
||||
<< "could not find fusion opportunity for: " << *tileableProducer;
|
||||
return nullptr;
|
||||
}
|
||||
auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
|
||||
|
||||
// Try to fuse the producer in-place.
|
||||
|
@ -256,8 +263,11 @@ static Operation *tileAndFuseFirstExtractUse(Operation *producerOp,
|
|||
FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
|
||||
rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(),
|
||||
sliceOpToTile.getMixedSizes());
|
||||
if (failed(tiledProducer))
|
||||
if (failed(tiledProducer)) {
|
||||
diag.attachNote(tileableProducer->getLoc())
|
||||
<< "failed to tile producer op: " << *tileableProducer;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Replace the extract op.
|
||||
Operation *fusedOp = tiledProducer->getDefiningOp();
|
||||
|
@ -272,11 +282,25 @@ static Operation *tileAndFuseFirstExtractUse(Operation *producerOp,
|
|||
/// `containingOp`.
|
||||
/// Return this fused op on success or nullptr if anything fails.
|
||||
static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
|
||||
Operation *producerOp, Operation *containingOp, RewriterBase &rewriter) {
|
||||
RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
|
||||
Operation *containingOp) {
|
||||
|
||||
auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
|
||||
if (!tileableProducer)
|
||||
if (!tileableProducer) {
|
||||
diag.attachNote(producerOp->getLoc())
|
||||
<< "producer is not a TileableInterface: " << *producerOp;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Ensure `tileableProducer` has exactly one destination operand that we can
|
||||
// replace the ForeachThreadOp bbArg with.
|
||||
auto destinationOperands = tileableProducer.getDestinationOperands(rewriter);
|
||||
if (destinationOperands.size() != 1) {
|
||||
diag.attachNote(tileableProducer->getLoc())
|
||||
<< "tileableProducer must have exactly one destination operand: "
|
||||
<< *tileableProducer;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Search the first use by a "scf::ForeachThreadOp" user.
|
||||
scf::ForeachThreadOp foreachThreadOp;
|
||||
|
@ -286,8 +310,11 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
|
|||
return foreachThreadOp;
|
||||
});
|
||||
// If it's not from the containing op, return.
|
||||
if (!foreachThreadOp || foreachThreadOp != containingOp)
|
||||
if (!foreachThreadOp || foreachThreadOp != containingOp) {
|
||||
diag.attachNote(tileableProducer->getLoc())
|
||||
<< "could not find a use by the containing op: " << *tileableProducer;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Search the producer slices accessed within the containing
|
||||
// operation.
|
||||
|
@ -305,16 +332,13 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
|
|||
});
|
||||
|
||||
// Find a fusion opportunity.
|
||||
if (itBBArgUsers == bbArg.getUsers().end())
|
||||
if (itBBArgUsers == bbArg.getUsers().end()) {
|
||||
diag.attachNote(containingOp->getLoc())
|
||||
<< "could not find fusion opportunity for bbArg: " << bbArg;
|
||||
return nullptr;
|
||||
}
|
||||
auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
|
||||
|
||||
// Ensure `tileableProducer` has exactly one destination operand that we can
|
||||
// replace the ForeachThreadOp bbArg with.
|
||||
auto destinationOperands = tileableProducer.getDestinationOperands(rewriter);
|
||||
if (destinationOperands.size() != 1)
|
||||
return nullptr;
|
||||
|
||||
// Try to fuse the producer in-place.
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(sliceOpToTile);
|
||||
|
@ -333,8 +357,11 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
|
|||
tileableProducerClone.generateResultTileValue(
|
||||
rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(),
|
||||
sliceOpToTile.getMixedSizes());
|
||||
if (failed(tiledProducer))
|
||||
if (failed(tiledProducer)) {
|
||||
diag.attachNote(tileableProducer->getLoc())
|
||||
<< "failed to tile producer op: " << *tileableProducer;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Replace the extract op.
|
||||
Operation *fusedOp = tiledProducer->getDefiningOp();
|
||||
|
@ -349,9 +376,9 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
|
|||
return fusedOp;
|
||||
}
|
||||
|
||||
static Operation *cloneAndFuseFirstUse(Operation *producerOp,
|
||||
Operation *containingOp,
|
||||
RewriterBase &rewriter) {
|
||||
static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
|
||||
Operation *producerOp,
|
||||
Operation *containingOp) {
|
||||
// Gather all uses inside the containing op.
|
||||
SmallVector<OpOperand *> uses;
|
||||
for (OpResult result : producerOp->getOpResults()) {
|
||||
|
@ -362,14 +389,19 @@ static Operation *cloneAndFuseFirstUse(Operation *producerOp,
|
|||
}
|
||||
// Cannot clone and fuse if the use is by the containing op itself: fail
|
||||
// immediately.
|
||||
if (containingOp == use.getOwner())
|
||||
if (containingOp == use.getOwner()) {
|
||||
diag.attachNote(producerOp->getLoc())
|
||||
<< "producer op use by containing op cannot be fused by cloning";
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for a non-empty list of fusion opportunities.
|
||||
if (uses.empty())
|
||||
if (uses.empty()) {
|
||||
diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning";
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Clone and fuse inside the containing op.
|
||||
Operation *fusedOp = nullptr;
|
||||
|
@ -441,18 +473,23 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
|
|||
auto nextProducer = getNextProducer();
|
||||
if (failed(nextProducer)) {
|
||||
Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Remark);
|
||||
diag << "could not fuse ops into container";
|
||||
diag << "could not find next producer to fuse into container";
|
||||
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
|
||||
}
|
||||
|
||||
Operation *producerOp = *nextProducer;
|
||||
|
||||
// Detaul diagnostic, to be complemented with more failure information.
|
||||
Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
|
||||
diag << "could not fuse " << *producerOp << " into " << *containingOp;
|
||||
|
||||
// TODO: If there are multiple uses of the producer in the containing op,
|
||||
// we currently tile/clone the op multiple times (once per use). In some
|
||||
// cases, we can tile/clone once and reuse the value for each use.
|
||||
// Futhermore, producers should then be traversed according to a
|
||||
// topological sorting.
|
||||
Operation *tiled =
|
||||
tileAndFuseFirstExtractUse(producerOp, containingOp, rewriter);
|
||||
tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
|
||||
if (tiled) {
|
||||
fusedOps.push_back(tiled);
|
||||
continue;
|
||||
|
@ -460,21 +497,19 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
|
|||
|
||||
Operation *tiledContainingOpOperand =
|
||||
tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
|
||||
producerOp, containingOp, rewriter);
|
||||
rewriter, diag, producerOp, containingOp);
|
||||
if (tiledContainingOpOperand) {
|
||||
fusedOps.push_back(tiledContainingOpOperand);
|
||||
continue;
|
||||
}
|
||||
|
||||
Operation *cloned =
|
||||
cloneAndFuseFirstUse(producerOp, containingOp, rewriter);
|
||||
cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
|
||||
if (cloned) {
|
||||
fusedOps.push_back(cloned);
|
||||
continue;
|
||||
}
|
||||
|
||||
Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
|
||||
diag << "could not fuse " << *producerOp << "into " << *containingOp;
|
||||
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue