forked from OSchip/llvm-project
[mlir] add transform dialect entry point
Introduce `transform::applyTransforms` as a top-level entry point to the Transform dialect-driven transformation infrastructure, by analogy with `applyFull/PartialConversion`. Clients are expected to use this function and no longer need to maintain the transformation state. Make the constructor of the TransformState private for that purpose. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D135681
This commit is contained in:
parent
812ad2167b
commit
32f0bde548
|
@ -16,19 +16,18 @@ def Transform_Dialect : Dialect {
|
|||
let description = [{
|
||||
## Disclaimer
|
||||
|
||||
** Proceed with care: not ready for general use. **
|
||||
**This dialect is actively developed and may change frequently.**
|
||||
|
||||
This dialect is evolving rapidly and may change on a very short notice. To
|
||||
decrease the maintenance burden and churn, only a few in-tree use cases are
|
||||
currently supported in the main tree:
|
||||
To decrease the maintenance burden and churn, please post a description of
|
||||
the intended use case on the MLIR forum. A few in-tree use cases are
|
||||
currently supported:
|
||||
|
||||
- high-level transformations on "structured ops" (i.e. ops that operate on
|
||||
chunks of data in a way that can be decomposed into operations on
|
||||
smaller chunks of data and control flow) in Linalg, Tensor and Vector
|
||||
dialects.
|
||||
|
||||
*Please post a description of the intended use case on the MLIR forum and
|
||||
wait for confirmation.*
|
||||
dialects;
|
||||
- loop transformations in the SCF dialect.
|
||||
|
||||
|
||||
## Overview
|
||||
|
||||
|
@ -79,6 +78,18 @@ def Transform_Dialect : Dialect {
|
|||
expected to have the `PossibleTopLevelTransformOpTrait` and may be used
|
||||
without arguments.
|
||||
|
||||
A program transformation expressed using the Transform dialect can be
|
||||
programmatically triggered by calling:
|
||||
|
||||
```c++
|
||||
LogicalResult transform::applyTransforms(Operation *payloadRoot,
|
||||
TransformOpInterface transform,
|
||||
const TransformOptions &options);
|
||||
```
|
||||
|
||||
that applies the transformations specified by the top-level `transform` to
|
||||
payload IR contained in `payloadRoot`.
|
||||
|
||||
## Dialect Extension Mechanism
|
||||
|
||||
This dialect is designed to be extensible, that is, clients of this dialect
|
||||
|
|
|
@ -206,6 +206,16 @@ private:
|
|||
bool expensiveChecksEnabled = true;
|
||||
};
|
||||
|
||||
/// Entry point to the Transform dialect infrastructure. Applies the
|
||||
/// transformation specified by `transform` to payload IR contained in
|
||||
/// `payloadRoot`. The `transform` operation may contain other operations that
|
||||
/// will be executed following the internal logic of the operation. It must
|
||||
/// have the `PossibleTopLevelTransformOp` trait and not have any operands.
|
||||
/// This function internally keeps track of the transformation state.
|
||||
LogicalResult
|
||||
applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
|
||||
const TransformOptions &options = TransformOptions());
|
||||
|
||||
/// The state maintained across applications of various ops implementing the
|
||||
/// TransformOpInterface. The operations implementing this interface and the
|
||||
/// surrounding structure are referred to as transform IR. The operations to
|
||||
|
@ -250,15 +260,11 @@ class TransformState {
|
|||
TransformOpReverseMapping reverse;
|
||||
};
|
||||
|
||||
public:
|
||||
/// Creates a state for transform ops living in the given region. The parent
|
||||
/// operation of the region. The second argument points to the root operation
|
||||
/// in the payload IR being transformed, which may or may not contain the
|
||||
/// region with transform ops. Additional options can be provided through the
|
||||
/// trailing configuration object.
|
||||
TransformState(Region ®ion, Operation *root,
|
||||
const TransformOptions &options = TransformOptions());
|
||||
friend LogicalResult applyTransforms(Operation *payloadRoot,
|
||||
TransformOpInterface transform,
|
||||
const TransformOptions &options);
|
||||
|
||||
public:
|
||||
/// Returns the op at which the transformation state is rooted. This is
|
||||
/// typically helpful for transformations that apply globally.
|
||||
Operation *getTopLevel() const;
|
||||
|
@ -438,6 +444,13 @@ private:
|
|||
/// Identifier for storing top-level value in the `operations` mapping.
|
||||
static constexpr Value kTopLevelValue = Value();
|
||||
|
||||
/// Creates a state for transform ops living in the given region. The second
|
||||
/// argument points to the root operation in the payload IR being transformed,
|
||||
/// which may or may not contain the region with transform ops. Additional
|
||||
/// options can be provided through the trailing configuration object.
|
||||
TransformState(Region *region, Operation *payloadRoot,
|
||||
const TransformOptions &options = TransformOptions());
|
||||
|
||||
/// Returns the mappings frame for the reigon in which the value is defined.
|
||||
const Mappings &getMapping(Value value) const {
|
||||
return const_cast<TransformState *>(this)->getMapping(value);
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include "mlir/IR/Operation.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
|
||||
#define DEBUG_TYPE "transform-dialect"
|
||||
#define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
|
||||
|
@ -25,14 +26,15 @@ using namespace mlir;
|
|||
|
||||
constexpr const Value transform::TransformState::kTopLevelValue;
|
||||
|
||||
transform::TransformState::TransformState(Region ®ion, Operation *root,
|
||||
transform::TransformState::TransformState(Region *region,
|
||||
Operation *payloadRoot,
|
||||
const TransformOptions &options)
|
||||
: topLevel(root), options(options) {
|
||||
auto result = mappings.try_emplace(®ion);
|
||||
: topLevel(payloadRoot), options(options) {
|
||||
auto result = mappings.try_emplace(region);
|
||||
assert(result.second && "the region scope is already present");
|
||||
(void)result;
|
||||
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
|
||||
regionStack.push_back(®ion);
|
||||
regionStack.push_back(region);
|
||||
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
|
||||
}
|
||||
|
||||
|
@ -447,6 +449,27 @@ void transform::onlyReadsPayload(
|
|||
effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Entry point.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult transform::applyTransforms(Operation *payloadRoot,
|
||||
TransformOpInterface transform,
|
||||
const TransformOptions &options) {
|
||||
#ifndef NDEBUG
|
||||
if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
|
||||
transform->getNumOperands() != 0) {
|
||||
transform->emitError()
|
||||
<< "expected transform to start at the top-level transform op";
|
||||
llvm::report_fatal_error("could not run transforms",
|
||||
/*gen_crash_diag=*/false);
|
||||
}
|
||||
#endif // NDEBUG
|
||||
|
||||
TransformState state(transform->getParentRegion(), payloadRoot, options);
|
||||
return state.applyTransform(transform).checkAndReport();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Generated interface implementation.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1,29 +1,41 @@
|
|||
// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics
|
||||
|
||||
// expected-remark @below {{applying transformation}}
|
||||
transform.test_transform_op
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg0: !transform.any_op):
|
||||
// expected-remark @below {{applying transformation}}
|
||||
transform.test_transform_op
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
|
||||
// expected-remark @below {{succeeded}}
|
||||
transform.test_consume_operand_if_matches_param_or_fail %0[42]
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg0: !transform.any_op):
|
||||
%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
|
||||
// expected-remark @below {{succeeded}}
|
||||
transform.test_consume_operand_if_matches_param_or_fail %0[42]
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
|
||||
// expected-error @below {{expected the operand to be associated with 21 got 42}}
|
||||
transform.test_consume_operand_if_matches_param_or_fail %0[21]
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg0: !transform.any_op):
|
||||
%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
|
||||
// expected-error @below {{expected the operand to be associated with 21 got 42}}
|
||||
transform.test_consume_operand_if_matches_param_or_fail %0[21]
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// It is okay to have multiple handles to the same payload op as long
|
||||
// as only one of them is consumed. The expensive checks mode is necessary
|
||||
// to detect double-consumption.
|
||||
%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
|
||||
%1 = transform.test_copy_payload %0
|
||||
// expected-remark @below {{succeeded}}
|
||||
transform.test_consume_operand_if_matches_param_or_fail %0[42]
|
||||
transform.sequence failures(propagate) {
|
||||
^bb0(%arg0: !transform.any_op):
|
||||
%0 = transform.test_produce_param_or_forward_operand 42 { foo = "bar" }
|
||||
%1 = transform.test_copy_payload %0
|
||||
// expected-remark @below {{succeeded}}
|
||||
transform.test_consume_operand_if_matches_param_or_fail %0[42]
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
|
|
@ -41,13 +41,12 @@ public:
|
|||
|
||||
void runOnOperation() override {
|
||||
ModuleOp module = getOperation();
|
||||
transform::TransformState state(
|
||||
module.getBodyRegion(), module,
|
||||
transform::TransformOptions().enableExpensiveChecks(
|
||||
enableExpensiveChecks));
|
||||
for (auto op :
|
||||
module.getBody()->getOps<transform::TransformOpInterface>()) {
|
||||
if (failed(state.applyTransform(op).checkAndReport()))
|
||||
if (failed(transform::applyTransforms(
|
||||
module, op,
|
||||
transform::TransformOptions().enableExpensiveChecks(
|
||||
enableExpensiveChecks))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue