[mlir][Linalg] NFC - Cleanup internal transform APIs and produce better messages on failure to apply.

This commit is contained in:
Nicolas Vasilache 2022-09-19 02:03:48 -07:00
parent 92e9bddc49
commit 12831be96c
1 changed files with 63 additions and 28 deletions

View File

@ -228,12 +228,16 @@ LogicalResult transform::FuseOp::verify() {
/// Find the first "extract" user of `producerOp` and tile it right before its
/// use. The tiled op is fused under the `containingOp`.
/// Return this fused op on success or nullptr if anything fails.
static Operation *tileAndFuseFirstExtractUse(Operation *producerOp,
Operation *containingOp,
RewriterBase &rewriter) {
static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
Diagnostic &diag,
Operation *producerOp,
Operation *containingOp) {
auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
if (!tileableProducer)
if (!tileableProducer) {
diag.attachNote(producerOp->getLoc())
<< "producer is not a TileableInterface: " << *producerOp;
return nullptr;
}
// Search the producer slices accessed within the containing operation.
// TODO: Generalize to more extract/insert/parallel_insert triples, maybe
@ -244,8 +248,11 @@ static Operation *tileAndFuseFirstExtractUse(Operation *producerOp,
});
// Find a fusion opportunity.
if (it == tileableProducer->getUsers().end())
if (it == tileableProducer->getUsers().end()) {
diag.attachNote(tileableProducer->getLoc())
<< "could not find fusion opportunity for: " << *tileableProducer;
return nullptr;
}
auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
// Try to fuse the producer in-place.
@ -256,8 +263,11 @@ static Operation *tileAndFuseFirstExtractUse(Operation *producerOp,
FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(),
sliceOpToTile.getMixedSizes());
if (failed(tiledProducer))
if (failed(tiledProducer)) {
diag.attachNote(tileableProducer->getLoc())
<< "failed to tile producer op: " << *tileableProducer;
return nullptr;
}
// Replace the extract op.
Operation *fusedOp = tiledProducer->getDefiningOp();
@ -272,11 +282,25 @@ static Operation *tileAndFuseFirstExtractUse(Operation *producerOp,
/// `containingOp`.
/// Return this fused op on success or nullptr if anything fails.
static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
Operation *producerOp, Operation *containingOp, RewriterBase &rewriter) {
RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
Operation *containingOp) {
auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
if (!tileableProducer)
if (!tileableProducer) {
diag.attachNote(producerOp->getLoc())
<< "producer is not a TileableInterface: " << *producerOp;
return nullptr;
}
// Ensure `tileableProducer` has exactly one destination operand that we can
// replace the ForeachThreadOp bbArg with.
auto destinationOperands = tileableProducer.getDestinationOperands(rewriter);
if (destinationOperands.size() != 1) {
diag.attachNote(tileableProducer->getLoc())
<< "tileableProducer must have exactly one destination operand: "
<< *tileableProducer;
return nullptr;
}
// Search the first use by a "scf::ForeachThreadOp" user.
scf::ForeachThreadOp foreachThreadOp;
@ -286,8 +310,11 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
return foreachThreadOp;
});
// If it's not from the containing op, return.
if (!foreachThreadOp || foreachThreadOp != containingOp)
if (!foreachThreadOp || foreachThreadOp != containingOp) {
diag.attachNote(tileableProducer->getLoc())
<< "could not find a use by the containing op: " << *tileableProducer;
return nullptr;
}
// Search the producer slices accessed within the containing
// operation.
@ -305,16 +332,13 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
});
// Find a fusion opportunity.
if (itBBArgUsers == bbArg.getUsers().end())
if (itBBArgUsers == bbArg.getUsers().end()) {
diag.attachNote(containingOp->getLoc())
<< "could not find fusion opportunity for bbArg: " << bbArg;
return nullptr;
}
auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
// Ensure `tileableProducer` has exactly one destination operand that we can
// replace the ForeachThreadOp bbArg with.
auto destinationOperands = tileableProducer.getDestinationOperands(rewriter);
if (destinationOperands.size() != 1)
return nullptr;
// Try to fuse the producer in-place.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(sliceOpToTile);
@ -333,8 +357,11 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
tileableProducerClone.generateResultTileValue(
rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(),
sliceOpToTile.getMixedSizes());
if (failed(tiledProducer))
if (failed(tiledProducer)) {
diag.attachNote(tileableProducer->getLoc())
<< "failed to tile producer op: " << *tileableProducer;
return nullptr;
}
// Replace the extract op.
Operation *fusedOp = tiledProducer->getDefiningOp();
@ -349,9 +376,9 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
return fusedOp;
}
static Operation *cloneAndFuseFirstUse(Operation *producerOp,
Operation *containingOp,
RewriterBase &rewriter) {
static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
Operation *producerOp,
Operation *containingOp) {
// Gather all uses inside the containing op.
SmallVector<OpOperand *> uses;
for (OpResult result : producerOp->getOpResults()) {
@ -362,14 +389,19 @@ static Operation *cloneAndFuseFirstUse(Operation *producerOp,
}
// Cannot clone and fuse if the use is by the containing op itself: fail
// immediately.
if (containingOp == use.getOwner())
if (containingOp == use.getOwner()) {
diag.attachNote(producerOp->getLoc())
<< "producer op use by containing op cannot be fused by cloning";
return nullptr;
}
}
}
// Check for a non-empty list of fusion opportunities.
if (uses.empty())
if (uses.empty()) {
diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning";
return nullptr;
}
// Clone and fuse inside the containing op.
Operation *fusedOp = nullptr;
@ -441,18 +473,23 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
auto nextProducer = getNextProducer();
if (failed(nextProducer)) {
Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Remark);
diag << "could not fuse ops into container";
diag << "could not find next producer to fuse into container";
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
}
Operation *producerOp = *nextProducer;
// Detaul diagnostic, to be complemented with more failure information.
Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
diag << "could not fuse " << *producerOp << " into " << *containingOp;
// TODO: If there are multiple uses of the producer in the containing op,
// we currently tile/clone the op multiple times (once per use). In some
// cases, we can tile/clone once and reuse the value for each use.
// Futhermore, producers should then be traversed according to a
// topological sorting.
Operation *tiled =
tileAndFuseFirstExtractUse(producerOp, containingOp, rewriter);
tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
if (tiled) {
fusedOps.push_back(tiled);
continue;
@ -460,21 +497,19 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
Operation *tiledContainingOpOperand =
tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
producerOp, containingOp, rewriter);
rewriter, diag, producerOp, containingOp);
if (tiledContainingOpOperand) {
fusedOps.push_back(tiledContainingOpOperand);
continue;
}
Operation *cloned =
cloneAndFuseFirstUse(producerOp, containingOp, rewriter);
cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
if (cloned) {
fusedOps.push_back(cloned);
continue;
}
Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
diag << "could not fuse " << *producerOp << "into " << *containingOp;
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
}