From 8c6da76483935d172c34e04e6c0106e33d803c61 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Thu, 23 Jun 2022 12:14:23 -0700 Subject: [PATCH] [mlir][Transform] Fix applyToOne corner case when no op is matched. Such situations manifest themselves with an empty payload which ends up producing empty results. In such cases, we still want to match the transform op contract and return as many empty SmallVector as the op requires. Differential Revision: https://reviews.llvm.org/D128456 --- .../Transform/IR/TransformInterfaces.h | 24 ++++++++++++++----- .../Dialect/Transform/test-interpreter.mlir | 24 +++++++++++++++++++ 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h index e6fbfc88e31e..ef891dd2ddc5 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -824,6 +824,17 @@ mlir::transform::TransformEachOpTrait::apply( decltype(&OpTy::applyToOne)>::template arg_t<0>; ArrayRef targets = state.getPayloadOps(this->getOperation()->getOperand(0)); + // Handle the corner case where no target is specified. + // This is typically the case when the matcher fails to apply and we need to + // propagate gracefully. + // In this case, we fill all results with an empty vector. + if (targets.empty()) { + SmallVector emptyResult; + for (auto r : this->getOperation()->getResults()) + transformResults.set(r.template cast(), emptyResult); + return DiagnosedSilenceableFailure::success(); + } + SmallVector, 1> results; // In the multi-result case, collect the number of results each transform // produced. @@ -831,14 +842,17 @@ mlir::transform::TransformEachOpTrait::apply( targets, results, [&](TransformOpType specificOp) { return static_cast(this)->applyToOne(specificOp, state); }); + // Propagate the failure (definite or silencable) if any. if (!result.succeeded()) return result; - if (results.empty()) + + // Legitimately no results, bail early. + if (results.empty() && OpTy::template hasTrait()) return DiagnosedSilenceableFailure::success(); // Ensure all applications return the same number of results. // Variadic cases are much trickier to handle in a generic fashion. - int64_t nRes = results[0].size(); + int64_t nRes = results.empty() ? 0 : results[0].size(); if (llvm::any_of(results, [&](const auto &r) { return static_cast(r.size()) != nRes; })) { @@ -849,6 +863,8 @@ mlir::transform::TransformEachOpTrait::apply( "generic `apply` instead of the specialized `applyToOne`"; } // Ensure the number of results agrees with what the transform op expects. + // Unless we see empty results, in which case we just want to propagate the + // emptiness. if (this->getOperation()->getNumResults() != nRes) { InFlightDiagnostic diag = static_cast(this)->emitError() << "unexpected number of results (got " << nRes @@ -857,10 +873,6 @@ mlir::transform::TransformEachOpTrait::apply( return DiagnosedSilenceableFailure::definiteFailure(); } - // If no results, bail early. - if (OpTy::template hasTrait()) - return DiagnosedSilenceableFailure::success(); - // Perform transposition of M applications producing N results each into N // results for each of the M applications. SmallVector> transposedResults = diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir index a487d1dbef19..34d1fc8a2b17 100644 --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -436,3 +436,27 @@ transform.with_pdl_patterns { %1:2 = transform.test_correct_number_of_multi_results %0 } } + +// ----- + +func.func @foo() { + "wrong_op_name" () : () -> () + return +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @some : benefit(1) { + %0 = pdl.operands + %1 = pdl.types + %2 = pdl.operation "op"(%0 : !pdl.range) -> (%1 : !pdl.range) + pdl.rewrite %2 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb0(%arg1: !pdl.operation): + %0 = pdl_match @some in %arg1 + // Transform fails to match any but still produces 2 results. + %1:2 = transform.test_correct_number_of_multi_results %0 + } +}