[mlir] improve and test TransformState::Extension

Add the mechanism for TransformState extensions to update the mapping between
Transform IR values and Payload IR operations held by the state. The mechanism
is intentionally restrictive, similarly to how results of the transform op are
handled.

Introduce test ops that exercise a simple extension that maintains information
across the application of multiple transform ops.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D124778
This commit is contained in:
Alex Zinenko 2022-05-02 18:22:19 +02:00
parent ad47114ad8
commit 6c57b0debe
6 changed files with 236 additions and 17 deletions

View File

@ -74,6 +74,10 @@ public:
/// This is helpful for transformations that apply to a particular handle.
ArrayRef<Operation *> getPayloadOps(Value value) const;
/// Returns the Transform IR handle for the given Payload IR op if it exists
/// in the state, null otherwise.
Value getHandleForPayloadOp(Operation *op) const;
/// Applies the transformation specified by the given transform op and updates
/// the state accordingly.
LogicalResult applyTransform(TransformOpInterface transform);
@ -185,6 +189,10 @@ public:
/// Provides read-only access to the parent TransformState object.
const TransformState &getTransformState() const { return state; }
/// Replaces the given payload op with another op. If the replacement op is
/// null, removes the association of the payload op with its handle.
LogicalResult replacePayloadOp(Operation *op, Operation *replacement);
private:
/// Back-reference to the state that is being extended.
TransformState &state;
@ -276,9 +284,17 @@ private:
/// The callback function is called once per associated operation and is
/// expected to return the modified operation or nullptr. In the latter case,
/// the corresponding operation is no longer associated with the transform IR
/// value.
void updatePayloadOps(Value value,
function_ref<Operation *(Operation *)> callback);
/// value. May fail if the operation produced by the update callback is
/// already associated with a different Transform IR handle value.
LogicalResult
updatePayloadOps(Value value,
function_ref<Operation *(Operation *)> callback);
/// Attempts to record the mapping between the given Payload IR operation and
/// the given Transform IR handle. Fails and reports an error if the operation
/// is already tracked by another handle.
static LogicalResult tryEmplaceReverseMapping(Mappings &map, Operation *op,
Value handle);
/// The mappings between transform IR values and payload IR ops, aggregated by
/// the region in which the transform IR values are defined.

View File

@ -41,6 +41,27 @@ transform::TransformState::getPayloadOps(Value value) const {
return iter->getSecond();
}
Value transform::TransformState::getHandleForPayloadOp(Operation *op) const {
for (const Mappings &mapping : llvm::make_second_range(mappings)) {
if (Value handle = mapping.reverse.lookup(op))
return handle;
}
return Value();
}
LogicalResult transform::TransformState::tryEmplaceReverseMapping(
Mappings &map, Operation *operation, Value handle) {
auto insertionResult = map.reverse.insert({operation, handle});
if (!insertionResult.second) {
InFlightDiagnostic diag = operation->emitError()
<< "operation tracked by two handles";
diag.attachNote(handle.getLoc()) << "handle";
diag.attachNote(insertionResult.first->second.getLoc()) << "handle";
return diag;
}
return success();
}
LogicalResult
transform::TransformState::setPayloadOps(Value value,
ArrayRef<Operation *> targets) {
@ -63,14 +84,8 @@ transform::TransformState::setPayloadOps(Value value,
// expressed using the dialect and may be constructed by valid API calls from
// valid IR. Emit an error here.
for (Operation *op : targets) {
auto insertionResult = mappings.reverse.insert({op, value});
if (!insertionResult.second) {
InFlightDiagnostic diag = op->emitError()
<< "operation tracked by two handles";
diag.attachNote(value.getLoc()) << "handle";
diag.attachNote(insertionResult.first->second.getLoc()) << "handle";
return diag;
}
if (failed(tryEmplaceReverseMapping(mappings, op, value)))
return failure();
}
return success();
@ -83,19 +98,26 @@ void transform::TransformState::removePayloadOps(Value value) {
mappings.direct.erase(value);
}
void transform::TransformState::updatePayloadOps(
LogicalResult transform::TransformState::updatePayloadOps(
Value value, function_ref<Operation *(Operation *)> callback) {
auto it = getMapping(value).direct.find(value);
assert(it != getMapping(value).direct.end() && "unknown handle");
Mappings &mappings = getMapping(value);
auto it = mappings.direct.find(value);
assert(it != mappings.direct.end() && "unknown handle");
SmallVector<Operation *> &association = it->getSecond();
SmallVector<Operation *> updated;
updated.reserve(association.size());
for (Operation *op : association)
if (Operation *updatedOp = callback(op))
for (Operation *op : association) {
mappings.reverse.erase(op);
if (Operation *updatedOp = callback(op)) {
updated.push_back(updatedOp);
if (failed(tryEmplaceReverseMapping(mappings, updatedOp, value)))
return failure();
}
}
std::swap(association, updated);
return success();
}
LogicalResult
@ -132,8 +154,21 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
return success();
}
//===----------------------------------------------------------------------===//
// TransformState::Extension
//===----------------------------------------------------------------------===//
transform::TransformState::Extension::~Extension() = default;
LogicalResult
transform::TransformState::Extension::replacePayloadOp(Operation *op,
Operation *replacement) {
return state.updatePayloadOps(state.getHandleForPayloadOp(op),
[&](Operation *current) {
return current == op ? replacement : current;
});
}
//===----------------------------------------------------------------------===//
// TransformResults
//===----------------------------------------------------------------------===//

View File

@ -0,0 +1,46 @@
// RUN: mlir-opt %s -test-transform-dialect-interpreter -verify-diagnostics -split-input-file
// expected-note @below {{associated payload op}}
module {
transform.sequence {
^bb0(%arg0: !pdl.operation):
// expected-remark @below {{extension absent}}
test_check_if_test_extension_present %arg0
test_add_test_extension "A"
// expected-remark @below {{extension present, A}}
test_check_if_test_extension_present %arg0
test_remove_test_extension
// expected-remark @below {{extension absent}}
test_check_if_test_extension_present %arg0
}
}
// -----
// expected-note @below {{associated payload op}}
module {
transform.sequence {
^bb0(%arg0: !pdl.operation):
test_add_test_extension "A"
test_remove_test_extension
test_add_test_extension "B"
// expected-remark @below {{extension present, B}}
test_check_if_test_extension_present %arg0
}
}
// -----
// expected-note @below {{associated payload op}}
module {
transform.sequence {
^bb0(%arg0: !pdl.operation):
test_add_test_extension "A"
// expected-remark @below {{extension present, A}}
test_check_if_test_extension_present %arg0
// expected-note @below {{associated payload op}}
test_remap_operand_to_self %arg0
// expected-remark @below {{extension present, A}}
test_check_if_test_extension_present %arg0
}
}

View File

@ -12,10 +12,10 @@
//===----------------------------------------------------------------------===//
#include "TestTransformDialectExtension.h"
#include "TestTransformStateExtension.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
using namespace mlir;
@ -142,6 +142,49 @@ LogicalResult mlir::test::TestPrintRemarkAtOperandOp::apply(
return success();
}
LogicalResult
mlir::test::TestAddTestExtensionOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
state.addExtension<TestTransformStateExtension>(getMessageAttr());
return success();
}
LogicalResult mlir::test::TestCheckIfTestExtensionPresentOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
auto *extension = state.getExtension<TestTransformStateExtension>();
if (!extension) {
emitRemark() << "extension absent";
return success();
}
InFlightDiagnostic diag = emitRemark()
<< "extension present, " << extension->getMessage();
for (Operation *payload : state.getPayloadOps(getOperand())) {
diag.attachNote(payload->getLoc()) << "associated payload op";
assert(state.getHandleForPayloadOp(payload) == getOperand() &&
"inconsistent mapping between transform IR handles and payload IR "
"operations");
}
return success();
}
LogicalResult mlir::test::TestRemapOperandPayloadToSelfOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
auto *extension = state.getExtension<TestTransformStateExtension>();
if (!extension)
return emitError() << "TestTransformStateExtension missing";
return extension->updateMapping(state.getPayloadOps(getOperand()).front(),
getOperation());
}
LogicalResult mlir::test::TestRemoveTestExtensionOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
state.removeExtension<TestTransformStateExtension>();
return success();
}
namespace {
/// Test extension of the Transform dialect. Registers additional ops and
/// declares PDL as dependent dialect since the additional ops are using PDL

View File

@ -56,4 +56,41 @@ def TestPrintRemarkAtOperandOp
let cppNamespace = "::mlir::test";
}
def TestAddTestExtensionOp
: Op<Transform_Dialect, "test_add_test_extension",
[DeclareOpInterfaceMethods<TransformOpInterface>,
NoSideEffect]> {
let arguments = (ins StrAttr:$message);
let assemblyFormat = "$message attr-dict";
let cppNamespace = "::mlir::test";
}
def TestCheckIfTestExtensionPresentOp
: Op<Transform_Dialect, "test_check_if_test_extension_present",
[DeclareOpInterfaceMethods<TransformOpInterface>]> {
let arguments = (ins
Arg<PDL_Operation, "", [TransformMappingRead, PayloadIRRead]>:$operand);
let assemblyFormat = "$operand attr-dict";
let cppNamespace = "::mlir::test";
}
def TestRemapOperandPayloadToSelfOp
: Op<Transform_Dialect, "test_remap_operand_to_self",
[DeclareOpInterfaceMethods<TransformOpInterface>]> {
let arguments = (ins
Arg<PDL_Operation, "",
[TransformMappingRead, TransformMappingWrite, PayloadIRRead]>:$operand);
let assemblyFormat = "$operand attr-dict";
let cppNamespace = "::mlir::test";
}
def TestRemoveTestExtensionOp
: Op<Transform_Dialect, "test_remove_test_extension",
[DeclareOpInterfaceMethods<TransformOpInterface>,
NoSideEffect]> {
let assemblyFormat = "attr-dict";
let cppNamespace = "::mlir::test";
}
#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD

View File

@ -0,0 +1,42 @@
//===- TestTransformStateExtension.h - Test Utility -------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines an TransformState extension for the purpose of testing the
// relevant APIs.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_TEST_LIB_DIALECT_TRANSFORM_TESTTRANSFORMSTATEEXTENSION_H
#define MLIR_TEST_LIB_DIALECT_TRANSFORM_TESTTRANSFORMSTATEEXTENSION_H
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
using namespace mlir;
namespace mlir {
namespace test {
class TestTransformStateExtension
: public transform::TransformState::Extension {
public:
TestTransformStateExtension(transform::TransformState &state,
StringAttr message)
: Extension(state), message(message) {}
StringRef getMessage() const { return message.getValue(); }
LogicalResult updateMapping(Operation *previous, Operation *updated) {
return replacePayloadOp(previous, updated);
}
private:
StringAttr message;
};
} // namespace test
} // namespace mlir
#endif // MLIR_TEST_LIB_DIALECT_TRANSFORM_TESTTRANSFORMSTATEEXTENSION_H