forked from OSchip/llvm-project
Define a `NoTerminator` traits that allows operations with a single block region to not provide a terminator
In particular for Graph Regions, the terminator needs is just a historical artifact of the generalization of MLIR from CFG region. Operations like Module don't need a terminator, and before Module migrated to be an operation with region there wasn't any needed. To validate the feature, the ModuleOp is migrated to use this trait and the ModuleTerminator operation is deleted. This patch is likely to break clients, if you're in this case: - you may iterate on a ModuleOp with `getBody()->without_terminator()`, the solution is simple: just remove the ->without_terminator! - you created a builder with `Builder::atBlockTerminator(module_body)`, just use `Builder::atBlockEnd(module_body)` instead. - you were handling ModuleTerminator: it isn't needed anymore. - for generic code, a `Block::mayNotHaveTerminator()` may be used. Differential Revision: https://reviews.llvm.org/D98468
This commit is contained in:
parent
0f99c6c56e
commit
973ddb7d6e
|
@ -351,13 +351,18 @@ value-id-and-type-list ::= value-id-and-type (`,` value-id-and-type)*
|
|||
block-arg-list ::= `(` value-id-and-type-list? `)`
|
||||
```
|
||||
|
||||
A *Block* is an ordered list of operations, concluding with a single
|
||||
[terminator operation](#terminator-operations). In [SSACFG
|
||||
A *Block* is a list of operations. In [SSACFG
|
||||
regions](#control-flow-and-ssacfg-regions), each block represents a compiler
|
||||
[basic block](https://en.wikipedia.org/wiki/Basic_block) where instructions
|
||||
inside the block are executed in order and terminator operations implement
|
||||
control flow branches between basic blocks.
|
||||
|
||||
A region with a single block may not include a [terminator
|
||||
operation](#terminator-operations). The enclosing op can opt-out of this
|
||||
requirement with the `NoTerminator` trait. The top-level `ModuleOp` is an
|
||||
example of such operation which defined this trait and whose block body does
|
||||
not have a terminator.
|
||||
|
||||
Blocks in MLIR take a list of block arguments, notated in a function-like
|
||||
way. Block arguments are bound to values specified by the semantics of
|
||||
individual operations. Block arguments of the entry block of a region are also
|
||||
|
|
|
@ -323,13 +323,20 @@ index expression that can express the equivalent of the memory-layout
|
|||
specification of the MemRef type. See [the -normalize-memrefs pass].
|
||||
(https://mlir.llvm.org/docs/Passes/#-normalize-memrefs-normalize-memrefs)
|
||||
|
||||
### Single Block with Implicit Terminator
|
||||
### Single Block Region
|
||||
|
||||
* `OpTrait::SingleBlockImplicitTerminator<typename TerminatorOpType>` :
|
||||
`SingleBlockImplicitTerminator<string op>`
|
||||
* `OpTrait::SingleBlock` -- `SingleBlock`
|
||||
|
||||
This trait provides APIs and verifiers for operations with regions that have a
|
||||
single block that must terminate with `TerminatorOpType`.
|
||||
single block.
|
||||
|
||||
### Single Block with Implicit Terminator
|
||||
|
||||
* `OpTrait::SingleBlockImplicitTerminator<typename TerminatorOpType>` --
|
||||
`SingleBlockImplicitTerminator<string op>`
|
||||
|
||||
This trait implies the `SingleBlock` above, but adds the additional requirement
|
||||
that the single block must terminate with `TerminatorOpType`.
|
||||
|
||||
### SymbolTable
|
||||
|
||||
|
@ -344,3 +351,10 @@ This trait is used for operations that define a
|
|||
|
||||
This trait provides verification and functionality for operations that are known
|
||||
to be [terminators](LangRef.md#terminator-operations).
|
||||
|
||||
* `OpTrait::NoTerminator` -- `NoTerminator`
|
||||
|
||||
This trait removes the requirement on regions held by an operation to have
|
||||
[terminator operations](LangRef.md#terminator-operations) at the end of a block.
|
||||
This requires that these regions have a single block. An example of operation
|
||||
using this trait is the top-level `ModuleOp`.
|
||||
|
|
|
@ -63,7 +63,7 @@ everything to the LLVM dialect.
|
|||
```c++
|
||||
mlir::ConversionTarget target(getContext());
|
||||
target.addLegalDialect<mlir::LLVMDialect>();
|
||||
target.addLegalOp<mlir::ModuleOp, mlir::ModuleTerminatorOp>();
|
||||
target.addLegalOp<mlir::ModuleOp>();
|
||||
```
|
||||
|
||||
### Type Converter
|
||||
|
|
|
@ -110,7 +110,6 @@ llvm-project/mlir/test/IR/print-ir-nesting.mlir`:
|
|||
"dialect.innerop6"() : () -> ()
|
||||
"dialect.innerop7"() : () -> ()
|
||||
}) {"other attribute" = 42 : i64} : () -> ()
|
||||
"module_terminator"() : () -> ()
|
||||
}) : () -> ()
|
||||
```
|
||||
|
||||
|
@ -147,7 +146,6 @@ visiting op: 'module' with 0 operands and 0 results
|
|||
0 nested regions:
|
||||
visiting op: 'dialect.innerop7' with 0 operands and 0 results
|
||||
0 nested regions:
|
||||
visiting op: 'module_terminator' with 0 operands and 0 results
|
||||
0 nested regions:
|
||||
```
|
||||
|
||||
|
|
|
@ -174,7 +174,7 @@ void ToyToLLVMLoweringPass::runOnOperation() {
|
|||
// final target for this lowering. For this lowering, we are only targeting
|
||||
// the LLVM dialect.
|
||||
LLVMConversionTarget target(getContext());
|
||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||
target.addLegalOp<ModuleOp>();
|
||||
|
||||
// During this lowering, we will also be lowering the MemRef types, that are
|
||||
// currently being operated on, to a representation in LLVM. To perform this
|
||||
|
|
|
@ -174,7 +174,7 @@ void ToyToLLVMLoweringPass::runOnOperation() {
|
|||
// final target for this lowering. For this lowering, we are only targeting
|
||||
// the LLVM dialect.
|
||||
LLVMConversionTarget target(getContext());
|
||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||
target.addLegalOp<ModuleOp>();
|
||||
|
||||
// During this lowering, we will also be lowering the MemRef types, that are
|
||||
// currently being operated on, to a representation in LLVM. To perform this
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
#include "mlir/IR/FunctionSupport.h"
|
||||
#include "mlir/IR/OwningOpRef.h"
|
||||
#include "mlir/IR/RegionKindInterface.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Interfaces/CallInterfaces.h"
|
||||
#include "mlir/Interfaces/CastInterfaces.h"
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#define BUILTIN_OPS
|
||||
|
||||
include "mlir/IR/BuiltinDialect.td"
|
||||
include "mlir/IR/RegionKindInterface.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
include "mlir/Interfaces/CallInterfaces.td"
|
||||
include "mlir/Interfaces/CastInterfaces.td"
|
||||
|
@ -159,17 +160,17 @@ def FuncOp : Builtin_Op<"func", [
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ModuleOp : Builtin_Op<"module", [
|
||||
AffineScope, IsolatedFromAbove, NoRegionArguments, SymbolTable, Symbol,
|
||||
SingleBlockImplicitTerminator<"ModuleTerminatorOp">
|
||||
]> {
|
||||
AffineScope, IsolatedFromAbove, NoRegionArguments, SymbolTable, Symbol]
|
||||
# GraphRegionNoTerminator.traits> {
|
||||
let summary = "A top level container operation";
|
||||
let description = [{
|
||||
A `module` represents a top-level container operation. It contains a single
|
||||
SSACFG region containing a single block which can contain any
|
||||
operations. Operations within this region cannot implicitly capture values
|
||||
defined outside the module, i.e. Modules are `IsolatedFromAbove`. Modules
|
||||
have an optional symbol name which can be used to refer to them in
|
||||
operations.
|
||||
[graph region](#control-flow-and-ssacfg-regions) containing a single block
|
||||
which can contain any operations and does not have a terminator. Operations
|
||||
within this region cannot implicitly capture values defined outside the module,
|
||||
i.e. Modules are [IsolatedFromAbove](Traits.md#isolatedfromabove). Modules have
|
||||
an optional [symbol name](SymbolsAndSymbolTables.md) which can be used to refer
|
||||
to them in operations.
|
||||
|
||||
Example:
|
||||
|
||||
|
@ -213,22 +214,6 @@ def ModuleOp : Builtin_Op<"module", [
|
|||
let skipDefaultBuilders = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ModuleTerminatorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ModuleTerminatorOp : Builtin_Op<"module_terminator", [
|
||||
Terminator, HasParent<"ModuleOp">
|
||||
]> {
|
||||
let summary = "A pseudo op that marks the end of a module";
|
||||
let description = [{
|
||||
`module_terminator` is a special terminator operation for the body of a
|
||||
`module`, it has no semantic meaning beyond keeping the body of a `module`
|
||||
well-formed.
|
||||
}];
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// UnrealizedConversionCastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1827,10 +1827,16 @@ def ElementwiseMappable {
|
|||
];
|
||||
}
|
||||
|
||||
// Op's regions have a single block.
|
||||
def SingleBlock : NativeOpTrait<"SingleBlock">;
|
||||
|
||||
// Op's regions have a single block with the specified terminator.
|
||||
class SingleBlockImplicitTerminator<string op>
|
||||
: ParamNativeOpTrait<"SingleBlockImplicitTerminator", op>;
|
||||
|
||||
// Op's regions don't have terminator.
|
||||
def NoTerminator : NativeOpTrait<"NoTerminator">;
|
||||
|
||||
// Op's parent operation is the provided one.
|
||||
class HasParent<string op>
|
||||
: ParamNativeOpTrait<"HasParent", op>;
|
||||
|
|
|
@ -654,6 +654,11 @@ class VariadicResults
|
|||
//===----------------------------------------------------------------------===//
|
||||
// Terminator Traits
|
||||
|
||||
/// This class indicates that the regions associated with this op don't have
|
||||
/// terminators.
|
||||
template <typename ConcreteType>
|
||||
class NoTerminator : public TraitBase<ConcreteType, NoTerminator> {};
|
||||
|
||||
/// This class provides the API for ops that are known to be terminators.
|
||||
template <typename ConcreteType>
|
||||
class IsTerminator : public TraitBase<ConcreteType, IsTerminator> {
|
||||
|
@ -757,6 +762,87 @@ class VariadicSuccessors
|
|||
: public detail::MultiSuccessorTraitBase<ConcreteType, VariadicSuccessors> {
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SingleBlock
|
||||
|
||||
/// This class provides APIs and verifiers for ops with regions having a single
|
||||
/// block.
|
||||
template <typename ConcreteType>
|
||||
struct SingleBlock : public TraitBase<ConcreteType, SingleBlock> {
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) {
|
||||
Region ®ion = op->getRegion(i);
|
||||
|
||||
// Empty regions are fine.
|
||||
if (region.empty())
|
||||
continue;
|
||||
|
||||
// Non-empty regions must contain a single basic block.
|
||||
if (!llvm::hasSingleElement(region))
|
||||
return op->emitOpError("expects region #")
|
||||
<< i << " to have 0 or 1 blocks";
|
||||
|
||||
if (!ConcreteType::template hasTrait<NoTerminator>()) {
|
||||
Block &block = region.front();
|
||||
if (block.empty())
|
||||
return op->emitOpError() << "expects a non-empty block";
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
Block *getBody(unsigned idx = 0) {
|
||||
Region ®ion = this->getOperation()->getRegion(idx);
|
||||
assert(!region.empty() && "unexpected empty region");
|
||||
return ®ion.front();
|
||||
}
|
||||
Region &getBodyRegion(unsigned idx = 0) {
|
||||
return this->getOperation()->getRegion(idx);
|
||||
}
|
||||
|
||||
//===------------------------------------------------------------------===//
|
||||
// Single Region Utilities
|
||||
//===------------------------------------------------------------------===//
|
||||
|
||||
/// The following are a set of methods only enabled when the parent
|
||||
/// operation has a single region. Each of these methods take an additional
|
||||
/// template parameter that represents the concrete operation so that we
|
||||
/// can use SFINAE to disable the methods for non-single region operations.
|
||||
template <typename OpT, typename T = void>
|
||||
using enable_if_single_region =
|
||||
typename std::enable_if_t<OpT::template hasTrait<OneRegion>(), T>;
|
||||
|
||||
template <typename OpT = ConcreteType>
|
||||
enable_if_single_region<OpT, Block::iterator> begin() {
|
||||
return getBody()->begin();
|
||||
}
|
||||
template <typename OpT = ConcreteType>
|
||||
enable_if_single_region<OpT, Block::iterator> end() {
|
||||
return getBody()->end();
|
||||
}
|
||||
template <typename OpT = ConcreteType>
|
||||
enable_if_single_region<OpT, Operation &> front() {
|
||||
return *begin();
|
||||
}
|
||||
|
||||
/// Insert the operation into the back of the body.
|
||||
template <typename OpT = ConcreteType>
|
||||
enable_if_single_region<OpT> push_back(Operation *op) {
|
||||
insert(Block::iterator(getBody()->end()), op);
|
||||
}
|
||||
|
||||
/// Insert the operation at the given insertion point.
|
||||
template <typename OpT = ConcreteType>
|
||||
enable_if_single_region<OpT> insert(Operation *insertPt, Operation *op) {
|
||||
insert(Block::iterator(insertPt), op);
|
||||
}
|
||||
template <typename OpT = ConcreteType>
|
||||
enable_if_single_region<OpT> insert(Block::iterator insertPt, Operation *op) {
|
||||
getBody()->getOperations().insert(insertPt, op);
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SingleBlockImplicitTerminator
|
||||
|
||||
|
@ -765,8 +851,9 @@ class VariadicSuccessors
|
|||
template <typename TerminatorOpType>
|
||||
struct SingleBlockImplicitTerminator {
|
||||
template <typename ConcreteType>
|
||||
class Impl : public TraitBase<ConcreteType, Impl> {
|
||||
class Impl : public SingleBlock<ConcreteType> {
|
||||
private:
|
||||
using Base = SingleBlock<ConcreteType>;
|
||||
/// Builds a terminator operation without relying on OpBuilder APIs to avoid
|
||||
/// cyclic header inclusion.
|
||||
static Operation *buildTerminator(OpBuilder &builder, Location loc) {
|
||||
|
@ -780,22 +867,14 @@ struct SingleBlockImplicitTerminator {
|
|||
using ImplicitTerminatorOpT = TerminatorOpType;
|
||||
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
if (failed(Base::verifyTrait(op)))
|
||||
return failure();
|
||||
for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) {
|
||||
Region ®ion = op->getRegion(i);
|
||||
|
||||
// Empty regions are fine.
|
||||
if (region.empty())
|
||||
continue;
|
||||
|
||||
// Non-empty regions must contain a single basic block.
|
||||
if (std::next(region.begin()) != region.end())
|
||||
return op->emitOpError("expects region #")
|
||||
<< i << " to have 0 or 1 blocks";
|
||||
|
||||
Block &block = region.front();
|
||||
if (block.empty())
|
||||
return op->emitOpError() << "expects a non-empty block";
|
||||
Operation &terminator = block.back();
|
||||
Operation &terminator = region.front().back();
|
||||
if (isa<TerminatorOpType>(terminator))
|
||||
continue;
|
||||
|
||||
|
@ -828,40 +907,15 @@ struct SingleBlockImplicitTerminator {
|
|||
buildTerminator);
|
||||
}
|
||||
|
||||
Block *getBody(unsigned idx = 0) {
|
||||
Region ®ion = this->getOperation()->getRegion(idx);
|
||||
assert(!region.empty() && "unexpected empty region");
|
||||
return ®ion.front();
|
||||
}
|
||||
Region &getBodyRegion(unsigned idx = 0) {
|
||||
return this->getOperation()->getRegion(idx);
|
||||
}
|
||||
|
||||
//===------------------------------------------------------------------===//
|
||||
// Single Region Utilities
|
||||
//===------------------------------------------------------------------===//
|
||||
using Base::getBody;
|
||||
|
||||
/// The following are a set of methods only enabled when the parent
|
||||
/// operation has a single region. Each of these methods take an additional
|
||||
/// template parameter that represents the concrete operation so that we
|
||||
/// can use SFINAE to disable the methods for non-single region operations.
|
||||
template <typename OpT, typename T = void>
|
||||
using enable_if_single_region =
|
||||
typename std::enable_if_t<OpT::template hasTrait<OneRegion>(), T>;
|
||||
|
||||
template <typename OpT = ConcreteType>
|
||||
enable_if_single_region<OpT, Block::iterator> begin() {
|
||||
return getBody()->begin();
|
||||
}
|
||||
template <typename OpT = ConcreteType>
|
||||
enable_if_single_region<OpT, Block::iterator> end() {
|
||||
return getBody()->end();
|
||||
}
|
||||
template <typename OpT = ConcreteType>
|
||||
enable_if_single_region<OpT, Operation &> front() {
|
||||
return *begin();
|
||||
}
|
||||
|
||||
/// Insert the operation into the back of the body, before the terminator.
|
||||
template <typename OpT = ConcreteType>
|
||||
enable_if_single_region<OpT> push_back(Operation *op) {
|
||||
|
@ -886,6 +940,27 @@ struct SingleBlockImplicitTerminator {
|
|||
};
|
||||
};
|
||||
|
||||
/// Check is an op defines the `ImplicitTerminatorOpT` member. This is intended
|
||||
/// to be used with `llvm::is_detected`.
|
||||
template <class T>
|
||||
using has_implicit_terminator_t = typename T::ImplicitTerminatorOpT;
|
||||
|
||||
/// Support to check if an operation has the SingleBlockImplicitTerminator
|
||||
/// trait. We can't just use `hasTrait` because this class is templated on a
|
||||
/// specific terminator op.
|
||||
template <class Op, bool hasTerminator =
|
||||
llvm::is_detected<has_implicit_terminator_t, Op>::value>
|
||||
struct hasSingleBlockImplicitTerminator {
|
||||
static constexpr bool value = std::is_base_of<
|
||||
typename OpTrait::SingleBlockImplicitTerminator<
|
||||
typename Op::ImplicitTerminatorOpT>::template Impl<Op>,
|
||||
Op>::value;
|
||||
};
|
||||
template <class Op>
|
||||
struct hasSingleBlockImplicitTerminator<Op, false> {
|
||||
static constexpr bool value = false;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Misc Traits
|
||||
|
||||
|
|
|
@ -92,8 +92,13 @@ public:
|
|||
virtual void printGenericOp(Operation *op) = 0;
|
||||
|
||||
/// Prints a region.
|
||||
/// If 'printEntryBlockArgs' is false, the arguments of the
|
||||
/// block are not printed. If 'printBlockTerminator' is false, the terminator
|
||||
/// operation of the block is not printed. If printEmptyBlock is true, then
|
||||
/// the block header is printed even if the block is empty.
|
||||
virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true,
|
||||
bool printBlockTerminators = true) = 0;
|
||||
bool printBlockTerminators = true,
|
||||
bool printEmptyBlock = false) = 0;
|
||||
|
||||
/// Renumber the arguments for the specified region to the same names as the
|
||||
/// SSA values in namesToUse. This may only be used for IsolatedFromAbove
|
||||
|
|
|
@ -43,6 +43,10 @@ public:
|
|||
|
||||
using BlockListType = llvm::iplist<Block>;
|
||||
BlockListType &getBlocks() { return blocks; }
|
||||
Block &emplaceBlock() {
|
||||
push_back(new Block);
|
||||
return back();
|
||||
}
|
||||
|
||||
// Iteration over the blocks in the region.
|
||||
using iterator = BlockListType::iterator;
|
||||
|
|
|
@ -28,6 +28,16 @@ enum class RegionKind {
|
|||
Graph,
|
||||
};
|
||||
|
||||
namespace OpTrait {
|
||||
/// A trait that specifies that an operation only defines graph regions.
|
||||
template <typename ConcreteType>
|
||||
class HasOnlyGraphRegion : public TraitBase<ConcreteType, HasOnlyGraphRegion> {
|
||||
public:
|
||||
static RegionKind getRegionKind(unsigned index) { return RegionKind::Graph; }
|
||||
static bool hasSSADominance(unsigned index) { return false; }
|
||||
};
|
||||
} // namespace OpTrait
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#include "mlir/IR/RegionKindInterface.h.inc"
|
||||
|
|
|
@ -50,4 +50,17 @@ def RegionKindInterface : OpInterface<"RegionKindInterface"> {
|
|||
];
|
||||
}
|
||||
|
||||
def HasOnlyGraphRegion : NativeOpTrait<"HasOnlyGraphRegion">;
|
||||
|
||||
// Op's regions that don't need a terminator: requires some other traits
|
||||
// so it defines a list that must be concatenated.
|
||||
def GraphRegionNoTerminator {
|
||||
list<OpTrait> traits = [
|
||||
NoTerminator,
|
||||
SingleBlock,
|
||||
RegionKindInterface,
|
||||
HasOnlyGraphRegion
|
||||
];
|
||||
}
|
||||
|
||||
#endif // MLIR_IR_REGIONKINDINTERFACE
|
||||
|
|
|
@ -25,6 +25,7 @@ class StringRef;
|
|||
|
||||
namespace mlir {
|
||||
namespace detail {
|
||||
|
||||
/// Given a block containing operations that have just been parsed, if the block
|
||||
/// contains a single operation of `ContainerOpT` type then remove it from the
|
||||
/// block and return it. If the block does not contain just that operation,
|
||||
|
@ -37,12 +38,11 @@ inline OwningOpRef<ContainerOpT> constructContainerOpForParserIfNecessary(
|
|||
Block *parsedBlock, MLIRContext *context, Location sourceFileLoc) {
|
||||
static_assert(
|
||||
ContainerOpT::template hasTrait<OpTrait::OneRegion>() &&
|
||||
std::is_base_of<typename OpTrait::SingleBlockImplicitTerminator<
|
||||
typename ContainerOpT::ImplicitTerminatorOpT>::
|
||||
template Impl<ContainerOpT>,
|
||||
ContainerOpT>::value,
|
||||
(ContainerOpT::template hasTrait<OpTrait::NoTerminator>() ||
|
||||
OpTrait::template hasSingleBlockImplicitTerminator<
|
||||
ContainerOpT>::value),
|
||||
"Expected `ContainerOpT` to have a single region with a single "
|
||||
"block that has an implicit terminator");
|
||||
"block that has an implicit terminator or does not require one");
|
||||
|
||||
// Check to see if we parsed a single instance of this operation.
|
||||
if (llvm::hasSingleElement(*parsedBlock)) {
|
||||
|
|
|
@ -16,8 +16,6 @@ class ModuleOp:
|
|||
super().__init__(self.build_generic(results=[], operands=[], loc=loc,
|
||||
ip=ip))
|
||||
body = self.regions[0].blocks.append()
|
||||
with InsertionPoint(body):
|
||||
Operation.create("module_terminator")
|
||||
|
||||
@property
|
||||
def body(self):
|
||||
|
|
|
@ -156,8 +156,8 @@ struct AsyncAPI {
|
|||
|
||||
/// Adds Async Runtime C API declarations to the module.
|
||||
static void addAsyncRuntimeApiDeclarations(ModuleOp module) {
|
||||
auto builder = ImplicitLocOpBuilder::atBlockTerminator(module.getLoc(),
|
||||
module.getBody());
|
||||
auto builder =
|
||||
ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());
|
||||
|
||||
auto addFuncDecl = [&](StringRef name, FunctionType type) {
|
||||
if (module.lookupSymbol(name))
|
||||
|
@ -207,8 +207,8 @@ static void addCRuntimeDeclarations(ModuleOp module) {
|
|||
using namespace mlir::LLVM;
|
||||
|
||||
MLIRContext *ctx = module.getContext();
|
||||
ImplicitLocOpBuilder builder(module.getLoc(),
|
||||
module.getBody()->getTerminator());
|
||||
auto builder =
|
||||
ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());
|
||||
|
||||
auto voidTy = LLVMVoidType::get(ctx);
|
||||
auto i64 = IntegerType::get(ctx, 64);
|
||||
|
@ -232,15 +232,14 @@ static void addResumeFunction(ModuleOp module) {
|
|||
return;
|
||||
|
||||
MLIRContext *ctx = module.getContext();
|
||||
|
||||
OpBuilder moduleBuilder(module.getBody()->getTerminator());
|
||||
Location loc = module.getLoc();
|
||||
auto loc = module.getLoc();
|
||||
auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody());
|
||||
|
||||
auto voidTy = LLVM::LLVMVoidType::get(ctx);
|
||||
auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8));
|
||||
|
||||
auto resumeOp = moduleBuilder.create<LLVM::LLVMFuncOp>(
|
||||
loc, kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
|
||||
kResume, LLVM::LLVMFunctionType::get(voidTy, {i8Ptr}));
|
||||
resumeOp.setPrivate();
|
||||
|
||||
auto *block = resumeOp.addEntryBlock();
|
||||
|
|
|
@ -342,7 +342,7 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
|
|||
auto function = [&] {
|
||||
if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName))
|
||||
return function;
|
||||
return OpBuilder(module.getBody()->getTerminator())
|
||||
return OpBuilder::atBlockEnd(module.getBody())
|
||||
.create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
|
||||
}();
|
||||
return builder.create<LLVM::CallOp>(
|
||||
|
|
|
@ -99,7 +99,7 @@ void ConvertGpuLaunchFuncToVulkanLaunchFunc::runOnOperation() {
|
|||
|
||||
LogicalResult ConvertGpuLaunchFuncToVulkanLaunchFunc::declareVulkanLaunchFunc(
|
||||
Location loc, gpu::LaunchFuncOp launchOp) {
|
||||
OpBuilder builder(getOperation().getBody()->getTerminator());
|
||||
auto builder = OpBuilder::atBlockEnd(getOperation().getBody());
|
||||
|
||||
// Workgroup size is written into the kernel. So to properly modelling
|
||||
// vulkan launch, we have to skip local workgroup size configuration here.
|
||||
|
|
|
@ -291,7 +291,7 @@ LogicalResult VulkanLaunchFuncToVulkanCallsPass::deduceMemRefRankAndType(
|
|||
|
||||
void VulkanLaunchFuncToVulkanCallsPass::declareVulkanFunctions(Location loc) {
|
||||
ModuleOp module = getOperation();
|
||||
OpBuilder builder(module.getBody()->getTerminator());
|
||||
auto builder = OpBuilder::atBlockEnd(module.getBody());
|
||||
|
||||
if (!module.lookupSymbol(kSetEntryPoint)) {
|
||||
builder.create<LLVM::LLVMFuncOp>(
|
||||
|
|
|
@ -227,7 +227,7 @@ void ConvertLinalgToLLVMPass::runOnOperation() {
|
|||
|
||||
LLVMConversionTarget target(getContext());
|
||||
target.addIllegalOp<RangeOp>();
|
||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp, LLVM::DialectCastOp>();
|
||||
target.addLegalOp<ModuleOp, LLVM::DialectCastOp>();
|
||||
if (failed(applyPartialConversion(module, target, std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
|
|
@ -35,7 +35,7 @@ void LinalgToSPIRVPass::runOnOperation() {
|
|||
populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
|
||||
|
||||
// Allow builtin ops.
|
||||
target->addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||
target->addLegalOp<ModuleOp>();
|
||||
target->addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
|
||||
return typeConverter.isSignatureLegal(op.getType()) &&
|
||||
typeConverter.isLegal(&op.getBody());
|
||||
|
|
|
@ -216,7 +216,7 @@ void ConvertLinalgToStandardPass::runOnOperation() {
|
|||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<AffineDialect, memref::MemRefDialect, scf::SCFDialect,
|
||||
StandardOpsDialect>();
|
||||
target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ReturnOp>();
|
||||
target.addLegalOp<ModuleOp, FuncOp, ReturnOp>();
|
||||
target.addLegalOp<linalg::ReshapeOp, linalg::RangeOp>();
|
||||
RewritePatternSet patterns(&getContext());
|
||||
populateLinalgToStandardConversionPatterns(patterns);
|
||||
|
|
|
@ -1358,7 +1358,7 @@ public:
|
|||
matchAndRewrite(spirv::ModuleEndOp moduleEndOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
rewriter.replaceOpWithNewOp<ModuleTerminatorOp>(moduleEndOp);
|
||||
rewriter.eraseOp(moduleEndOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -48,10 +48,8 @@ void ConvertSPIRVToLLVMPass::runOnOperation() {
|
|||
target.addIllegalDialect<spirv::SPIRVDialect>();
|
||||
target.addLegalDialect<LLVM::LLVMDialect>();
|
||||
|
||||
// Set `ModuleOp` and `ModuleTerminatorOp` as legal for `spv.module`
|
||||
// conversion.
|
||||
// Set `ModuleOp` as legal for `spv.module` conversion.
|
||||
target.addLegalOp<ModuleOp>();
|
||||
target.addLegalOp<ModuleTerminatorOp>();
|
||||
if (failed(applyPartialConversion(module, target, std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
|
|
|
@ -675,7 +675,7 @@ void ConvertShapeToStandardPass::runOnOperation() {
|
|||
ConversionTarget target(ctx);
|
||||
target.addLegalDialect<memref::MemRefDialect, StandardOpsDialect, SCFDialect,
|
||||
tensor::TensorDialect>();
|
||||
target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp, ModuleTerminatorOp>();
|
||||
target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp>();
|
||||
|
||||
// Setup conversion patterns.
|
||||
RewritePatternSet patterns(&ctx);
|
||||
|
|
|
@ -40,7 +40,7 @@ void LowerVectorToSPIRVPass::runOnOperation() {
|
|||
RewritePatternSet patterns(context);
|
||||
populateVectorToSPIRVPatterns(typeConverter, patterns);
|
||||
|
||||
target->addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||
target->addLegalOp<ModuleOp>();
|
||||
target->addLegalOp<FuncOp>();
|
||||
|
||||
if (failed(applyFullConversion(module, *target, std::move(patterns))))
|
||||
|
|
|
@ -199,7 +199,7 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) {
|
|||
// TODO: Derive outlined function name from the parent FuncOp (support
|
||||
// multiple nested async.execute operations).
|
||||
FuncOp func = FuncOp::create(loc, kAsyncFnPrefix, funcType, funcAttrs);
|
||||
symbolTable.insert(func, Block::iterator(module.getBody()->getTerminator()));
|
||||
symbolTable.insert(func);
|
||||
|
||||
SymbolTable::setSymbolVisibility(func, SymbolTable::Visibility::Private);
|
||||
|
||||
|
|
|
@ -42,8 +42,7 @@ struct FuncBufferizePass : public FuncBufferizeBase<FuncBufferizePass> {
|
|||
|
||||
populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter);
|
||||
populateReturnOpTypeConversionPattern(patterns, typeConverter);
|
||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp, memref::TensorLoadOp,
|
||||
memref::BufferCastOp>();
|
||||
target.addLegalOp<ModuleOp, memref::TensorLoadOp, memref::BufferCastOp>();
|
||||
|
||||
target.markUnknownOpDynamicallyLegal([&](Operation *op) {
|
||||
return isNotBranchOpInterfaceOrReturnLikeOp(op) ||
|
||||
|
|
|
@ -418,7 +418,8 @@ private:
|
|||
|
||||
/// Print the given region.
|
||||
void printRegion(Region ®ion, bool printEntryBlockArgs,
|
||||
bool printBlockTerminators) override {
|
||||
bool printBlockTerminators,
|
||||
bool printEmptyBlock = false) override {
|
||||
if (region.empty())
|
||||
return;
|
||||
|
||||
|
@ -2324,7 +2325,7 @@ public:
|
|||
|
||||
/// Print the given region.
|
||||
void printRegion(Region ®ion, bool printEntryBlockArgs,
|
||||
bool printBlockTerminators) override;
|
||||
bool printBlockTerminators, bool printEmptyBlock) override;
|
||||
|
||||
/// Renumber the arguments for the specified region to the same names as the
|
||||
/// SSA values in namesToUse. This may only be used for IsolatedFromAbove
|
||||
|
@ -2440,7 +2441,7 @@ void OperationPrinter::printGenericOp(Operation *op) {
|
|||
os << " (";
|
||||
interleaveComma(op->getRegions(), [&](Region ®ion) {
|
||||
printRegion(region, /*printEntryBlockArgs=*/true,
|
||||
/*printBlockTerminators=*/true);
|
||||
/*printBlockTerminators=*/true, /*printEmptyBlock=*/true);
|
||||
});
|
||||
os << ')';
|
||||
}
|
||||
|
@ -2541,12 +2542,18 @@ void OperationPrinter::printSuccessorAndUseList(Block *successor,
|
|||
}
|
||||
|
||||
void OperationPrinter::printRegion(Region ®ion, bool printEntryBlockArgs,
|
||||
bool printBlockTerminators) {
|
||||
bool printBlockTerminators,
|
||||
bool printEmptyBlock) {
|
||||
os << " {" << newLine;
|
||||
if (!region.empty()) {
|
||||
auto *entryBlock = ®ion.front();
|
||||
print(entryBlock, printEntryBlockArgs && entryBlock->getNumArguments() != 0,
|
||||
printBlockTerminators);
|
||||
// Force printing the block header if printEmptyBlock is set and the block
|
||||
// is empty or if printEntryBlockArgs is set and there are arguments to
|
||||
// print.
|
||||
bool shouldAlwaysPrintBlockHeader =
|
||||
(printEmptyBlock && entryBlock->empty()) ||
|
||||
(printEntryBlockArgs && entryBlock->getNumArguments() != 0);
|
||||
print(entryBlock, shouldAlwaysPrintBlockHeader, printBlockTerminators);
|
||||
for (auto &b : llvm::drop_begin(region.getBlocks(), 1))
|
||||
print(&b);
|
||||
}
|
||||
|
|
|
@ -294,6 +294,21 @@ Block *Block::splitBlock(iterator splitBefore) {
|
|||
return newBB;
|
||||
}
|
||||
|
||||
/// Returns true if this block may be valid without terminator. That is if:
|
||||
/// - it does not have a parent region.
|
||||
/// - Or the parent region have a single block and:
|
||||
/// - This region does not have a parent op.
|
||||
/// - Or the parent op is unregistered.
|
||||
/// - Or the parent op has the NoTerminator trait.
|
||||
static bool mayNotHaveTerminator(Block *block) {
|
||||
if (!block->getParent())
|
||||
return true;
|
||||
if (!llvm::hasSingleElement(*block->getParent()))
|
||||
return false;
|
||||
Operation *op = block->getParentOp();
|
||||
return !op || op->mightHaveTrait<OpTrait::NoTerminator>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Predecessors
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -314,9 +329,11 @@ unsigned PredecessorIterator::getSuccessorIndex() const {
|
|||
SuccessorRange::SuccessorRange() : SuccessorRange(nullptr, 0) {}
|
||||
|
||||
SuccessorRange::SuccessorRange(Block *block) : SuccessorRange() {
|
||||
if (Operation *term = block->getTerminator())
|
||||
if (!llvm::hasSingleElement(*block->getParent())) {
|
||||
Operation *term = block->getTerminator();
|
||||
if ((count = term->getNumSuccessors()))
|
||||
base = term->getBlockOperands().data();
|
||||
}
|
||||
}
|
||||
|
||||
SuccessorRange::SuccessorRange(Operation *term) : SuccessorRange() {
|
||||
|
|
|
@ -209,7 +209,7 @@ FuncOp FuncOp::clone() {
|
|||
|
||||
void ModuleOp::build(OpBuilder &builder, OperationState &state,
|
||||
Optional<StringRef> name) {
|
||||
ensureTerminator(*state.addRegion(), builder, state.location);
|
||||
state.addRegion()->emplaceBlock();
|
||||
if (name) {
|
||||
state.attributes.push_back(builder.getNamedAttr(
|
||||
mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(*name)));
|
||||
|
|
|
@ -161,11 +161,17 @@ void SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
|
|||
// TODO: consider if SymbolTable's constructor should behave the same.
|
||||
if (!symbol->getParentOp()) {
|
||||
auto &body = symbolTableOp->getRegion(0).front();
|
||||
if (insertPt == Block::iterator() || insertPt == body.end())
|
||||
insertPt = Block::iterator(body.getTerminator());
|
||||
|
||||
assert(insertPt->getParentOp() == symbolTableOp &&
|
||||
"expected insertPt to be in the associated module operation");
|
||||
if (insertPt == Block::iterator()) {
|
||||
insertPt = Block::iterator(body.end());
|
||||
} else {
|
||||
assert((insertPt == body.end() ||
|
||||
insertPt->getParentOp() == symbolTableOp) &&
|
||||
"expected insertPt to be in the associated module operation");
|
||||
}
|
||||
// Insert before the terminator, if any.
|
||||
if (insertPt == Block::iterator(body.end()) && !body.empty() &&
|
||||
std::prev(body.end())->hasTrait<OpTrait::IsTerminator>())
|
||||
insertPt = std::prev(body.end());
|
||||
|
||||
body.getOperations().insert(insertPt, symbol);
|
||||
}
|
||||
|
@ -291,11 +297,14 @@ void SymbolTable::walkSymbolTables(
|
|||
Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
|
||||
StringRef symbol) {
|
||||
assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
|
||||
Region ®ion = symbolTableOp->getRegion(0);
|
||||
if (region.empty())
|
||||
return nullptr;
|
||||
|
||||
// Look for a symbol with the given name.
|
||||
Identifier symbolNameId = Identifier::get(SymbolTable::getSymbolAttrName(),
|
||||
symbolTableOp->getContext());
|
||||
for (auto &op : symbolTableOp->getRegion(0).front().without_terminator())
|
||||
for (auto &op : region.front())
|
||||
if (getNameIfSymbol(&op, symbolNameId) == symbol)
|
||||
return &op;
|
||||
return nullptr;
|
||||
|
|
|
@ -113,17 +113,36 @@ LogicalResult OperationVerifier::verifyRegion(Region ®ion) {
|
|||
return success();
|
||||
}
|
||||
|
||||
/// Returns true if this block may be valid without terminator. That is if:
|
||||
/// - it does not have a parent region.
|
||||
/// - Or the parent region have a single block and:
|
||||
/// - This region does not have a parent op.
|
||||
/// - Or the parent op is unregistered.
|
||||
/// - Or the parent op has the NoTerminator trait.
|
||||
static bool mayNotHaveTerminator(Block *block) {
|
||||
if (!block->getParent())
|
||||
return true;
|
||||
if (!llvm::hasSingleElement(*block->getParent()))
|
||||
return false;
|
||||
Operation *op = block->getParentOp();
|
||||
return !op || op->mightHaveTrait<OpTrait::NoTerminator>();
|
||||
}
|
||||
|
||||
LogicalResult OperationVerifier::verifyBlock(Block &block) {
|
||||
for (auto arg : block.getArguments())
|
||||
if (arg.getOwner() != &block)
|
||||
return emitError(block, "block argument not owned by block");
|
||||
|
||||
// Verify that this block has a terminator.
|
||||
if (block.empty())
|
||||
return emitError(block, "block with no terminator");
|
||||
|
||||
if (block.empty()) {
|
||||
if (mayNotHaveTerminator(&block))
|
||||
return success();
|
||||
return emitError(block, "empty block: expect at least a terminator");
|
||||
}
|
||||
|
||||
// Verify the non-terminator operations separately so that we can verify
|
||||
// they has no successors.
|
||||
// they have no successors.
|
||||
for (auto &op : llvm::make_range(block.begin(), std::prev(block.end()))) {
|
||||
if (op.getNumSuccessors() != 0)
|
||||
return op.emitError(
|
||||
|
@ -137,8 +156,13 @@ LogicalResult OperationVerifier::verifyBlock(Block &block) {
|
|||
Operation &terminator = block.back();
|
||||
if (failed(verifyOperation(terminator)))
|
||||
return failure();
|
||||
|
||||
if (mayNotHaveTerminator(&block))
|
||||
return success();
|
||||
|
||||
if (!terminator.mightHaveTrait<OpTrait::IsTerminator>())
|
||||
return block.back().emitError("block with no terminator");
|
||||
return block.back().emitError("block with no terminator, has ")
|
||||
<< terminator;
|
||||
|
||||
// Verify that this block is not branching to a block of a different
|
||||
// region.
|
||||
|
@ -176,13 +200,14 @@ LogicalResult OperationVerifier::verifyOperation(Operation &op) {
|
|||
unsigned numRegions = op.getNumRegions();
|
||||
for (unsigned i = 0; i < numRegions; i++) {
|
||||
Region ®ion = op.getRegion(i);
|
||||
RegionKind kind =
|
||||
kindInterface ? kindInterface.getRegionKind(i) : RegionKind::SSACFG;
|
||||
// Check that Graph Regions only have a single basic block. This is
|
||||
// similar to the code in SingleBlockImplicitTerminator, but doesn't
|
||||
// require the trait to be specified. This arbitrary limitation is
|
||||
// designed to limit the number of cases that have to be handled by
|
||||
// transforms and conversions until the concept stabilizes.
|
||||
if (op.isRegistered() && kindInterface &&
|
||||
kindInterface.getRegionKind(i) == RegionKind::Graph) {
|
||||
if (op.isRegistered() && kind == RegionKind::Graph) {
|
||||
// Empty regions are fine.
|
||||
if (region.empty())
|
||||
continue;
|
||||
|
|
|
@ -2121,7 +2121,7 @@ ParseResult TopLevelOperationParser::parse(Block *topLevelBlock,
|
|||
auto &parsedOps = (*topLevelOp)->getRegion(0).front().getOperations();
|
||||
auto &destOps = topLevelBlock->getOperations();
|
||||
destOps.splice(destOps.empty() ? destOps.end() : std::prev(destOps.end()),
|
||||
parsedOps, parsedOps.begin(), std::prev(parsedOps.end()));
|
||||
parsedOps, parsedOps.begin(), parsedOps.end());
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -269,10 +269,11 @@ private:
|
|||
|
||||
/// Globals are inserted before the first function, if any.
|
||||
Block::iterator getGlobalInsertPt() {
|
||||
auto i = module.getBody()->begin();
|
||||
while (!isa<LLVMFuncOp, ModuleTerminatorOp>(i))
|
||||
++i;
|
||||
return i;
|
||||
auto it = module.getBody()->begin();
|
||||
auto endIt = module.getBody()->end();
|
||||
while (it != endIt && !isa<LLVMFuncOp>(it))
|
||||
++it;
|
||||
return it;
|
||||
}
|
||||
|
||||
/// Functions are always inserted before the module terminator.
|
||||
|
|
|
@ -61,8 +61,7 @@ void SymbolDCE::runOnOperation() {
|
|||
if (!nestedSymbolTable->hasTrait<OpTrait::SymbolTable>())
|
||||
return;
|
||||
for (auto &block : nestedSymbolTable->getRegion(0)) {
|
||||
for (Operation &op :
|
||||
llvm::make_early_inc_range(block.without_terminator())) {
|
||||
for (Operation &op : llvm::make_early_inc_range(block)) {
|
||||
if (isa<SymbolOpInterface>(&op) && !liveSymbols.count(&op))
|
||||
op.erase();
|
||||
}
|
||||
|
@ -84,7 +83,7 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
|
|||
// are known to be live.
|
||||
for (auto &block : symbolTableOp->getRegion(0)) {
|
||||
// Add all non-symbols or symbols that can't be discarded.
|
||||
for (Operation &op : block.without_terminator()) {
|
||||
for (Operation &op : block) {
|
||||
SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
|
||||
if (!symbol) {
|
||||
worklist.push_back(&op);
|
||||
|
|
|
@ -314,6 +314,7 @@ static LogicalResult deleteDeadness(RewriterBase &rewriter,
|
|||
for (Region ®ion : regions) {
|
||||
if (region.empty())
|
||||
continue;
|
||||
bool hasSingleBlock = llvm::hasSingleElement(region);
|
||||
|
||||
// Delete every operation that is not live. Graph regions may have cycles
|
||||
// in the use-def graph, so we must explicitly dropAllUses() from each
|
||||
|
@ -321,7 +322,8 @@ static LogicalResult deleteDeadness(RewriterBase &rewriter,
|
|||
// guarantees that in SSA CFG regions value uses are removed before defs,
|
||||
// which makes dropAllUses() a no-op.
|
||||
for (Block *block : llvm::post_order(®ion.front())) {
|
||||
eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap);
|
||||
if (!hasSingleBlock)
|
||||
eraseTerminatorSuccessorOperands(block->getTerminator(), liveMap);
|
||||
for (Operation &childOp :
|
||||
llvm::make_early_inc_range(llvm::reverse(block->getOperations()))) {
|
||||
if (!liveMap.wasProvenLive(&childOp)) {
|
||||
|
|
|
@ -62,7 +62,7 @@ run(testLocationEnterExit)
|
|||
def testInsertionPointEnterExit():
|
||||
ctx1 = Context()
|
||||
m = Module.create(Location.unknown(ctx1))
|
||||
ip = InsertionPoint.at_block_terminator(m.body)
|
||||
ip = InsertionPoint(m.body)
|
||||
|
||||
with ip:
|
||||
assert InsertionPoint.current is ip
|
||||
|
|
|
@ -77,7 +77,7 @@ def testCustomOpView():
|
|||
ctx.allow_unregistered_dialects = True
|
||||
m = Module.create()
|
||||
|
||||
with InsertionPoint.at_block_terminator(m.body):
|
||||
with InsertionPoint(m.body):
|
||||
f32 = F32Type.get()
|
||||
# Create via dialects context collection.
|
||||
input1 = createInput()
|
||||
|
|
|
@ -18,7 +18,7 @@ def testFromPyFunc():
|
|||
m = builtin.ModuleOp()
|
||||
f32 = F32Type.get()
|
||||
f64 = F64Type.get()
|
||||
with InsertionPoint.at_block_terminator(m.body):
|
||||
with InsertionPoint(m.body):
|
||||
# CHECK-LABEL: func @unary_return(%arg0: f64) -> f64
|
||||
# CHECK: return %arg0 : f64
|
||||
@builtin.FuncOp.from_py_func(f64)
|
||||
|
@ -95,7 +95,7 @@ def testFromPyFuncErrors():
|
|||
m = builtin.ModuleOp()
|
||||
f32 = F32Type.get()
|
||||
f64 = F64Type.get()
|
||||
with InsertionPoint.at_block_terminator(m.body):
|
||||
with InsertionPoint(m.body):
|
||||
try:
|
||||
|
||||
@builtin.FuncOp.from_py_func(f64, results=[f64])
|
||||
|
|
|
@ -32,7 +32,7 @@ with Context() as ctx, Location.unknown():
|
|||
i8 = IntegerType.get_signless(8)
|
||||
i16 = IntegerType.get_signless(16)
|
||||
i32 = IntegerType.get_signless(32)
|
||||
with InsertionPoint.at_block_terminator(module.body):
|
||||
with InsertionPoint(module.body):
|
||||
|
||||
# Note that these all have the same indexing maps. We verify the first and
|
||||
# then do more permutation tests on casting and body generation
|
||||
|
|
|
@ -17,7 +17,7 @@ def testStructuredOpOnTensors():
|
|||
module = Module.create()
|
||||
f32 = F32Type.get()
|
||||
tensor_type = RankedTensorType.get((2, 3, 4), f32)
|
||||
with InsertionPoint.at_block_terminator(module.body):
|
||||
with InsertionPoint(module.body):
|
||||
func = builtin.FuncOp(name="matmul_test",
|
||||
type=FunctionType.get(
|
||||
inputs=[tensor_type, tensor_type],
|
||||
|
@ -40,7 +40,7 @@ def testStructuredOpOnBuffers():
|
|||
module = Module.create()
|
||||
f32 = F32Type.get()
|
||||
memref_type = MemRefType.get((2, 3, 4), f32)
|
||||
with InsertionPoint.at_block_terminator(module.body):
|
||||
with InsertionPoint(module.body):
|
||||
func = builtin.FuncOp(name="matmul_test",
|
||||
type=FunctionType.get(
|
||||
inputs=[memref_type, memref_type, memref_type],
|
||||
|
|
|
@ -129,8 +129,13 @@ run(test_insert_at_block_terminator_missing)
|
|||
def test_insert_at_end_with_terminator_errors():
|
||||
with Context() as ctx, Location.unknown():
|
||||
ctx.allow_unregistered_dialects = True
|
||||
m = Module.create() # Module is created with a terminator.
|
||||
with InsertionPoint(m.body):
|
||||
module = Module.parse(r"""
|
||||
func @foo() -> () {
|
||||
return
|
||||
}
|
||||
""")
|
||||
entry_block = module.body.operations[0].regions[0].blocks[0]
|
||||
with InsertionPoint(entry_block):
|
||||
try:
|
||||
Operation.create("custom.op1", results=[], operands=[])
|
||||
except IndexError as e:
|
||||
|
|
|
@ -64,7 +64,6 @@ def testTraverseOpRegionBlockIterators():
|
|||
# CHECK: BLOCK 0:
|
||||
# CHECK: OP 0: %0 = "custom.addi"
|
||||
# CHECK: OP 1: return
|
||||
# CHECK: OP 1: module_terminator
|
||||
walk_operations("", op)
|
||||
|
||||
run(testTraverseOpRegionBlockIterators)
|
||||
|
@ -101,7 +100,6 @@ def testTraverseOpRegionBlockIndices():
|
|||
# CHECK: BLOCK 0:
|
||||
# CHECK: OP 0: %0 = "custom.addi"
|
||||
# CHECK: OP 1: return
|
||||
# CHECK: OP 1: module_terminator
|
||||
walk_operations("", module.operation)
|
||||
|
||||
run(testTraverseOpRegionBlockIndices)
|
||||
|
@ -546,9 +544,9 @@ run(testSingleResultProperty)
|
|||
def testPrintInvalidOperation():
|
||||
ctx = Context()
|
||||
with Location.unknown(ctx):
|
||||
module = Operation.create("module", regions=1)
|
||||
# This block does not have a terminator, it may crash the custom printer.
|
||||
# Verify that we fallback to the generic printer for safety.
|
||||
module = Operation.create("module", regions=2)
|
||||
# This module has two region and is invalid verify that we fallback
|
||||
# to the generic printer for safety.
|
||||
block = module.regions[0].blocks.append()
|
||||
# CHECK: // Verification failed, printing generic form
|
||||
# CHECK: "module"() ( {
|
||||
|
|
|
@ -29,7 +29,7 @@ def testOdsBuildDefaultImplicitRegions():
|
|||
with Context() as ctx, Location.unknown():
|
||||
ctx.allow_unregistered_dialects = True
|
||||
m = Module.create()
|
||||
with InsertionPoint.at_block_terminator(m.body):
|
||||
with InsertionPoint(m.body):
|
||||
op = TestFixedRegionsOp.build_generic(results=[], operands=[])
|
||||
# CHECK: NUM_REGIONS: 2
|
||||
print(f"NUM_REGIONS: {len(op.regions)}")
|
||||
|
@ -84,7 +84,7 @@ def testOdsBuildDefaultNonVariadic():
|
|||
with Context() as ctx, Location.unknown():
|
||||
ctx.allow_unregistered_dialects = True
|
||||
m = Module.create()
|
||||
with InsertionPoint.at_block_terminator(m.body):
|
||||
with InsertionPoint(m.body):
|
||||
v0 = add_dummy_value()
|
||||
v1 = add_dummy_value()
|
||||
t0 = IntegerType.get_signless(8)
|
||||
|
@ -111,7 +111,7 @@ def testOdsBuildDefaultSizedVariadic():
|
|||
with Context() as ctx, Location.unknown():
|
||||
ctx.allow_unregistered_dialects = True
|
||||
m = Module.create()
|
||||
with InsertionPoint.at_block_terminator(m.body):
|
||||
with InsertionPoint(m.body):
|
||||
v0 = add_dummy_value()
|
||||
v1 = add_dummy_value()
|
||||
v2 = add_dummy_value()
|
||||
|
@ -187,7 +187,7 @@ def testOdsBuildDefaultCastError():
|
|||
with Context() as ctx, Location.unknown():
|
||||
ctx.allow_unregistered_dialects = True
|
||||
m = Module.create()
|
||||
with InsertionPoint.at_block_terminator(m.body):
|
||||
with InsertionPoint(m.body):
|
||||
v0 = add_dummy_value()
|
||||
v1 = add_dummy_value()
|
||||
t0 = IntegerType.get_signless(8)
|
||||
|
|
|
@ -91,6 +91,5 @@ def testRunPipeline():
|
|||
# CHECK: Operations encountered:
|
||||
# CHECK: func , 1
|
||||
# CHECK: module , 1
|
||||
# CHECK: module_terminator , 1
|
||||
# CHECK: std.return , 1
|
||||
run(testRunPipeline)
|
||||
|
|
|
@ -293,7 +293,7 @@ int collectStats(MlirOperation operation) {
|
|||
fprintf(stderr, "Number of op results: %u\n", stats.numOpResults);
|
||||
// clang-format off
|
||||
// CHECK-LABEL: @stats
|
||||
// CHECK: Number of operations: 13
|
||||
// CHECK: Number of operations: 12
|
||||
// CHECK: Number of attributes: 4
|
||||
// CHECK: Number of blocks: 3
|
||||
// CHECK: Number of regions: 3
|
||||
|
|
|
@ -42,7 +42,6 @@ void testRunPassOnModule() {
|
|||
// Run the print-op-stats pass on the top-level module:
|
||||
// CHECK-LABEL: Operations encountered:
|
||||
// CHECK: func , 1
|
||||
// CHECK: module_terminator , 1
|
||||
// CHECK: std.addi , 1
|
||||
// CHECK: std.return , 1
|
||||
{
|
||||
|
@ -84,7 +83,6 @@ void testRunPassOnNestedModule() {
|
|||
|
||||
// Run the print-op-stats pass on functions under the top-level module:
|
||||
// CHECK-LABEL: Operations encountered:
|
||||
// CHECK-NOT: module_terminator
|
||||
// CHECK: func , 1
|
||||
// CHECK: std.addi , 1
|
||||
// CHECK: std.return , 1
|
||||
|
@ -101,7 +99,6 @@ void testRunPassOnNestedModule() {
|
|||
}
|
||||
// Run the print-op-stats pass on functions under the nested module:
|
||||
// CHECK-LABEL: Operations encountered:
|
||||
// CHECK-NOT: module_terminator
|
||||
// CHECK: func , 1
|
||||
// CHECK: std.addf , 1
|
||||
// CHECK: std.return , 1
|
||||
|
|
|
@ -19,31 +19,12 @@ func @module_op() {
|
|||
// expected-error@+1 {{region should have no arguments}}
|
||||
module {
|
||||
^bb1(%arg: i32):
|
||||
"module_terminator"() : () -> ()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @module_op() {
|
||||
// expected-error@below {{expects regions to end with 'module_terminator'}}
|
||||
// expected-note@below {{the absence of terminator implies 'module_terminator'}}
|
||||
module {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @module_op() {
|
||||
// expected-error@+1 {{expects parent op 'module'}}
|
||||
"module_terminator"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error@+1 {{can only contain attributes with dialect-prefixed names}}
|
||||
module attributes {attr} {
|
||||
}
|
||||
|
|
|
@ -120,7 +120,7 @@ func @block_redef() {
|
|||
|
||||
// -----
|
||||
|
||||
func @no_terminator() { // expected-error {{block with no terminator}}
|
||||
func @no_terminator() { // expected-error {{empty block: expect at least a terminator}}
|
||||
^bb40:
|
||||
return
|
||||
^bb41:
|
||||
|
|
|
@ -4,16 +4,14 @@
|
|||
module {
|
||||
}
|
||||
|
||||
// CHECK: module {
|
||||
// CHECK-NEXT: }
|
||||
module {
|
||||
"module_terminator"() : () -> ()
|
||||
}
|
||||
// -----
|
||||
|
||||
// CHECK: module attributes {foo.attr = true} {
|
||||
module attributes {foo.attr = true} {
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: module {
|
||||
module {
|
||||
// CHECK-NEXT: "foo.result_op"() : () -> i32
|
||||
|
|
|
@ -18,8 +18,6 @@
|
|||
// CHECK: Has 0 results:
|
||||
// CHECK: Visiting op 'dialect.op3' with 0 operands:
|
||||
// CHECK: Has 0 results:
|
||||
// CHECK: Visiting op 'module_terminator' with 0 operands:
|
||||
// CHECK: Has 0 results:
|
||||
// CHECK: Visiting op 'module' with 0 operands:
|
||||
// CHECK: Has 0 results:
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
// CHECK: visiting op: 'module' with 0 operands and 0 results
|
||||
// CHECK: 1 nested regions:
|
||||
// CHECK: Region with 1 blocks:
|
||||
// CHECK: Block with 0 arguments, 0 successors, and 3 operations
|
||||
// CHECK: Block with 0 arguments, 0 successors, and 2 operations
|
||||
module {
|
||||
|
||||
|
||||
|
@ -52,6 +52,4 @@ module {
|
|||
"dialect.innerop7"() : () -> ()
|
||||
}) : () -> ()
|
||||
|
||||
// CHECK: visiting op: 'module_terminator' with 0 operands and 0 results
|
||||
|
||||
} // module
|
||||
|
|
|
@ -73,3 +73,11 @@ func @named_region_has_wrong_number_of_blocks() {
|
|||
}) : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Region with single block and not terminator.
|
||||
// CHECK: unregistered_without_terminator
|
||||
"test.unregistered_without_terminator"() ( {
|
||||
^bb0: // no predecessors
|
||||
}) : () -> ()
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
// RUN: mlir-opt -allow-unregistered-dialect -test-legalize-patterns -verify-diagnostics -test-legalize-mode=analysis %s | FileCheck %s
|
||||
// expected-remark@-2 {{op 'module' is legalizable}}
|
||||
// expected-remark@-3 {{op 'module_terminator' is legalizable}}
|
||||
|
||||
// expected-remark@+1 {{op 'func' is legalizable}}
|
||||
func @test(%arg0: f32) {
|
||||
|
|
|
@ -33,6 +33,16 @@ void mlir::test::registerTestDialect(DialectRegistry ®istry) {
|
|||
|
||||
namespace {
|
||||
|
||||
/// Testing the correctness of some traits.
|
||||
static_assert(
|
||||
llvm::is_detected<OpTrait::has_implicit_terminator_t,
|
||||
SingleBlockImplicitTerminatorOp>::value,
|
||||
"has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
|
||||
static_assert(OpTrait::hasSingleBlockImplicitTerminator<
|
||||
SingleBlockImplicitTerminatorOp>::value,
|
||||
"hasSingleBlockImplicitTerminator does not match "
|
||||
"SingleBlockImplicitTerminatorOp");
|
||||
|
||||
// Test support for interacting with the AsmPrinter.
|
||||
struct TestOpAsmInterface : public OpAsmDialectInterface {
|
||||
using OpAsmDialectInterface::OpAsmDialectInterface;
|
||||
|
|
|
@ -573,7 +573,7 @@ struct TestLegalizePatternDriver
|
|||
|
||||
// Define the conversion target used for the test.
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||
target.addLegalOp<ModuleOp>();
|
||||
target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp,
|
||||
TerminatorOp>();
|
||||
target
|
||||
|
@ -702,7 +702,7 @@ struct TestRemappedValue
|
|||
patterns.add<OneVResOneVOperandOp1Converter>(&getContext());
|
||||
|
||||
mlir::ConversionTarget target(getContext());
|
||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp, FuncOp, TestReturnOp>();
|
||||
target.addLegalOp<ModuleOp, FuncOp, TestReturnOp>();
|
||||
// We make OneVResOneVOperandOp1 legal only when it has more that one
|
||||
// operand. This will trigger the conversion that will replace one-operand
|
||||
// OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1.
|
||||
|
@ -969,9 +969,8 @@ struct TestMergeBlocksPatternDriver
|
|||
patterns.add<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
|
||||
context);
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp, TerminatorOp,
|
||||
TestBranchOp, TestTypeConsumerOp, TestTypeProducerOp,
|
||||
TestReturnOp>();
|
||||
target.addLegalOp<FuncOp, ModuleOp, TerminatorOp, TestBranchOp,
|
||||
TestTypeConsumerOp, TestTypeProducerOp, TestReturnOp>();
|
||||
target.addIllegalOp<ILLegalOpF>();
|
||||
|
||||
/// Expect the op to have a single block after legalization.
|
||||
|
|
|
@ -56,7 +56,7 @@ void TestConvVectorization::runOnOperation() {
|
|||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<AffineDialect, scf::SCFDialect, StandardOpsDialect,
|
||||
VectorDialect>();
|
||||
target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ReturnOp>();
|
||||
target.addLegalOp<ModuleOp, FuncOp, ReturnOp>();
|
||||
target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
|
||||
|
||||
SmallVector<RewritePatternSet, 4> stage1Patterns;
|
||||
|
|
|
@ -449,6 +449,14 @@ struct OperationFormat {
|
|||
llvm::any_of(op.getTraits(), [](const OpTrait &trait) {
|
||||
return trait.getDef().isSubClassOf("SingleBlockImplicitTerminator");
|
||||
});
|
||||
|
||||
hasSingleBlockTrait =
|
||||
hasImplicitTermTrait ||
|
||||
llvm::any_of(op.getTraits(), [](const OpTrait &trait) {
|
||||
if (auto *native = dyn_cast<NativeOpTrait>(&trait))
|
||||
return native->getTrait() == "::mlir::OpTrait::SingleBlock";
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
/// Generate the operation parser from this format.
|
||||
|
@ -484,6 +492,9 @@ struct OperationFormat {
|
|||
/// trait.
|
||||
bool hasImplicitTermTrait;
|
||||
|
||||
/// A flag indicating if this operation has the SingleBlock trait.
|
||||
bool hasSingleBlockTrait;
|
||||
|
||||
/// A map of buildable types to indices.
|
||||
llvm::MapVector<StringRef, int, llvm::StringMap<int>> buildableTypes;
|
||||
|
||||
|
@ -679,6 +690,14 @@ const char *regionListEnsureTerminatorParserCode = R"(
|
|||
ensureTerminator(*region, parser.getBuilder(), result.location);
|
||||
)";
|
||||
|
||||
/// The code snippet used to ensure a list of regions have a block.
|
||||
///
|
||||
/// {0}: The name of the region list.
|
||||
const char *regionListEnsureSingleBlockParserCode = R"(
|
||||
for (auto ®ion : {0}Regions)
|
||||
if (region.empty()) *{0}Region.emplaceBlock();
|
||||
)";
|
||||
|
||||
/// The code snippet used to generate a parser call for an optional region.
|
||||
///
|
||||
/// {0}: The name of the region.
|
||||
|
@ -705,6 +724,13 @@ const char *regionEnsureTerminatorParserCode = R"(
|
|||
ensureTerminator(*{0}Region, parser.getBuilder(), result.location);
|
||||
)";
|
||||
|
||||
/// The code snippet used to ensure a region has a block.
|
||||
///
|
||||
/// {0}: The name of the region.
|
||||
const char *regionEnsureSingleBlockParserCode = R"(
|
||||
if ({0}Region->empty()) {0}Region->emplaceBlock();
|
||||
)";
|
||||
|
||||
/// The code snippet used to generate a parser call for a successor list.
|
||||
///
|
||||
/// {0}: The name for the successor list.
|
||||
|
@ -1134,6 +1160,9 @@ void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
|
|||
body << " if (!" << region->name << "Region->empty()) {\n ";
|
||||
if (hasImplicitTermTrait)
|
||||
body << llvm::formatv(regionEnsureTerminatorParserCode, region->name);
|
||||
else if (hasSingleBlockTrait)
|
||||
body << llvm::formatv(regionEnsureSingleBlockParserCode,
|
||||
region->name);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1217,11 +1246,14 @@ void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
|
|||
bool isVariadic = region->getVar()->isVariadic();
|
||||
body << llvm::formatv(isVariadic ? regionListParserCode : regionParserCode,
|
||||
region->getVar()->name);
|
||||
if (hasImplicitTermTrait) {
|
||||
if (hasImplicitTermTrait)
|
||||
body << llvm::formatv(isVariadic ? regionListEnsureTerminatorParserCode
|
||||
: regionEnsureTerminatorParserCode,
|
||||
region->getVar()->name);
|
||||
}
|
||||
else if (hasSingleBlockTrait)
|
||||
body << llvm::formatv(isVariadic ? regionListEnsureSingleBlockParserCode
|
||||
: regionEnsureSingleBlockParserCode,
|
||||
region->getVar()->name);
|
||||
|
||||
} else if (auto *successor = dyn_cast<SuccessorVariable>(element)) {
|
||||
bool isVariadic = successor->getVar()->isVariadic();
|
||||
|
@ -1246,6 +1278,8 @@ void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
|
|||
body << llvm::formatv(regionListParserCode, "full");
|
||||
if (hasImplicitTermTrait)
|
||||
body << llvm::formatv(regionListEnsureTerminatorParserCode, "full");
|
||||
else if (hasSingleBlockTrait)
|
||||
body << llvm::formatv(regionListEnsureSingleBlockParserCode, "full");
|
||||
|
||||
} else if (isa<SuccessorsDirective>(element)) {
|
||||
body << llvm::formatv(successorListParserCode, "full");
|
||||
|
|
Loading…
Reference in New Issue