forked from OSchip/llvm-project
[mlir] Transform dialect: introduce merge_handles op
This Transform dialect op allows one to merge the lists of Payload IR operations pointed to by several handles into a single list associated with one handle. This is an important Transform dialect usability improvement for cases where transformations may temporarily diverge for different groups of Payload IR ops before converging back to the same script. Without this op, several copies of the trailing transformations would have to be present in the transformation script. Depends On D129090 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D129110
This commit is contained in:
parent
ff6e5508d6
commit
8e03bfc368
|
@ -121,6 +121,28 @@ def GetClosestIsolatedParentOp : TransformDialectOp<"get_closest_isolated_parent
|
|||
let assemblyFormat = "$target attr-dict";
|
||||
}
|
||||
|
||||
def MergeHandlesOp : TransformDialectOp<"merge_handles",
|
||||
[DeclareOpInterfaceMethods<TransformOpInterface>,
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
|
||||
let summary = "Merges handles into one pointing to the union of payload ops";
|
||||
let description = [{
|
||||
Creates a new Transform IR handle value that points to the same Payload IR
|
||||
operations as the operand handles. The Payload IR operations are listed
|
||||
in the same order as they are in the operand handles, grouped by operand
|
||||
handle, e.g., all Payload IR operations associated with the first handle
|
||||
come first, then all Payload IR operations associated with the second handle
|
||||
and so on. If `deduplicate` is set, do not add the given Payload IR
|
||||
operation more than once to the final list regardless of it coming from the
|
||||
same or different handles. Consumes the operands and produces a new handle.
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<PDL_Operation>:$handles,
|
||||
UnitAttr:$deduplicate);
|
||||
let results = (outs PDL_Operation:$result);
|
||||
let assemblyFormat = "($deduplicate^)? $handles attr-dict";
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def PDLMatchOp : TransformDialectOp<"pdl_match",
|
||||
[DeclareOpInterfaceMethods<TransformOpInterface>]> {
|
||||
let summary = "Finds ops that match the named PDL pattern";
|
||||
|
|
|
@ -286,6 +286,52 @@ DiagnosedSilenceableFailure transform::GetClosestIsolatedParentOp::apply(
|
|||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MergeHandlesOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DiagnosedSilenceableFailure
|
||||
transform::MergeHandlesOp::apply(transform::TransformResults &results,
|
||||
transform::TransformState &state) {
|
||||
SmallVector<Operation *> operations;
|
||||
for (Value operand : getHandles())
|
||||
llvm::append_range(operations, state.getPayloadOps(operand));
|
||||
if (!getDeduplicate()) {
|
||||
results.set(getResult().cast<OpResult>(), operations);
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
SetVector<Operation *> uniqued(operations.begin(), operations.end());
|
||||
results.set(getResult().cast<OpResult>(), uniqued.getArrayRef());
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
void transform::MergeHandlesOp::getEffects(
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
for (Value operand : getHandles()) {
|
||||
effects.emplace_back(MemoryEffects::Read::get(), operand,
|
||||
transform::TransformMappingResource::get());
|
||||
effects.emplace_back(MemoryEffects::Free::get(), operand,
|
||||
transform::TransformMappingResource::get());
|
||||
}
|
||||
effects.emplace_back(MemoryEffects::Allocate::get(), getResult(),
|
||||
transform::TransformMappingResource::get());
|
||||
effects.emplace_back(MemoryEffects::Write::get(), getResult(),
|
||||
transform::TransformMappingResource::get());
|
||||
|
||||
// There are no effects on the Payload IR as this is only a handle
|
||||
// manipulation.
|
||||
}
|
||||
|
||||
OpFoldResult transform::MergeHandlesOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (getDeduplicate() || getHandles().size() != 1)
|
||||
return {};
|
||||
|
||||
// If deduplication is not required and there is only one operand, it can be
|
||||
// used directly instead of merging.
|
||||
return getHandles().front();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// PDLMatchOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -28,6 +28,21 @@ class GetClosestIsolatedParentOp:
|
|||
ip=ip)
|
||||
|
||||
|
||||
class MergeHandlesOp:
|
||||
|
||||
def __init__(self,
|
||||
handles: Sequence[Union[Operation, Value]],
|
||||
*,
|
||||
deduplicate: bool = False,
|
||||
loc=None,
|
||||
ip=None):
|
||||
super().__init__(
|
||||
pdl.OperationType.get(), [_get_op_result_or_value(h) for h in handles],
|
||||
deduplicate=deduplicate,
|
||||
loc=loc,
|
||||
ip=ip)
|
||||
|
||||
|
||||
class PDLMatchOp:
|
||||
|
||||
def __init__(self,
|
||||
|
|
|
@ -460,3 +460,42 @@ transform.with_pdl_patterns {
|
|||
%1:2 = transform.test_correct_number_of_multi_results %0
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Expecting to match all operations by merging the handles that matched addi
|
||||
// and subi separately.
|
||||
func.func @foo(%arg0: index) {
|
||||
// expected-remark @below {{matched}}
|
||||
%0 = arith.addi %arg0, %arg0 : index
|
||||
// expected-remark @below {{matched}}
|
||||
%1 = arith.subi %arg0, %arg0 : index
|
||||
// expected-remark @below {{matched}}
|
||||
%2 = arith.addi %0, %1 : index
|
||||
return
|
||||
}
|
||||
|
||||
transform.with_pdl_patterns {
|
||||
^bb0(%arg0: !pdl.operation):
|
||||
pdl.pattern @addi : benefit(1) {
|
||||
%0 = pdl.operands
|
||||
%1 = pdl.types
|
||||
%2 = pdl.operation "arith.addi"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
|
||||
pdl.rewrite %2 with "transform.dialect"
|
||||
}
|
||||
pdl.pattern @subi : benefit(1) {
|
||||
%0 = pdl.operands
|
||||
%1 = pdl.types
|
||||
%2 = pdl.operation "arith.subi"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
|
||||
pdl.rewrite %2 with "transform.dialect"
|
||||
}
|
||||
|
||||
transform.sequence %arg0 {
|
||||
^bb0(%arg1: !pdl.operation):
|
||||
%0 = pdl_match @addi in %arg1
|
||||
%1 = pdl_match @subi in %arg1
|
||||
%2 = merge_handles %0, %1
|
||||
test_print_remark_at_operand %2, "matched"
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -82,3 +82,15 @@ def testGetClosestIsolatedParentOp():
|
|||
# CHECK: transform.sequence
|
||||
# CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation):
|
||||
# CHECK: = get_closest_isolated_parent %[[ARG1]]
|
||||
|
||||
|
||||
@run
|
||||
def testMergeHandlesOp():
|
||||
sequence = transform.SequenceOp()
|
||||
with InsertionPoint(sequence.body):
|
||||
transform.MergeHandlesOp([sequence.bodyTarget])
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: testMergeHandlesOp
|
||||
# CHECK: transform.sequence
|
||||
# CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation):
|
||||
# CHECK: = merge_handles %[[ARG1]]
|
||||
|
|
Loading…
Reference in New Issue