forked from OSchip/llvm-project
[mlir] Add ReplicateOp to the Transform dialect
This handle manipulation operation allows one to define a new handle that is associated with a the same payload IR operations N times, where N can be driven by the size of payload IR operation list associated with another handle. This can be seen as a sort of broadcast that can be used to ensure the lists associated with two handles have equal numbers of payload IR ops as expected by many pairwise transform operations. Introduce an additional "expensive" check that guards against consuming a handle that is assocaited with the same payload IR operation more than once as this is likely to lead to double-free or other undesired effects. Depends On D129110 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D129216
This commit is contained in:
parent
d4c53202eb
commit
00d1a1a25f
|
@ -845,6 +845,27 @@ transposeResults(const SmallVector<SmallVector<Operation *>, 1> &m) {
|
|||
return res;
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
/// Populates `effects` with the memory effects indicating the operation on the
|
||||
/// given handle value:
|
||||
/// - consumes = Read + Free,
|
||||
/// - produces = Allocate + Write,
|
||||
/// - onlyReads = Read.
|
||||
void consumesHandle(ValueRange handles,
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
|
||||
void producesHandle(ValueRange handles,
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
|
||||
void onlyReadsHandle(ValueRange handles,
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
|
||||
|
||||
/// Checks whether the transform op consumes the given handle.
|
||||
bool isHandleConsumed(Value handle, transform::TransformOpInterface transform);
|
||||
|
||||
/// Populates `effects` with the memory effects indicating the access to payload
|
||||
/// IR resource.
|
||||
void modifiesPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
|
||||
void onlyReadsPayload(SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
|
||||
|
||||
} // namespace transform
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -174,6 +174,42 @@ def PDLMatchOp : TransformDialectOp<"pdl_match",
|
|||
let assemblyFormat = "$pattern_name `in` $root attr-dict";
|
||||
}
|
||||
|
||||
def ReplicateOp : TransformDialectOp<"replicate",
|
||||
[DeclareOpInterfaceMethods<TransformOpInterface>,
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
|
||||
let summary = "Lists payload ops multiple times in the new handle";
|
||||
let description = [{
|
||||
Produces a new handle associated with a list of payload IR ops that is
|
||||
computed by repeating the list of payload IR ops associated with the
|
||||
operand handle as many times as the "pattern" handle has associated
|
||||
operations. For example, if pattern is associated with [op1, op2] and the
|
||||
operand handle is associated with [op3, op4, op5], the resulting handle
|
||||
will be associated with [op3, op4, op5, op3, op4, op5].
|
||||
|
||||
This transformation is useful to "align" the sizes of payload IR lists
|
||||
before a transformation that expects, e.g., identically-sized lists. For
|
||||
example, a transformation may be parameterized by same notional per-target
|
||||
size computed at runtime and supplied as another handle, the replication
|
||||
allows this size to be computed only once and used for every target instead
|
||||
of replicating the computation itself.
|
||||
|
||||
Note that it is undesirable to pass a handle with duplicate operations to
|
||||
an operation that consumes the handle. Handle consumption often indicates
|
||||
that the associated payload IR ops are destroyed, so having the same op
|
||||
listed more than once will lead to double-free. Single-operand
|
||||
MergeHandlesOp may be used to deduplicate the associated list of payload IR
|
||||
ops when necessary. Furthermore, a combination of ReplicateOp and
|
||||
MergeHandlesOp can be used to construct arbitrary lists with repetitions.
|
||||
}];
|
||||
|
||||
let arguments = (ins PDL_Operation:$pattern,
|
||||
Variadic<PDL_Operation>:$handles);
|
||||
let results = (outs Variadic<PDL_Operation>:$replicated);
|
||||
let assemblyFormat =
|
||||
"`num` `(` $pattern `)` $handles "
|
||||
"custom<PDLOpTypedResults>(type($replicated), ref($handles)) attr-dict";
|
||||
}
|
||||
|
||||
def SequenceOp : TransformDialectOp<"sequence",
|
||||
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
|
||||
["getSuccessorEntryOperands", "getSuccessorRegions",
|
||||
|
|
|
@ -55,7 +55,7 @@ Value transform::TransformState::getHandleForPayloadOp(Operation *op) const {
|
|||
LogicalResult transform::TransformState::tryEmplaceReverseMapping(
|
||||
Mappings &map, Operation *operation, Value handle) {
|
||||
auto insertionResult = map.reverse.insert({operation, handle});
|
||||
if (!insertionResult.second) {
|
||||
if (!insertionResult.second && insertionResult.first->second != handle) {
|
||||
InFlightDiagnostic diag = operation->emitError()
|
||||
<< "operation tracked by two handles";
|
||||
diag.attachNote(handle.getLoc()) << "handle";
|
||||
|
@ -191,9 +191,27 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidation(
|
|||
DiagnosedSilenceableFailure
|
||||
transform::TransformState::applyTransform(TransformOpInterface transform) {
|
||||
LLVM_DEBUG(DBGS() << "applying: " << transform << "\n");
|
||||
if (options.getExpensiveChecksEnabled() &&
|
||||
failed(checkAndRecordHandleInvalidation(transform))) {
|
||||
return DiagnosedSilenceableFailure::definiteFailure();
|
||||
if (options.getExpensiveChecksEnabled()) {
|
||||
if (failed(checkAndRecordHandleInvalidation(transform)))
|
||||
return DiagnosedSilenceableFailure::definiteFailure();
|
||||
|
||||
for (OpOperand &operand : transform->getOpOperands()) {
|
||||
if (!isHandleConsumed(operand.get(), transform))
|
||||
continue;
|
||||
|
||||
DenseSet<Operation *> seen;
|
||||
for (Operation *op : getPayloadOps(operand.get())) {
|
||||
if (!seen.insert(op).second) {
|
||||
DiagnosedSilenceableFailure diag =
|
||||
transform.emitSilenceableError()
|
||||
<< "a handle passed as operand #" << operand.getOperandNumber()
|
||||
<< " and consumed by this operation points to a payload "
|
||||
"operation more than once";
|
||||
diag.attachNote(op->getLoc()) << "repeated target op";
|
||||
return diag;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
transform::TransformResults results(transform->getNumResults());
|
||||
|
@ -326,6 +344,70 @@ transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Memory effects.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void transform::consumesHandle(
|
||||
ValueRange handles,
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
for (Value handle : handles) {
|
||||
effects.emplace_back(MemoryEffects::Read::get(), handle,
|
||||
TransformMappingResource::get());
|
||||
effects.emplace_back(MemoryEffects::Free::get(), handle,
|
||||
TransformMappingResource::get());
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `true` if the given list of effects instances contains an instance
|
||||
/// with the effect type specified as template parameter.
|
||||
template <typename EffectTy>
|
||||
static bool hasEffect(ArrayRef<MemoryEffects::EffectInstance> effects) {
|
||||
return llvm::any_of(effects, [](const MemoryEffects::EffectInstance &effect) {
|
||||
return isa<EffectTy>(effect.getEffect());
|
||||
});
|
||||
}
|
||||
|
||||
bool transform::isHandleConsumed(Value handle,
|
||||
transform::TransformOpInterface transform) {
|
||||
auto iface = cast<MemoryEffectOpInterface>(transform.getOperation());
|
||||
SmallVector<MemoryEffects::EffectInstance> effects;
|
||||
iface.getEffectsOnValue(handle, effects);
|
||||
return hasEffect<MemoryEffects::Read>(effects) &&
|
||||
hasEffect<MemoryEffects::Free>(effects);
|
||||
}
|
||||
|
||||
void transform::producesHandle(
|
||||
ValueRange handles,
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
for (Value handle : handles) {
|
||||
effects.emplace_back(MemoryEffects::Allocate::get(), handle,
|
||||
TransformMappingResource::get());
|
||||
effects.emplace_back(MemoryEffects::Write::get(), handle,
|
||||
TransformMappingResource::get());
|
||||
}
|
||||
}
|
||||
|
||||
void transform::onlyReadsHandle(
|
||||
ValueRange handles,
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
for (Value handle : handles) {
|
||||
effects.emplace_back(MemoryEffects::Read::get(), handle,
|
||||
TransformMappingResource::get());
|
||||
}
|
||||
}
|
||||
|
||||
void transform::modifiesPayload(
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
|
||||
effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
|
||||
}
|
||||
|
||||
void transform::onlyReadsPayload(
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Generated interface implementation.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -23,6 +23,16 @@
|
|||
|
||||
using namespace mlir;
|
||||
|
||||
static ParseResult parsePDLOpTypedResults(
|
||||
OpAsmParser &parser, SmallVectorImpl<Type> &types,
|
||||
const SmallVectorImpl<OpAsmParser::UnresolvedOperand> &handles) {
|
||||
types.resize(handles.size(), pdl::OperationType::get(parser.getContext()));
|
||||
return success();
|
||||
}
|
||||
|
||||
static void printPDLOpTypedResults(OpAsmPrinter &, Operation *, TypeRange,
|
||||
ValueRange) {}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
|
||||
|
||||
|
@ -354,6 +364,33 @@ transform::PDLMatchOp::apply(transform::TransformResults &results,
|
|||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ReplicateOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DiagnosedSilenceableFailure
|
||||
transform::ReplicateOp::apply(transform::TransformResults &results,
|
||||
transform::TransformState &state) {
|
||||
unsigned numRepetitions = state.getPayloadOps(getPattern()).size();
|
||||
for (const auto &en : llvm::enumerate(getHandles())) {
|
||||
Value handle = en.value();
|
||||
ArrayRef<Operation *> current = state.getPayloadOps(handle);
|
||||
SmallVector<Operation *> payload;
|
||||
payload.reserve(numRepetitions * current.size());
|
||||
for (unsigned i = 0; i < numRepetitions; ++i)
|
||||
llvm::append_range(payload, current);
|
||||
results.set(getReplicated()[en.index()].cast<OpResult>(), payload);
|
||||
}
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
void transform::ReplicateOp::getEffects(
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
onlyReadsHandle(getPattern(), effects);
|
||||
consumesHandle(getHandles(), effects);
|
||||
producesHandle(getReplicated(), effects);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SequenceOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -59,6 +59,22 @@ class PDLMatchOp:
|
|||
ip=ip)
|
||||
|
||||
|
||||
class ReplicateOp:
|
||||
|
||||
def __init__(self,
|
||||
pattern: Union[Operation, Value],
|
||||
handles: Sequence[Union[Operation, Value]],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None):
|
||||
super().__init__(
|
||||
[pdl.OperationType.get()] * len(handles),
|
||||
_get_op_result_or_value(pattern),
|
||||
[_get_op_result_or_value(h) for h in handles],
|
||||
loc=loc,
|
||||
ip=ip)
|
||||
|
||||
|
||||
class SequenceOp:
|
||||
|
||||
@overload
|
||||
|
|
|
@ -25,3 +25,37 @@ transform.with_pdl_patterns {
|
|||
test_print_remark_at_operand %0, "remark"
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @func1() {
|
||||
// expected-note @below {{repeated target op}}
|
||||
return
|
||||
}
|
||||
func.func private @func2()
|
||||
|
||||
transform.with_pdl_patterns {
|
||||
^bb0(%arg0: !pdl.operation):
|
||||
pdl.pattern @func : benefit(1) {
|
||||
%0 = operands
|
||||
%1 = types
|
||||
%2 = operation "func.func"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
|
||||
rewrite %2 with "transform.dialect"
|
||||
}
|
||||
pdl.pattern @return : benefit(1) {
|
||||
%0 = operands
|
||||
%1 = types
|
||||
%2 = operation "func.return"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
|
||||
rewrite %2 with "transform.dialect"
|
||||
}
|
||||
|
||||
sequence %arg0 {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
%0 = pdl_match @func in %arg1
|
||||
%1 = pdl_match @return in %arg1
|
||||
%2 = replicate num(%0) %1
|
||||
// expected-error @below {{a handle passed as operand #0 and consumed by this operation points to a payload operation more than once}}
|
||||
test_consume_operand %2
|
||||
test_print_remark_at_operand %0, "remark"
|
||||
}
|
||||
}
|
||||
|
|
|
@ -569,3 +569,31 @@ transform.with_pdl_patterns {
|
|||
transform.test_mixed_sucess_and_silenceable %0
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
func.func private @foo()
|
||||
func.func private @bar()
|
||||
|
||||
transform.with_pdl_patterns {
|
||||
^bb0(%arg0: !pdl.operation):
|
||||
pdl.pattern @func : benefit(1) {
|
||||
%0 = pdl.operands
|
||||
%1 = pdl.types
|
||||
%2 = pdl.operation "func.func"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
|
||||
pdl.rewrite %2 with "transform.dialect"
|
||||
}
|
||||
|
||||
transform.sequence %arg0 {
|
||||
^bb0(%arg1: !pdl.operation):
|
||||
%0 = pdl_match @func in %arg1
|
||||
%1 = replicate num(%0) %arg1
|
||||
// expected-remark @below {{2}}
|
||||
test_print_number_of_associated_payload_ir_ops %1
|
||||
%2 = replicate num(%0) %1
|
||||
// expected-remark @below {{4}}
|
||||
test_print_number_of_associated_payload_ir_ops %2
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -275,6 +275,18 @@ mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne(
|
|||
return emitDefaultSilenceableFailure(target);
|
||||
}
|
||||
|
||||
DiagnosedSilenceableFailure
|
||||
mlir::test::TestPrintNumberOfAssociatedPayloadIROps::apply(
|
||||
transform::TransformResults &results, transform::TransformState &state) {
|
||||
emitRemark() << state.getPayloadOps(getHandle()).size();
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
void mlir::test::TestPrintNumberOfAssociatedPayloadIROps::getEffects(
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
transform::onlyReadsHandle(getHandle(), effects);
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Test extension of the Transform dialect. Registers additional ops and
|
||||
/// declares PDL as dependent dialect since the additional ops are using PDL
|
||||
|
|
|
@ -212,4 +212,13 @@ def TestMixedSuccessAndSilenceableOp
|
|||
}];
|
||||
}
|
||||
|
||||
def TestPrintNumberOfAssociatedPayloadIROps
|
||||
: Op<Transform_Dialect, "test_print_number_of_associated_payload_ir_ops",
|
||||
[DeclareOpInterfaceMethods<TransformOpInterface>,
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
|
||||
let arguments = (ins PDL_Operation:$handle);
|
||||
let assemblyFormat = "$handle attr-dict";
|
||||
let cppNamespace = "::mlir::test";
|
||||
}
|
||||
|
||||
#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD
|
||||
|
|
|
@ -94,3 +94,19 @@ def testMergeHandlesOp():
|
|||
# CHECK: transform.sequence
|
||||
# CHECK: ^{{.*}}(%[[ARG1:.+]]: !pdl.operation):
|
||||
# CHECK: = merge_handles %[[ARG1]]
|
||||
|
||||
|
||||
@run
|
||||
def testReplicateOp():
|
||||
with_pdl = transform.WithPDLPatternsOp()
|
||||
with InsertionPoint(with_pdl.body):
|
||||
sequence = transform.SequenceOp(with_pdl.bodyTarget)
|
||||
with InsertionPoint(sequence.body):
|
||||
m1 = transform.PDLMatchOp(sequence.bodyTarget, "first")
|
||||
m2 = transform.PDLMatchOp(sequence.bodyTarget, "second")
|
||||
transform.ReplicateOp(m1, [m2])
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: testReplicateOp
|
||||
# CHECK: %[[FIRST:.+]] = pdl_match
|
||||
# CHECK: %[[SECOND:.+]] = pdl_match
|
||||
# CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]
|
||||
|
|
Loading…
Reference in New Issue