[mlir][transform] Introduce transform.sequence op

Sequence is an important transform combination primitive that just indicates
transform ops being applied in a row. The simplest version requires fails
immediately if any transformation in the sequence fails. Introducing this
operation allows one to start placing transform IR within other IR.

Depends On D123135

Reviewed By: Mogball, rriddle

Differential Revision: https://reviews.llvm.org/D123664
This commit is contained in:
Alex Zinenko 2022-04-19 16:36:37 +02:00
parent e37726beb2
commit 0eb403ad1b
15 changed files with 525 additions and 31 deletions

View File

@ -1,8 +1,13 @@
# The dialect does not have its own ops, so just generate the dialect files.
# Generate the dialect files from the dialect .td.
#
# TODO: Make it possible to use XDialect instead of XOpsDialect in
# add_mlir_dialect.
set(LLVM_TARGET_DEFINITIONS TransformDialect.td)
mlir_tablegen(TransformDialect.h.inc -gen-dialect-decls -dialect=transform)
mlir_tablegen(TransformDialect.cpp.inc -gen-dialect-defs -dialect=transform)
add_public_tablegen_target(MLIRTransformDialectIncGen)
add_dependencies(mlir-headers MLIRTransformDialectIncGen)
add_mlir_dialect(TransformOps transform)
add_mlir_interface(TransformInterfaces)

View File

@ -161,6 +161,7 @@ def Transform_Dialect : Dialect {
let name = "transform";
let cppNamespace = "::mlir::transform";
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
let extraClassDeclaration = [{
// Make addOperations available to the TransformDialectExtension class.
@ -172,4 +173,9 @@ def Transform_Dialect : Dialect {
}];
}
// Base class for ops that belong to the tranfsorm dialect. Ops defined in
// extensions of this dialect may also use this.
class TransformDialectOp<string mnemonic, list<Trait> traits = []>
: Op<Transform_Dialect, mnemonic, traits>;
#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT

View File

@ -33,6 +33,14 @@ class TransformOpInterface;
/// expected to populate the `TransformResults` class instance in order to
/// update the mapping. The `applyTransform` method takes care of propagating
/// the state of `TransformResults` into the instance of this class.
///
/// When applying transform IR operations with regions, the client is expected
/// to create a RegionScope RAII object to create a new "stack frame" for
/// values defined inside the region. The mappings from and to these values will
/// be automatically dropped when the object goes out of scope, typically at the
/// end of the "apply" function of the parent operation. If a region contains
/// blocks with arguments, the client can map those arguments to payload IR ops
/// using "mapBlockArguments".
class TransformState {
/// Mapping between a Value in the transform IR and the corresponding set of
/// operations in the payload IR.
@ -42,9 +50,19 @@ class TransformState {
/// currently associated with.
using TransformOpReverseMapping = DenseMap<Operation *, Value>;
/// Bidirectional mappings between transform IR values and payload IR
/// operations.
struct Mappings {
TransformOpMapping direct;
TransformOpReverseMapping reverse;
};
public:
/// Creates a state for the transformation rooted at the given op.
explicit TransformState(Operation *root);
/// Creates a state for transform ops living in the given region. The parent
/// operation of the region. The second argument points to the root operation
/// in the payload IR beind transformed, which may or may not contain the
/// region with transform ops.
TransformState(Region &region, Operation *root);
/// Returns the op at which the transformation state is rooted. This is
/// typically helpful for transformations that apply globally.
@ -58,10 +76,96 @@ public:
/// the state accordingly.
LogicalResult applyTransform(TransformOpInterface transform);
/// Records the mapping between a block argument in the transform IR and a
/// list of operations in the payload IR. The arguments must be defined in
/// blocks of the currently processed transform IR region, typically after a
/// region scope is defined.
LogicalResult mapBlockArguments(BlockArgument argument,
ArrayRef<Operation *> operations) {
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
assert(argument.getParentRegion() == regionStack.back() &&
"mapping block arguments from a region other than the active one");
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
return setPayloadOps(argument, operations);
}
// Forward declarations to support limited visibility.
class RegionScope;
/// Creates a new region scope for the given region. The region is expected to
/// be nested in the currently processed region.
// Implementation note: this method is inline but implemented outside of the
// class body to comply with visibility and full-declaration requirements.
inline RegionScope make_region_scope(Region &region);
/// A RAII object maintaining a "stack frame" for a transform IR region. When
/// applying a transform IR operation that contains a region, the caller is
/// expected to create a RegionScope before applying the ops contained in the
/// region. This ensures that the mappings between values defined in the
/// transform IR region and payload IR operations are cleared when the region
/// processing ends; such values cannot be accessed outside the region.
class RegionScope {
public:
/// Forgets the mapping from or to values defined in the associated
/// transform IR region.
~RegionScope() {
state.mappings.erase(region);
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
state.regionStack.pop_back();
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
}
private:
/// Creates a new scope for mappings between values defined in the given
/// transform IR region and payload IR operations.
RegionScope(TransformState &state, Region &region)
: state(state), region(&region) {
auto res = state.mappings.try_emplace(this->region);
assert(res.second && "the region scope is already present");
(void)res;
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
assert(state.regionStack.back()->isProperAncestor(&region) &&
"scope started at a non-nested region");
state.regionStack.push_back(&region);
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
}
/// Back-reference to the transform state.
TransformState &state;
/// The region this scope is associated with.
Region *region;
friend RegionScope TransformState::make_region_scope(Region &);
};
friend class RegionScope;
private:
/// Identifier for storing top-level value in the `operations` mapping.
static constexpr Value kTopLevelValue = Value();
/// Returns the mappings frame for the reigon in which the value is defined.
const Mappings &getMapping(Value value) const {
return const_cast<TransformState *>(this)->getMapping(value);
}
Mappings &getMapping(Value value) {
auto it = mappings.find(value.getParentRegion());
assert(it != mappings.end() &&
"trying to find a mapping for a value from an unmapped region");
return it->second;
}
/// Returns the mappings frame for the region in which the operation resides.
const Mappings &getMapping(Operation *operation) const {
return const_cast<TransformState *>(this)->getMapping(operation);
}
Mappings &getMapping(Operation *operation) {
auto it = mappings.find(operation->getParentRegion());
assert(it != mappings.end() &&
"trying to find a mapping for an operation from an unmapped region");
return it->second;
}
/// Sets the payload IR ops associated with the given transform IR value.
/// Fails if this would result in multiple transform IR values with uses
/// corresponding to the same payload IR ops. For example, a hypothetical
@ -88,9 +192,19 @@ private:
void updatePayloadOps(Value value,
function_ref<Operation *(Operation *)> callback);
/// The mapping between payload IR values and transform IR ops.
TransformOpMapping operationMapping;
TransformOpReverseMapping reverseMapping;
/// The mappings between transform IR values and payload IR ops, aggregated by
/// the region in which the transform IR values are defined.
llvm::SmallDenseMap<Region *, Mappings> mappings;
/// The top-level operation that contains all payload IR, typically a module.
Operation *topLevel;
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
/// A stack of nested regions that are being processed in the transform IR.
/// Each region must be an ancestor of the following regions in this list.
/// These are also the keys for "mappings".
SmallVector<Region *> regionStack;
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
};
/// Local mapping between values defined by a specific op implementing the
@ -123,6 +237,10 @@ private:
SmallVector<Operation *> operations;
};
TransformState::RegionScope TransformState::make_region_scope(Region &region) {
return RegionScope(*this, region);
}
} // namespace transform
} // namespace mlir

View File

@ -0,0 +1,20 @@
//===- TransformDialect.h - Transform dialect operations --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS_H
#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS_H
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/IR/TransformOps.h.inc"
#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS_H

View File

@ -0,0 +1,78 @@
//===- TransformOps.td - Transform dialect operations ------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS
#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS
include "mlir/IR/OpAsmInterface.td"
include "mlir/Dialect/PDL/IR/PDLTypes.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
def SequenceOp : TransformDialectOp<"sequence",
[DeclareOpInterfaceMethods<TransformOpInterface>, OpAsmOpInterface,
SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> {
let summary = "Contains a sequence of other transform ops to apply";
let description = [{
The transformations indicated by the sequence are applied in order of their
appearance. Each value produced by a transformation within the sequence
corresponds to an operation or a group of operations in the payload IR.
Each value may be used at most once by another transformation operation as
the transformation is likely to replace the transformed operation with
another operation or a group thereof. In such cases, the transformation
operation is expected to produce a new value to denote the newly produced
operations that can be transformed further. During application, if any
transformation in the sequence fails, the entire sequence fails immediately
leaving the payload IR in potentially invalid state, i.e., this operation
offers no transformation rollback capabilities.
The entry block of this operation has a single argument that maps to either
the operand if provided or the top-level container operation of the payload
IR, typically the root operation of the pass interpreting the transform
dialect. Operand omission is only allowed for sequences not contained in
another sequence.
}];
let arguments = (ins Optional<PDL_Operation>:$root);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$body);
let assemblyFormat =
"($root^)? attr-dict-with-keyword regions (`:` type($results)^)?";
let extraClassDeclaration = [{
/// Allow the dialect prefix to be omitted.
static StringRef getDefaultDialect() { return "transform"; }
Block *getBodyBlock() {
return &getBody().front();
}
}];
let hasVerifier = 1;
}
def YieldOp : TransformDialectOp<"yield", [Terminator]> {
let summary = "Yields operation handles from a transform IR region";
let description = [{
This terminator operation yields operation handles from regions of the
transform IR ops back to the containing op. It is not itself associated with
any transformation on the payload IR and is used for flow purposes only.
}];
let arguments = (ins Variadic<AnyType>:$operands);
let assemblyFormat = "operands attr-dict (`:` type($operands)^)?";
let builders = [
OpBuilder<(ins), [{
return build($_builder, $_state, ::mlir::ValueRange());
}]>
];
}
#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS

View File

@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRTransformDialect
TransformDialect.cpp
TransformInterfaces.cpp
TransformOps.cpp
DEPENDS
MLIRTransformDialectIncGen
@ -8,4 +9,6 @@ add_mlir_dialect_library(MLIRTransformDialect
LINK_LIBS PUBLIC
MLIRIR
MLIRPDL
MLIRPDLInterp
)

View File

@ -7,9 +7,15 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
using namespace mlir;
void transform::TransformDialect::initialize() {}
#include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"
void transform::TransformDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
>();
}

View File

@ -9,6 +9,7 @@
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Operation.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallPtrSet.h"
using namespace mlir;
@ -19,16 +20,21 @@ using namespace mlir;
constexpr const Value transform::TransformState::kTopLevelValue;
transform::TransformState::TransformState(Operation *root) {
operationMapping[kTopLevelValue].push_back(root);
transform::TransformState::TransformState(Region &region, Operation *root)
: topLevel(root) {
auto result = mappings.try_emplace(&region);
assert(result.second && "the region scope is already present");
(void)result;
#if LLVM_ENABLE_ABI_BREAKING_CHECKS
regionStack.push_back(&region);
#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS
}
Operation *transform::TransformState::getTopLevel() const {
return operationMapping.lookup(kTopLevelValue).front();
}
Operation *transform::TransformState::getTopLevel() const { return topLevel; }
ArrayRef<Operation *>
transform::TransformState::getPayloadOps(Value value) const {
const TransformOpMapping &operationMapping = getMapping(value).direct;
auto iter = operationMapping.find(value);
assert(iter != operationMapping.end() && "unknown handle");
return iter->getSecond();
@ -46,8 +52,9 @@ transform::TransformState::setPayloadOps(Value value,
// Setting new payload for the value without cleaning it first is a misuse of
// the API, assert here.
SmallVector<Operation *> storedTargets(targets.begin(), targets.end());
Mappings &mappings = getMapping(value);
bool inserted =
operationMapping.insert({value, std::move(storedTargets)}).second;
mappings.direct.insert({value, std::move(storedTargets)}).second;
assert(inserted && "value is already associated with another list");
(void)inserted;
@ -55,7 +62,7 @@ transform::TransformState::setPayloadOps(Value value,
// expressed using the dialect and may be constructed by valid API calls from
// valid IR. Emit an error here.
for (Operation *op : targets) {
auto insertionResult = reverseMapping.insert({op, value});
auto insertionResult = mappings.reverse.insert({op, value});
if (!insertionResult.second) {
InFlightDiagnostic diag = op->emitError()
<< "operation tracked by two handles";
@ -69,15 +76,16 @@ transform::TransformState::setPayloadOps(Value value,
}
void transform::TransformState::removePayloadOps(Value value) {
for (Operation *op : operationMapping[value])
reverseMapping.erase(op);
operationMapping.erase(value);
Mappings &mappings = getMapping(value);
for (Operation *op : mappings.direct[value])
mappings.reverse.erase(op);
mappings.direct.erase(value);
}
void transform::TransformState::updatePayloadOps(
Value value, function_ref<Operation *(Operation *)> callback) {
auto it = operationMapping.find(value);
assert(it != operationMapping.end() && "unknown handle");
auto it = getMapping(value).direct.find(value);
assert(it != getMapping(value).direct.end() && "unknown handle");
SmallVector<Operation *> &association = it->getSecond();
SmallVector<Operation *> updated;
updated.reserve(association.size());
@ -98,9 +106,13 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
for (Value target : transform->getOperands())
removePayloadOps(target);
for (auto &en : llvm::enumerate(transform->getResults()))
for (auto &en : llvm::enumerate(transform->getResults())) {
assert(en.value().getDefiningOp() == transform.getOperation() &&
"payload IR association for a value other than the result of the "
"current transform op");
if (failed(setPayloadOps(en.value(), results.get(en.index()))))
return failure();
}
return success();
}

View File

@ -0,0 +1,101 @@
//===- TransformDialect.cpp - Transform dialect operations ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
using namespace mlir;
#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
LogicalResult transform::SequenceOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<Operation *> targets;
if (getRoot())
llvm::append_range(targets, state.getPayloadOps(getRoot()));
else
targets.push_back(state.getTopLevel());
// Map the entry block argument to the list of operations.
auto scope = state.make_region_scope(*getBodyBlock()->getParent());
if (failed(state.mapBlockArguments(getBodyBlock()->getArgument(0), targets)))
return failure();
// Apply the sequenced ops one by one.
for (Operation &transform : getBodyBlock()->without_terminator())
if (failed(state.applyTransform(cast<TransformOpInterface>(transform))))
return failure();
// Forward the operation mapping for values yielded from the sequence to the
// values produced by the sequence op.
for (const auto &pair :
llvm::zip(getBodyBlock()->getTerminator()->getOperands(),
getOperation()->getOpResults())) {
Value terminatorOperand = std::get<0>(pair);
OpResult result = std::get<1>(pair);
results.set(result, state.getPayloadOps(terminatorOperand));
}
return success();
}
LogicalResult transform::SequenceOp::verify() {
if (getBodyBlock()->getNumArguments() != 1 ||
!getBodyBlock()->getArgumentTypes()[0].isa<pdl::OperationType>()) {
return emitOpError()
<< "expected the entry block to have one argument of type "
<< pdl::OperationType::get(getContext());
}
if (auto parent = getOperation()->getParentOfType<transform::SequenceOp>()) {
if (!getRoot()) {
InFlightDiagnostic diag =
emitOpError()
<< "expected the root operation to be provided for a nested sequence";
diag.attachNote(parent.getLoc()) << "nested in another sequence";
return diag;
}
}
for (Operation &child : *getBodyBlock()) {
if (!isa<TransformOpInterface>(child) &&
&child != &getBodyBlock()->back()) {
InFlightDiagnostic diag =
emitOpError()
<< "expected children ops to implement TransformOpInterface";
diag.attachNote(child.getLoc()) << "op without interface";
return diag;
}
for (OpResult result : child.getResults()) {
if (llvm::hasNItemsOrLess(result.getUses(), 1))
continue;
InFlightDiagnostic diag = child.emitError()
<< "result #" << result.getResultNumber()
<< " has more than one use";
for (OpOperand &use : result.getUses()) {
diag.attachNote(use.getOwner()->getLoc())
<< "used here as operand #" << use.getOperandNumber();
}
return diag;
}
}
if (getBodyBlock()->getTerminator()->getOperandTypes() !=
getOperation()->getResultTypes()) {
InFlightDiagnostic diag = emitOpError()
<< "expects the types of the terminator operands "
"to match the types of the result";
diag.attachNote(getBodyBlock()->getTerminator()->getLoc()) << "terminator";
return diag;
}
return success();
}

View File

@ -0,0 +1,52 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
// expected-error @below {{expected the entry block to have one argument of type '!pdl.operation'}}
transform.sequence {
}
// -----
// expected-note @below {{nested in another sequence}}
transform.sequence {
^bb0(%arg0: !pdl.operation):
// expected-error @below {{expected the root operation to be provided for a nested sequence}}
transform.sequence {
^bb1(%arg1: !pdl.operation):
}
}
// -----
// expected-error @below {{expected children ops to implement TransformOpInterface}}
transform.sequence {
^bb0(%arg0: !pdl.operation):
// expected-note @below {{op without interface}}
arith.constant 42.0 : f32
}
// -----
transform.sequence {
^bb0(%arg0: !pdl.operation):
// expected-error @below {{result #0 has more than one use}}
%0 = transform.sequence %arg0 {
^bb1(%arg1: !pdl.operation):
} : !pdl.operation
// expected-note @below {{used here as operand #0}}
transform.sequence %0 {
^bb2(%arg2: !pdl.operation):
}
// expected-note @below {{used here as operand #0}}
transform.sequence %0 {
^bb3(%arg3: !pdl.operation):
}
}
// -----
// expected-error @below {{expects the types of the terminator operands to match the types of the resul}}
%0 = transform.sequence {
^bb0(%arg0: !pdl.operation):
// expected-note @below {{terminator}}
transform.yield
} : !pdl.operation

View File

@ -0,0 +1,12 @@
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
// CHECK: transform.sequence
// CHECK: ^{{.+}}(%{{.+}}: !pdl.operation):
transform.sequence {
^bb0(%arg0: !pdl.operation):
// CHECK: sequence %{{.+}}
// CHECK: ^{{.+}}(%{{.+}}: !pdl.operation):
sequence %arg0 {
^bb1(%arg1: !pdl.operation):
}
}

View File

@ -25,3 +25,47 @@ transform.test_consume_operand_if_matches_param_or_fail %0[21]
%2 = transform.test_produce_param_or_forward_operand from %0
transform.test_consume_operand_if_matches_param_or_fail %1[42]
transform.test_consume_operand_if_matches_param_or_fail %2[42]
// -----
transform.sequence {
^bb0(%arg0: !pdl.operation):
sequence %arg0 {
^bb0(%arg1: !pdl.operation):
// expected-remark @below {{applying transformation "a"}}
test_transform_op "a"
// expected-remark @below {{applying transformation "b"}}
test_transform_op "b"
// expected-remark @below {{applying transformation "c"}}
test_transform_op "c"
}
// expected-remark @below {{applying transformation "d"}}
test_transform_op "d"
// expected-remark @below {{applying transformation "e"}}
test_transform_op "e"
}
// -----
transform.sequence {
^bb0(%arg0: !pdl.operation):
%0 = test_produce_param_or_forward_operand 42
sequence %0 {
^bb0(%arg1: !pdl.operation):
// expected-remark @below {{succeeded}}
test_consume_operand_if_matches_param_or_fail %arg1[42]
}
}
// -----
transform.sequence {
^bb0(%arg0: !pdl.operation):
%0 = sequence %arg0 {
^bb0(%arg1: !pdl.operation):
%1 = test_produce_param_or_forward_operand 42
yield %1 : !pdl.operation
} : !pdl.operation
// expected-remark @below {{succeeded}}
test_consume_operand_if_matches_param_or_fail %0[42]
}

View File

@ -38,31 +38,47 @@ public:
LogicalResult apply(transform::TransformResults &results,
transform::TransformState &state) {
emitRemark() << "applying transformation";
InFlightDiagnostic remark = emitRemark() << "applying transformation";
if (Attribute message = getMessage())
remark << " " << message;
return success();
}
Attribute getMessage() { return getOperation()->getAttr("message"); }
static ParseResult parse(OpAsmParser &parser, OperationState &state) {
return success();
StringAttr message;
OptionalParseResult result = parser.parseOptionalAttribute(message);
if (!result.hasValue())
return success();
if (result.getValue().succeeded())
state.addAttribute("message", message);
return result.getValue();
}
void print(OpAsmPrinter &printer) {}
void print(OpAsmPrinter &printer) {
if (getMessage())
printer << " " << getMessage();
}
};
} // namespace
LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::apply(
transform::TransformResults &results, transform::TransformState &state) {
if (getOperation()->getNumOperands() != 0) {
results.set(getResult().cast<OpResult>(), getOperand(0).getDefiningOp());
results.set(getResult().cast<OpResult>(),
getOperation()->getOperand(0).getDefiningOp());
} else {
results.set(getResult().cast<OpResult>(),
reinterpret_cast<Operation *>(*parameter()));
reinterpret_cast<Operation *>(*getParameter()));
}
return success();
}
LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::verify() {
if (parameter().hasValue() ^ (getNumOperands() != 1))
if (getParameter().hasValue() ^ (getNumOperands() != 1))
return emitOpError() << "expects either a parameter or an operand";
return success();
}
@ -72,9 +88,9 @@ LogicalResult mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply(
ArrayRef<Operation *> payload = state.getPayloadOps(getOperand());
assert(payload.size() == 1 && "expected a single target op");
auto value = reinterpret_cast<intptr_t>(payload[0]);
if (static_cast<uint64_t>(value) != parameter()) {
if (static_cast<uint64_t>(value) != getParameter()) {
return emitOpError() << "expected the operand to be associated with "
<< parameter() << " got " << value;
<< getParameter() << " got " << value;
}
emitRemark() << "succeeded";

View File

@ -37,7 +37,7 @@ public:
void runOnOperation() override {
ModuleOp module = getOperation();
transform::TransformState state(module);
transform::TransformState state(module.getBodyRegion(), module);
for (auto op :
module.getBody()->getOps<transform::TransformOpInterface>()) {
if (failed(state.applyTransform(op)))

View File

@ -7699,6 +7699,7 @@ td_library(
srcs = glob(["include/mlir/Dialect/Transform/IR/*.td"]),
deps = [
":OpBaseTdFiles",
":PDLDialectTdFiles",
],
)
@ -7746,15 +7747,35 @@ gentbl_cc_library(
deps = [":TransformDialectTdFiles"],
)
gentbl_cc_library(
name = "TransformOpsIncGen",
strip_include_prefix = "include",
tbl_outs = [
(
["-gen-op-decls"],
"include/mlir/Dialect/Transform/IR/TransformOps.h.inc",
),
(
["-gen-op-defs"],
"include/mlir/Dialect/Transform/IR/TransformOps.cpp.inc",
),
],
tblgen = ":mlir-tblgen",
td_file = "include/mlir/Dialect/Transform/IR/TransformOps.td",
deps = [":TransformDialectTdFiles"],
)
cc_library(
name = "TransformDialect",
srcs = glob(["lib/Dialect/Transform/IR/*.cpp"]),
hdrs = glob(["include/mlir/Dialect/Transform/IR/*.h"]),
deps = [
":IR",
":PDLDialect",
":Support",
":TransformDialectIncGen",
":TransformDialectInterfacesIncGen",
":TransformOpsIncGen",
"//llvm:Support",
],
)