[mlir][Transform] Fix applyToOne corner case when no op is matched.

Such situations manifest themselves with an empty payload which ends up producing empty results.
In such cases, we still want to match the transform op contract and return as many empty SmallVector<Operation*>
as the op requires.

Differential Revision: https://reviews.llvm.org/D128456
This commit is contained in:
Nicolas Vasilache 2022-06-23 12:14:23 -07:00
parent fbf611ed2a
commit 8c6da76483
2 changed files with 42 additions and 6 deletions

View File

@ -824,6 +824,17 @@ mlir::transform::TransformEachOpTrait<OpTy>::apply(
decltype(&OpTy::applyToOne)>::template arg_t<0>;
ArrayRef<Operation *> targets =
state.getPayloadOps(this->getOperation()->getOperand(0));
// Handle the corner case where no target is specified.
// This is typically the case when the matcher fails to apply and we need to
// propagate gracefully.
// In this case, we fill all results with an empty vector.
if (targets.empty()) {
SmallVector<Operation *> emptyResult;
for (auto r : this->getOperation()->getResults())
transformResults.set(r.template cast<OpResult>(), emptyResult);
return DiagnosedSilenceableFailure::success();
}
SmallVector<SmallVector<Operation *>, 1> results;
// In the multi-result case, collect the number of results each transform
// produced.
@ -831,14 +842,17 @@ mlir::transform::TransformEachOpTrait<OpTy>::apply(
targets, results, [&](TransformOpType specificOp) {
return static_cast<OpTy *>(this)->applyToOne(specificOp, state);
});
// Propagate the failure (definite or silencable) if any.
if (!result.succeeded())
return result;
if (results.empty())
// Legitimately no results, bail early.
if (results.empty() && OpTy::template hasTrait<OpTrait::ZeroResults>())
return DiagnosedSilenceableFailure::success();
// Ensure all applications return the same number of results.
// Variadic cases are much trickier to handle in a generic fashion.
int64_t nRes = results[0].size();
int64_t nRes = results.empty() ? 0 : results[0].size();
if (llvm::any_of(results, [&](const auto &r) {
return static_cast<int64_t>(r.size()) != nRes;
})) {
@ -849,6 +863,8 @@ mlir::transform::TransformEachOpTrait<OpTy>::apply(
"generic `apply` instead of the specialized `applyToOne`";
}
// Ensure the number of results agrees with what the transform op expects.
// Unless we see empty results, in which case we just want to propagate the
// emptiness.
if (this->getOperation()->getNumResults() != nRes) {
InFlightDiagnostic diag = static_cast<OpTy *>(this)->emitError()
<< "unexpected number of results (got " << nRes
@ -857,10 +873,6 @@ mlir::transform::TransformEachOpTrait<OpTy>::apply(
return DiagnosedSilenceableFailure::definiteFailure();
}
// If no results, bail early.
if (OpTy::template hasTrait<OpTrait::ZeroResults>())
return DiagnosedSilenceableFailure::success();
// Perform transposition of M applications producing N results each into N
// results for each of the M applications.
SmallVector<SmallVector<Operation *, 1>> transposedResults =

View File

@ -436,3 +436,27 @@ transform.with_pdl_patterns {
%1:2 = transform.test_correct_number_of_multi_results %0
}
}
// -----
func.func @foo() {
"wrong_op_name" () : () -> ()
return
}
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @some : benefit(1) {
%0 = pdl.operands
%1 = pdl.types
%2 = pdl.operation "op"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
pdl.rewrite %2 with "transform.dialect"
}
transform.sequence %arg0 {
^bb0(%arg1: !pdl.operation):
%0 = pdl_match @some in %arg1
// Transform fails to match any but still produces 2 results.
%1:2 = transform.test_correct_number_of_multi_results %0
}
}