[mlir][Transform] Add a new navigation op to retrieve the producer of an operand

Given an opOperand uniquely determined by the operation `%op` and the operand number `num`,
the `transform.get_producer_of_operand %op[num]` returns the handle to the unique operation
that produced the SSA value used as opOperand.

The transform fails if the operand is a block argument.

Differential Revision: https://reviews.llvm.org/D134171
This commit is contained in:
Nicolas Vasilache 2022-09-19 02:04:39 -07:00
parent 12831be96c
commit ecd9dc0499
3 changed files with 82 additions and 0 deletions

View File

@ -169,6 +169,25 @@ def GetClosestIsolatedParentOp : TransformDialectOp<"get_closest_isolated_parent
let assemblyFormat = "$target attr-dict";
}
def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
[DeclareOpInterfaceMethods<TransformOpInterface>,
NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
let summary = "Get handle to the producer of this operation's operand number";
let description = [{
The handle defined by this Transform op corresponds to operation that
produces the SSA value defined by the `target` and `operand_number`
arguments. If the origin of the SSA value is not an operations (i.e. it is
a block argument), the transform silently fails.
The return handle points to only the subset of successfully produced
computational operations, which can be empty.
}];
let arguments = (ins PDL_Operation:$target,
I64Attr:$operand_number);
let results = (outs PDL_Operation:$parent);
let assemblyFormat = "$target `[` $operand_number `]` attr-dict";
}
def MergeHandlesOp : TransformDialectOp<"merge_handles",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {

View File

@ -386,6 +386,36 @@ DiagnosedSilenceableFailure transform::GetClosestIsolatedParentOp::apply(
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// GetProducerOfOperand
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::GetProducerOfOperand::apply(transform::TransformResults &results,
transform::TransformState &state) {
int64_t operandNumber = getOperandNumber();
SmallVector<Operation *> producers;
for (Operation *target : state.getPayloadOps(getTarget())) {
Operation *producer =
target->getNumOperands() <= operandNumber
? nullptr
: target->getOperand(operandNumber).getDefiningOp();
if (!producer) {
DiagnosedSilenceableFailure diag =
emitSilenceableError()
<< "could not find a producer for operand number: " << operandNumber
<< " of " << *target;
diag.attachNote(target->getLoc()) << "target op";
results.set(getResult().cast<OpResult>(),
SmallVector<mlir::Operation *>{});
return diag;
}
producers.push_back(producer);
}
results.set(getResult().cast<OpResult>(), producers);
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// MergeHandlesOp
//===----------------------------------------------------------------------===//

View File

@ -727,3 +727,36 @@ transform.with_pdl_patterns {
transform.test_print_remark_at_operand %results, "transform applied"
}
}
// -----
func.func @get_parent_for_op_no_loop(%arg0: index, %arg1: index) {
// expected-remark @below {{found muli}}
%0 = arith.muli %arg0, %arg1 : index
arith.addi %0, %arg1 : index
return
}
transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%addi = transform.structured.match ops{["arith.addi"]} in %arg1
%muli = get_producer_of_operand %addi[0]
transform.test_print_remark_at_operand %muli, "found muli"
}
// -----
func.func @get_parent_for_op_no_loop(%arg0: index, %arg1: index) {
// expected-note @below {{target op}}
%0 = arith.muli %arg0, %arg1 : index
return
}
transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
%muli = transform.structured.match ops{["arith.muli"]} in %arg1
// expected-error @below {{could not find a producer for operand number: 0 of}}
%bbarg = get_producer_of_operand %muli[0]
}