[mlir][Transform] Fix dropReverseMapping early exit condition

Previously, the erasure would not trigger and result in surprising behavior.

Differential Revision: https://reviews.llvm.org/D135881
This commit is contained in:
Nicolas Vasilache 2022-10-13 08:13:25 -07:00
parent f386f7690d
commit d8cab3f407
4 changed files with 55 additions and 1 deletions

View File

@ -92,7 +92,7 @@ transform::TransformState::setPayloadOps(Value value,
void transform::TransformState::dropReverseMapping(Mappings &mappings,
Operation *op, Value value) {
auto it = mappings.reverse.find(op);
if (it != mappings.reverse.end())
if (it == mappings.reverse.end())
return;
llvm::erase_value(it->getSecond(), value);

View File

@ -895,3 +895,28 @@ transform.with_pdl_patterns {
transform.cast %2 : !transform.op<"test.some_op"> to !pdl.operation
}
}
// -----
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
transform.sequence %arg0 : !pdl.operation failures(propagate) {
^bb0(%arg1: !pdl.operation):
%0 = pdl_match @some in %arg1 : (!pdl.operation) -> !pdl.operation
// here, the handles nested under are {%arg0, %arg1, %0}
// expected-remark @below {{3 handles nested under}}
transform.test_report_number_of_tracked_handles_nested_under %arg1
// expected-remark @below {{erased}}
transform.test_emit_remark_and_erase_operand %0, "erased"
// here, the handles nested under are only {%arg0, %arg1}
// expected-remark @below {{2 handles nested under}}
transform.test_report_number_of_tracked_handles_nested_under %arg1
}
pdl.pattern @some : benefit(1) {
%0 = pdl.operation "test.some_op"
pdl.rewrite %0 with "transform.dialect"
}
}
"test.some_op"() : () -> ()

View File

@ -328,6 +328,26 @@ DiagnosedSilenceableFailure mlir::transform::TestDialectOpType::checkPayload(
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getTarget(), effects);
}
DiagnosedSilenceableFailure
mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::apply(
transform::TransformResults &results, transform::TransformState &state) {
int64_t count = 0;
for (Operation *op : state.getPayloadOps(getTarget())) {
op->walk([&](Operation *nested) {
SmallVector<Value> handles;
(void)state.getHandlesForPayloadOp(nested, handles);
count += handles.size();
});
}
emitRemark() << count << " handles nested under";
return DiagnosedSilenceableFailure::success();
}
namespace {
/// Test extension of the Transform dialect. Registers additional ops and
/// declares PDL as dependent dialect since the additional ops are using PDL

View File

@ -253,4 +253,13 @@ def TestCopyPayloadOp
let assemblyFormat = "$handle attr-dict";
}
def TestReportNumberOfTrackedHandlesNestedUnder
: Op<Transform_Dialect, "test_report_number_of_tracked_handles_nested_under",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
let arguments = (ins PDL_Operation:$target);
let assemblyFormat = "$target attr-dict";
let cppNamespace = "::mlir::test";
}
#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD