[mlir] add transform dialect entry point

Introduce `transform::applyTransforms` as a top-level entry point to the
Transform dialect-driven transformation infrastructure, by analogy with
`applyFull/PartialConversion`. Clients are expected to use this function
and no longer need to maintain the transformation state. Make the
constructor of the TransformState private for that purpose.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D135681
This commit is contained in:
Alex Zinenko 2022-10-11 15:23:48 +00:00
parent 812ad2167b
commit 32f0bde548
5 changed files with 95 additions and 37 deletions

View File

@ -16,19 +16,18 @@ def Transform_Dialect : Dialect {
let description = [{
## Disclaimer
** Proceed with care: not ready for general use. **
**This dialect is actively developed and may change frequently.**
This dialect is evolving rapidly and may change on a very short notice. To
decrease the maintenance burden and churn, only a few in-tree use cases are
currently supported in the main tree:
To decrease the maintenance burden and churn, please post a description of
the intended use case on the MLIR forum. A few in-tree use cases are
currently supported:
- high-level transformations on "structured ops" (i.e. ops that operate on
chunks of data in a way that can be decomposed into operations on
smaller chunks of data and control flow) in Linalg, Tensor and Vector
dialects.
*Please post a description of the intended use case on the MLIR forum and
wait for confirmation.*
dialects;
- loop transformations in the SCF dialect.
## Overview
@ -79,6 +78,18 @@ def Transform_Dialect : Dialect {
expected to have the `PossibleTopLevelTransformOpTrait` and may be used
without arguments.
A program transformation expressed using the Transform dialect can be
programmatically triggered by calling:
```c++
LogicalResult transform::applyTransforms(Operation *payloadRoot,
TransformOpInterface transform,
const TransformOptions &options);
```
that applies the transformations specified by the top-level `transform` to
payload IR contained in `payloadRoot`.
## Dialect Extension Mechanism
This dialect is designed to be extensible, that is, clients of this dialect

View File

@ -206,6 +206,16 @@ private:
bool expensiveChecksEnabled = true;
};
/// Entry point to the Transform dialect infrastructure. Applies the
/// transformation specified by `transform` to payload IR contained in
/// `payloadRoot`. The `transform` operation may contain other operations that
/// will be executed following the internal logic of the operation. It must
/// have the `PossibleTopLevelTransformOp` trait and not have any operands.
/// This function internally keeps track of the transformation state.
LogicalResult
applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
const TransformOptions &options = TransformOptions());
/// The state maintained across applications of various ops implementing the
/// TransformOpInterface. The operations implementing this interface and the
/// surrounding structure are referred to as transform IR. The operations to
@ -250,15 +260,11 @@ class TransformState {
TransformOpReverseMapping reverse;
};
public:
/// Creates a state for transform ops living in the given region. The parent
/// operation of the region. The second argument points to the root operation
/// in the payload IR being transformed, which may or may not contain the
/// region with transform ops. Additional options can be provided through the
/// trailing configuration object.
TransformState(Region &region, Operation *root,
const TransformOptions &options = TransformOptions());
friend LogicalResult applyTransforms(Operation *payloadRoot,
TransformOpInterface transform,
const TransformOptions &options);
public:
/// Returns the op at which the transformation state is rooted. This is
/// typically helpful for transformations that apply globally.
Operation *getTopLevel() const;
@ -438,6 +444,13 @@ private:
/// Identifier for storing top-level value in the `operations` mapping.
static constexpr Value kTopLevelValue = Value();
/// Creates a state for transform ops living in the given region. The second
/// argument points to the root operation in the payload IR being transformed,
/// which may or may not contain the region with transform ops. Additional
/// options can be provided through the trailing configuration object.
TransformState(Region *region, Operation *payloadRoot,
const TransformOptions &options = TransformOptions());
/// Returns the mappings frame for the reigon in which the value is defined.
const Mappings &getMapping(Value value) const {
return const_cast<TransformState *>(this)->getMapping(value);

View File

@ -12,6 +12,7 @@
#include "mlir/IR/Operation.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#define DEBUG_TYPE "transform-dialect"
#define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
@ -25,14 +26,15 @@ using namespace mlir;
constexpr const Value transform::TransformState::kTopLevelValue;
transform::TransformState::TransformState(Region &region, Operation *root,
transform::TransformState::TransformState(Region *region,
Operation *payloadRoot,
const TransformOptions &options)
: topLevel(root), options(options) {
auto result = mappings.try_emplace(&region);
: topLevel(payloadRoot), options(options) {
auto result = mappings.try_emplace(region);
assert(result.second && "the region scope is already present");
(void)result;
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
regionStack.push_back(&region);
regionStack.push_back(region);
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
}
@ -447,6 +449,27 @@ void transform::onlyReadsPayload(
effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
}
//===----------------------------------------------------------------------===//
// Entry point.
//===----------------------------------------------------------------------===//
LogicalResult transform::applyTransforms(Operation *payloadRoot,
TransformOpInterface transform,
const TransformOptions &options) {
#ifndef NDEBUG
if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
transform->getNumOperands() != 0) {
transform->emitError()
<< "expected transform to start at the top-level transform op";
llvm::report_fatal_error("could not run transforms",
/*gen_crash_diag=*/false);
}
#endif // NDEBUG
TransformState state(transform->getParentRegion(), payloadRoot, options);
return state.applyTransform(transform).checkAndReport();
}
//===----------------------------------------------------------------------===//
// Generated interface implementation.
//===----------------------------------------------------------------------===//

View File

@ -1,29 +1,41 @@
// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics
// expected-remark @below {{applying transformation}}
transform.test_transform_op
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
// expected-remark @below {{applying transformation}}
transform.test_transform_op
}
// -----
%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
// expected-remark @below {{succeeded}}
transform.test_consume_operand_if_matches_param_or_fail %0[42]
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
// expected-remark @below {{succeeded}}
transform.test_consume_operand_if_matches_param_or_fail %0[42]
}
// -----
%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
// expected-error @below {{expected the operand to be associated with 21 got 42}}
transform.test_consume_operand_if_matches_param_or_fail %0[21]
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
// expected-error @below {{expected the operand to be associated with 21 got 42}}
transform.test_consume_operand_if_matches_param_or_fail %0[21]
}
// -----
// It is okay to have multiple handles to the same payload op as long
// as only one of them is consumed. The expensive checks mode is necessary
// to detect double-consumption.
%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
%1 = transform.test_copy_payload %0
// expected-remark @below {{succeeded}}
transform.test_consume_operand_if_matches_param_or_fail %0[42]
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
%1 = transform.test_copy_payload %0
// expected-remark @below {{succeeded}}
transform.test_consume_operand_if_matches_param_or_fail %0[42]
}
// -----

View File

@ -41,13 +41,12 @@ public:
void runOnOperation() override {
ModuleOp module = getOperation();
transform::TransformState state(
module.getBodyRegion(), module,
transform::TransformOptions().enableExpensiveChecks(
enableExpensiveChecks));
for (auto op :
module.getBody()->getOps<transform::TransformOpInterface>()) {
if (failed(state.applyTransform(op).checkAndReport()))
if (failed(transform::applyTransforms(
module, op,
transform::TransformOptions().enableExpensiveChecks(
enableExpensiveChecks))))
return signalPassFailure();
}
}