[mlir] Add ReplicateOp to the Transform dialect

This handle manipulation operation allows one to define a new handle that is
associated with a the same payload IR operations N times, where N can be driven
by the size of payload IR operation list associated with another handle. This
can be seen as a sort of broadcast that can be used to ensure the lists
associated with two handles have equal numbers of payload IR ops as expected by
many pairwise transform operations.

Introduce an additional "expensive" check that guards against consuming a
handle that is assocaited with the same payload IR operation more than once as
this is likely to lead to double-free or other undesired effects.

Depends On D129110

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D129216
This commit is contained in:
Alex Zinenko 2022-07-07 15:55:23 +02:00
parent d4c53202eb
commit 00d1a1a25f
10 changed files with 295 additions and 4 deletions

View File

@ -845,6 +845,27 @@ transposeResults(const SmallVector<SmallVector<Operation *>, 1> &m) {
return res;
}
} // namespace detail
/// Populates `effects` with the memory effects indicating the operation on the
/// given handle value:
/// - consumes = Read + Free,
/// - produces = Allocate + Write,
/// - onlyReads = Read.
void consumesHandle(ValueRange handles,
SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
void producesHandle(ValueRange handles,
SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
void onlyReadsHandle(ValueRange handles,
SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
/// Checks whether the transform op consumes the given handle.
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform);
/// Populates `effects` with the memory effects indicating the access to payload
/// IR resource.
void modifiesPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
void onlyReadsPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
} // namespace transform
} // namespace mlir

View File

@ -174,6 +174,42 @@ def PDLMatchOp : TransformDialectOp<"pdl_match",
let assemblyFormat = "$pattern_name `in` $root attr-dict";
}
def ReplicateOp : TransformDialectOp<"replicate",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "Lists payload ops multiple times in the new handle";
let description = [{
Produces a new handle associated with a list of payload IR ops that is
computed by repeating the list of payload IR ops associated with the
operand handle as many times as the "pattern" handle has associated
operations. For example, if pattern is associated with [op1, op2] and the
operand handle is associated with [op3, op4, op5], the resulting handle
will be associated with [op3, op4, op5, op3, op4, op5].
This transformation is useful to "align" the sizes of payload IR lists
before a transformation that expects, e.g., identically-sized lists. For
example, a transformation may be parameterized by same notional per-target
size computed at runtime and supplied as another handle, the replication
allows this size to be computed only once and used for every target instead
of replicating the computation itself.
Note that it is undesirable to pass a handle with duplicate operations to
an operation that consumes the handle. Handle consumption often indicates
that the associated payload IR ops are destroyed, so having the same op
listed more than once will lead to double-free. Single-operand
MergeHandlesOp may be used to deduplicate the associated list of payload IR
ops when necessary. Furthermore, a combination of ReplicateOp and
MergeHandlesOp can be used to construct arbitrary lists with repetitions.
}];
let arguments = (ins PDL_Operation:$pattern,
Variadic<PDL_Operation>:$handles);
let results = (outs Variadic<PDL_Operation>:$replicated);
let assemblyFormat =
"`num` `(` $pattern `)` $handles "
"custom<PDLOpTypedResults>(type($replicated), ref($handles)) attr-dict";
}
def SequenceOp : TransformDialectOp<"sequence",
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getSuccessorEntryOperands", "getSuccessorRegions",

View File

@ -55,7 +55,7 @@ Value transform::TransformState::getHandleForPayloadOp(Operation *op) const {
LogicalResult transform::TransformState::tryEmplaceReverseMapping(
Mappings &map, Operation *operation, Value handle) {
auto insertionResult = map.reverse.insert({operation, handle});
if (!insertionResult.second) {
if (!insertionResult.second && insertionResult.first->second != handle) {
InFlightDiagnostic diag = operation->emitError()
<< "operation tracked by two handles";
diag.attachNote(handle.getLoc()) << "handle";
@ -191,9 +191,27 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
DiagnosedSilenceableFailure
transform::TransformState::applyTransform(TransformOpInterface transform) {
LLVM_DEBUG(DBGS() << "applying: " << transform << "\n");
if (options.getExpensiveChecksEnabled() &&
failed(checkAndRecordHandleInvalidation(transform))) {
return DiagnosedSilenceableFailure::definiteFailure();
if (options.getExpensiveChecksEnabled()) {
if (failed(checkAndRecordHandleInvalidation(transform)))
return DiagnosedSilenceableFailure::definiteFailure();
for (OpOperand &operand : transform->getOpOperands()) {
if (!isHandleConsumed(operand.get(), transform))
continue;
DenseSet<Operation *> seen;
for (Operation *op : getPayloadOps(operand.get())) {
if (!seen.insert(op).second) {
DiagnosedSilenceableFailure diag =
transform.emitSilenceableError()
<< "a handle passed as operand #" << operand.getOperandNumber()
<< " and consumed by this operation points to a payload "
"operation more than once";
diag.attachNote(op->getLoc()) << "repeated target op";
return diag;
}
}
}
}
transform::TransformResults results(transform->getNumResults());
@ -326,6 +344,70 @@ transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {
return success();
}
//===----------------------------------------------------------------------===//
// Memory effects.
//===----------------------------------------------------------------------===//
void transform::consumesHandle(
ValueRange handles,
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
for (Value handle : handles) {
effects.emplace_back(MemoryEffects::Read::get(), handle,
TransformMappingResource::get());
effects.emplace_back(MemoryEffects::Free::get(), handle,
TransformMappingResource::get());
}
}
/// Returns `true` if the given list of effects instances contains an instance
/// with the effect type specified as template parameter.
template <typename EffectTy>
static bool hasEffect(ArrayRef<MemoryEffects::EffectInstance> effects) {
return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
return isa<EffectTy>(effect.getEffect());
});
}
bool transform::isHandleConsumed(Value handle,
transform::TransformOpInterface transform) {
auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
SmallVector<MemoryEffects::EffectInstance> effects;
iface.getEffectsOnValue(handle, effects);
return hasEffect<MemoryEffects::Read>(effects) &&
hasEffect<MemoryEffects::Free>(effects);
}
void transform::producesHandle(
ValueRange handles,
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
for (Value handle : handles) {
effects.emplace_back(MemoryEffects::Allocate::get(), handle,
TransformMappingResource::get());
effects.emplace_back(MemoryEffects::Write::get(), handle,
TransformMappingResource::get());
}
}
void transform::onlyReadsHandle(
ValueRange handles,
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
for (Value handle : handles) {
effects.emplace_back(MemoryEffects::Read::get(), handle,
TransformMappingResource::get());
}
}
void transform::modifiesPayload(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
}
void transform::onlyReadsPayload(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
}
//===----------------------------------------------------------------------===//
// Generated interface implementation.
//===----------------------------------------------------------------------===//

View File

@ -23,6 +23,16 @@
using namespace mlir;
static ParseResult parsePDLOpTypedResults(
OpAsmParser &parser, SmallVectorImpl<Type> &types,
const SmallVectorImpl<OpAsmParser::UnresolvedOperand> &handles) {
types.resize(handles.size(), pdl::OperationType::get(parser.getContext()));
return success();
}
static void printPDLOpTypedResults(OpAsmPrinter &, Operation *, TypeRange,
ValueRange) {}
#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
@ -354,6 +364,33 @@ transform::PDLMatchOp::apply(transform::TransformResults &results,
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// ReplicateOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::ReplicateOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
unsigned numRepetitions = state.getPayloadOps(getPattern()).size();
for (const auto &en : llvm::enumerate(getHandles())) {
Value handle = en.value();
ArrayRef<Operation *> current = state.getPayloadOps(handle);
SmallVector<Operation *> payload;
payload.reserve(numRepetitions * current.size());
for (unsigned i = 0; i < numRepetitions; ++i)
llvm::append_range(payload, current);
results.set(getReplicated()[en.index()].cast<OpResult>(), payload);
}
return DiagnosedSilenceableFailure::success();
}
void transform::ReplicateOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getPattern(), effects);
consumesHandle(getHandles(), effects);
producesHandle(getReplicated(), effects);
}
//===----------------------------------------------------------------------===//
// SequenceOp
//===----------------------------------------------------------------------===//

View File

@ -59,6 +59,22 @@ class PDLMatchOp:
ip=ip)
class ReplicateOp:
def __init__(self,
pattern: Union[Operation, Value],
handles: Sequence[Union[Operation, Value]],
*,
loc=None,
ip=None):
super().__init__(
[pdl.OperationType.get()] * len(handles),
_get_op_result_or_value(pattern),
[_get_op_result_or_value(h) for h in handles],
loc=loc,
ip=ip)
class SequenceOp:
@overload

View File

@ -25,3 +25,37 @@ transform.with_pdl_patterns {
test_print_remark_at_operand %0, "remark"
}
}
// -----
func.func @func1() {
// expected-note @below {{repeated target op}}
return
}
func.func private @func2()
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @func : benefit(1) {
%0 = operands
%1 = types
%2 = operation "func.func"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
rewrite %2 with "transform.dialect"
}
pdl.pattern @return : benefit(1) {
%0 = operands
%1 = types
%2 = operation "func.return"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
rewrite %2 with "transform.dialect"
}
sequence %arg0 {
^bb1(%arg1: !pdl.operation):
%0 = pdl_match @func in %arg1
%1 = pdl_match @return in %arg1
%2 = replicate num(%0) %1
// expected-error @below {{a handle passed as operand #0 and consumed by this operation points to a payload operation more than once}}
test_consume_operand %2
test_print_remark_at_operand %0, "remark"
}
}

View File

@ -569,3 +569,31 @@ transform.with_pdl_patterns {
transform.test_mixed_sucess_and_silenceable %0
}
}
// -----
module {
func.func private @foo()
func.func private @bar()
transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
pdl.pattern @func : benefit(1) {
%0 = pdl.operands
%1 = pdl.types
%2 = pdl.operation "func.func"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
pdl.rewrite %2 with "transform.dialect"
}
transform.sequence %arg0 {
^bb0(%arg1: !pdl.operation):
%0 = pdl_match @func in %arg1
%1 = replicate num(%0) %arg1
// expected-remark @below {{2}}
test_print_number_of_associated_payload_ir_ops %1
%2 = replicate num(%0) %1
// expected-remark @below {{4}}
test_print_number_of_associated_payload_ir_ops %2
}
}
}

View File

@ -275,6 +275,18 @@ mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne(
return emitDefaultSilenceableFailure(target);
}
DiagnosedSilenceableFailure
mlir::test::TestPrintNumberOfAssociatedPayloadIROps::apply(
transform::TransformResults &results, transform::TransformState &state) {
emitRemark() << state.getPayloadOps(getHandle()).size();
return DiagnosedSilenceableFailure::success();
}
void mlir::test::TestPrintNumberOfAssociatedPayloadIROps::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::onlyReadsHandle(getHandle(), effects);
}
namespace {
/// Test extension of the Transform dialect. Registers additional ops and
/// declares PDL as dependent dialect since the additional ops are using PDL

View File

@ -212,4 +212,13 @@ def TestMixedSuccessAndSilenceableOp
}];
}
def TestPrintNumberOfAssociatedPayloadIROps
: Op<Transform_Dialect, "test_print_number_of_associated_payload_ir_ops",
[DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let arguments = (ins PDL_Operation:$handle);
let assemblyFormat = "$handle attr-dict";
let cppNamespace = "::mlir::test";
}
#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD

View File

@ -94,3 +94,19 @@ def testMergeHandlesOp():
# CHECK: transform.sequence
# CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation):
# CHECK: = merge_handles %[[ARG1]]
@run
def testReplicateOp():
with_pdl = transform.WithPDLPatternsOp()
with InsertionPoint(with_pdl.body):
sequence = transform.SequenceOp(with_pdl.bodyTarget)
with InsertionPoint(sequence.body):
m1 = transform.PDLMatchOp(sequence.bodyTarget, "first")
m2 = transform.PDLMatchOp(sequence.bodyTarget, "second")
transform.ReplicateOp(m1, [m2])
transform.YieldOp()
# CHECK-LABEL: TEST: testReplicateOp
# CHECK: %[[FIRST:.+]] = pdl_match
# CHECK: %[[SECOND:.+]] = pdl_match
# CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]