forked from OSchip/llvm-project
[mlir][Symbol] Change Symbol from a Trait into an OpInterface.
This provides a much cleaner interface into Symbols, and allows for users to start injecting op-specific information. For example, derived op can now inject when a symbol can be discarded if use_empty. This would let us drop unused external functions, which generally have public visibility. This revision also adds a new `extraTraitClassDeclaration` field to ODS OpInterface to allow for injecting declarations into the trait class that gets attached to the operations. Differential Revision: https://reviews.llvm.org/D78522
This commit is contained in:
parent
21acc0612a
commit
7c221a7d4f
|
@ -14,6 +14,7 @@
|
|||
#ifndef FIR_DIALECT_FIR_OPS
|
||||
#define FIR_DIALECT_FIR_OPS
|
||||
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffects.td"
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
include "mlir/Dialect/GPU/GPUBase.td"
|
||||
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffects.td"
|
||||
|
||||
// Type constraint accepting standard integers, indices and wrapped LLVM integer
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#define LLVMIR_OPS
|
||||
|
||||
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffects.td"
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#define SPIRV_STRUCTURE_OPS
|
||||
|
||||
include "mlir/Dialect/SPIRV/SPIRVBase.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
include "mlir/Interfaces/CallInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffects.td"
|
||||
|
||||
|
|
|
@ -2,3 +2,8 @@ set(LLVM_TARGET_DEFINITIONS OpAsmInterface.td)
|
|||
mlir_tablegen(OpAsmInterface.h.inc -gen-op-interface-decls)
|
||||
mlir_tablegen(OpAsmInterface.cpp.inc -gen-op-interface-defs)
|
||||
add_public_tablegen_target(MLIROpAsmInterfacesIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS SymbolInterfaces.td)
|
||||
mlir_tablegen(SymbolInterfaces.h.inc -gen-op-interface-decls)
|
||||
mlir_tablegen(SymbolInterfaces.cpp.inc -gen-op-interface-defs)
|
||||
add_public_tablegen_target(MLIRSymbolInterfacesIncGen)
|
||||
|
|
|
@ -30,11 +30,10 @@ namespace mlir {
|
|||
/// implicitly capture global values, and all external references must use
|
||||
/// Function arguments or attributes that establish a symbolic connection(e.g.
|
||||
/// symbols referenced by name via a string attribute).
|
||||
class FuncOp
|
||||
: public Op<FuncOp, OpTrait::ZeroOperands, OpTrait::ZeroResult,
|
||||
OpTrait::IsIsolatedFromAbove, OpTrait::Symbol,
|
||||
OpTrait::FunctionLike, OpTrait::AutomaticAllocationScope,
|
||||
CallableOpInterface::Trait> {
|
||||
class FuncOp : public Op<FuncOp, OpTrait::ZeroOperands, OpTrait::ZeroResult,
|
||||
OpTrait::IsIsolatedFromAbove, OpTrait::FunctionLike,
|
||||
OpTrait::AutomaticAllocationScope,
|
||||
CallableOpInterface::Trait, SymbolOpInterface::Trait> {
|
||||
public:
|
||||
using Op::Op;
|
||||
using Op::print;
|
||||
|
|
|
@ -31,7 +31,8 @@ class ModuleOp
|
|||
: public Op<
|
||||
ModuleOp, OpTrait::ZeroOperands, OpTrait::ZeroResult,
|
||||
OpTrait::IsIsolatedFromAbove, OpTrait::SymbolTable,
|
||||
OpTrait::SingleBlockImplicitTerminator<ModuleTerminatorOp>::Impl> {
|
||||
OpTrait::SingleBlockImplicitTerminator<ModuleTerminatorOp>::Impl,
|
||||
SymbolOpInterface::Trait> {
|
||||
public:
|
||||
using Op::Op;
|
||||
using Op::print;
|
||||
|
@ -95,6 +96,13 @@ public:
|
|||
insertPt = Block::iterator(body->getTerminator());
|
||||
body->getOperations().insert(insertPt, op);
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// SymbolOpInterface Methods
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// A ModuleOp may optionally define a symbol.
|
||||
bool isOptionalSymbol() { return true; }
|
||||
};
|
||||
|
||||
/// The ModuleTerminatorOp is a special terminator operation for the body of a
|
||||
|
|
|
@ -1658,10 +1658,6 @@ def SameOperandsElementType : NativeOpTrait<"SameOperandsElementType">;
|
|||
// Op has the same operand and result element type (or type itself, if scalar).
|
||||
def SameOperandsAndResultElementType :
|
||||
NativeOpTrait<"SameOperandsAndResultElementType">;
|
||||
// Op is a symbol.
|
||||
def Symbol : NativeOpTrait<"Symbol">;
|
||||
// Op defines a symbol table.
|
||||
def SymbolTable : NativeOpTrait<"SymbolTable">;
|
||||
// Op is a terminator.
|
||||
def Terminator : NativeOpTrait<"IsTerminator">;
|
||||
|
||||
|
@ -1721,6 +1717,10 @@ class OpInterfaceTrait<string name, code verifyBody = [{}]> : NativeOpTrait<"">
|
|||
// Specify the body of the verification function. `$_op` will be replaced with
|
||||
// the operation being verified.
|
||||
code verify = verifyBody;
|
||||
|
||||
// An optional code block containing extra declarations to place in the
|
||||
// interface trait declaration.
|
||||
code extraTraitClassDeclaration = "";
|
||||
}
|
||||
|
||||
// This class represents a single, optionally static, interface method.
|
||||
|
|
|
@ -1359,6 +1359,7 @@ class OpInterface : public Op<ConcreteType> {
|
|||
public:
|
||||
using Concept = typename Traits::Concept;
|
||||
template <typename T> using Model = typename Traits::template Model<T>;
|
||||
using Base = OpInterface<ConcreteType, Traits>;
|
||||
|
||||
OpInterface(Operation *op = nullptr)
|
||||
: Op<ConcreteType>(op), impl(op ? getInterfaceFor(op) : nullptr) {
|
||||
|
|
|
@ -0,0 +1,155 @@
|
|||
//===- SymbolInterfaces.td - Interfaces for symbol ops -----*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file contains a set of interfaces and traits that can be used to define
|
||||
// properties of symbol and symbol table operations.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_IR_SYMBOLINTERFACES
|
||||
#define MLIR_IR_SYMBOLINTERFACES
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SymbolOpInterface
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Symbol : OpInterface<"SymbolOpInterface"> {
|
||||
let description = [{
|
||||
This interface describes an operation that may define a `Symbol`. A `Symbol`
|
||||
operation resides immediately within a region that defines a `SymbolTable`.
|
||||
See [Symbols and SymbolTables](SymbolsAndSymbolTables.md) for more details
|
||||
and constraints on `Symbol` operations.
|
||||
}];
|
||||
|
||||
let methods = [
|
||||
InterfaceMethod<"Returns the name of this symbol.",
|
||||
"StringRef", "getName", (ins), [{
|
||||
// Don't rely on the trait implementation as optional symbol operations
|
||||
// may override this.
|
||||
return mlir::SymbolTable::getSymbolName(op);
|
||||
}], /*defaultImplementation=*/[{
|
||||
return mlir::SymbolTable::getSymbolName(this->getOperation());
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<"Sets the name of this symbol.",
|
||||
"void", "setName", (ins "StringRef":$name), [{}],
|
||||
/*defaultImplementation=*/[{
|
||||
this->getOperation()->setAttr(
|
||||
mlir::SymbolTable::getSymbolAttrName(),
|
||||
StringAttr::get(name, this->getOperation()->getContext()));
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<"Gets the visibility of this symbol.",
|
||||
"mlir::SymbolTable::Visibility", "getVisibility", (ins), [{}],
|
||||
/*defaultImplementation=*/[{
|
||||
return mlir::SymbolTable::getSymbolVisibility(this->getOperation());
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<"Sets the visibility of this symbol.",
|
||||
"void", "setVisibility", (ins "mlir::SymbolTable::Visibility":$vis), [{}],
|
||||
/*defaultImplementation=*/[{
|
||||
mlir::SymbolTable::setSymbolVisibility(this->getOperation(), vis);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<[{
|
||||
Get all of the uses of the current symbol that are nested within the
|
||||
given operation 'from'.
|
||||
Note: See mlir::SymbolTable::getSymbolUses for more details.
|
||||
}],
|
||||
"Optional<::mlir::SymbolTable::UseRange>", "getSymbolUses",
|
||||
(ins "Operation *":$from), [{}],
|
||||
/*defaultImplementation=*/[{
|
||||
return ::mlir::SymbolTable::getSymbolUses(this->getOperation(), from);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<[{
|
||||
Return if the current symbol is known to have no uses that are nested
|
||||
within the given operation 'from'.
|
||||
Note: See mlir::SymbolTable::symbolKnownUseEmpty for more details.
|
||||
}],
|
||||
"bool", "symbolKnownUseEmpty", (ins "Operation *":$from), [{}],
|
||||
/*defaultImplementation=*/[{
|
||||
return ::mlir::SymbolTable::symbolKnownUseEmpty(this->getOperation(),
|
||||
from);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<[{
|
||||
Attempt to replace all uses of the current symbol with the provided
|
||||
symbol 'newSymbol' that are nested within the given operation 'from'.
|
||||
Note: See mlir::SymbolTable::replaceAllSymbolUses for more details.
|
||||
}],
|
||||
"LogicalResult", "replaceAllSymbolUses", (ins "StringRef":$newSymbol,
|
||||
"Operation *":$from), [{}],
|
||||
/*defaultImplementation=*/[{
|
||||
return ::mlir::SymbolTable::replaceAllSymbolUses(this->getOperation(),
|
||||
newSymbol, from);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<[{
|
||||
Returns true if this operation optionally defines a symbol based on the
|
||||
presence of the symbol name.
|
||||
}],
|
||||
"bool", "isOptionalSymbol", (ins), [{}],
|
||||
/*defaultImplementation=*/[{ return false; }]
|
||||
>,
|
||||
InterfaceMethod<[{
|
||||
Returns true if this operation can be discarded if it has no remaining
|
||||
symbol uses.
|
||||
}],
|
||||
"bool", "canDiscardOnUseEmpty", (ins), [{}],
|
||||
/*defaultImplementation=*/[{
|
||||
// By default, base this on the visibility alone. A symbol can be
|
||||
// discarded as long as it is not public. Only public symbols may be
|
||||
// visible from outside of the IR.
|
||||
return getVisibility() != ::mlir::SymbolTable::Visibility::Public;
|
||||
}]
|
||||
>,
|
||||
];
|
||||
|
||||
let verify = [{
|
||||
// If this is an optional symbol, bail out early if possible.
|
||||
auto concreteOp = cast<ConcreteOp>($_op);
|
||||
if (concreteOp.isOptionalSymbol()) {
|
||||
if(!concreteOp.getAttr(::mlir::SymbolTable::getSymbolAttrName()))
|
||||
return success();
|
||||
}
|
||||
return ::mlir::detail::verifySymbol($_op);
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
using Visibility = mlir::SymbolTable::Visibility;
|
||||
|
||||
/// Custom classof that handles the case where the symbol is optional.
|
||||
static bool classof(Operation *op) {
|
||||
return Base::classof(op)
|
||||
&& op->getAttr(::mlir::SymbolTable::getSymbolAttrName());
|
||||
}
|
||||
|
||||
/// Returns true if this symbol has nested visibility.
|
||||
bool isNested() { return getVisibility() == Visibility::Nested; }
|
||||
/// Returns true if this symbol has private visibility.
|
||||
bool isPrivate() { return getVisibility() == Visibility::Private; }
|
||||
/// Returns true if this symbol has public visibility.
|
||||
bool isPublic() { return getVisibility() == Visibility::Public; }
|
||||
}];
|
||||
|
||||
let extraTraitClassDeclaration = [{
|
||||
using Visibility = mlir::SymbolTable::Visibility;
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Symbol Traits
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Op defines a symbol table.
|
||||
def SymbolTable : NativeOpTrait<"SymbolTable">;
|
||||
|
||||
#endif // MLIR_IR_SYMBOLINTERFACES
|
|
@ -72,9 +72,6 @@ public:
|
|||
Nested,
|
||||
};
|
||||
|
||||
/// Returns true if the given operation defines a symbol.
|
||||
static bool isSymbol(Operation *op);
|
||||
|
||||
/// Returns the name of the given symbol operation.
|
||||
static StringRef getSymbolName(Operation *symbol);
|
||||
/// Sets the name of the given symbol operation.
|
||||
|
@ -207,12 +204,12 @@ private:
|
|||
// SymbolTable Trait Types
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace OpTrait {
|
||||
namespace impl {
|
||||
namespace detail {
|
||||
LogicalResult verifySymbolTable(Operation *op);
|
||||
LogicalResult verifySymbol(Operation *op);
|
||||
} // namespace impl
|
||||
} // namespace detail
|
||||
|
||||
namespace OpTrait {
|
||||
/// A trait used to provide symbol table functionalities to a region operation.
|
||||
/// This operation must hold exactly 1 region. Once attached, all operations
|
||||
/// that are directly within the region, i.e not including those within child
|
||||
|
@ -224,7 +221,7 @@ template <typename ConcreteType>
|
|||
class SymbolTable : public TraitBase<ConcreteType, SymbolTable> {
|
||||
public:
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return impl::verifySymbolTable(op);
|
||||
return ::mlir::detail::verifySymbolTable(op);
|
||||
}
|
||||
|
||||
/// Look up a symbol with the specified name, returning null if no such
|
||||
|
@ -245,68 +242,11 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// A trait used to define a symbol that can be used on operations within a
|
||||
/// symbol table. Operations using this trait must adhere to the following:
|
||||
/// * Have a StringAttr attribute named 'SymbolTable::getSymbolAttrName()'.
|
||||
template <typename ConcreteType>
|
||||
class Symbol : public TraitBase<ConcreteType, Symbol> {
|
||||
public:
|
||||
using Visibility = mlir::SymbolTable::Visibility;
|
||||
|
||||
static LogicalResult verifyTrait(Operation *op) {
|
||||
return impl::verifySymbol(op);
|
||||
}
|
||||
|
||||
/// Returns the name of this symbol.
|
||||
StringRef getName() {
|
||||
return this->getOperation()
|
||||
->template getAttrOfType<StringAttr>(
|
||||
mlir::SymbolTable::getSymbolAttrName())
|
||||
.getValue();
|
||||
}
|
||||
|
||||
/// Set the name of this symbol.
|
||||
void setName(StringRef name) {
|
||||
this->getOperation()->setAttr(
|
||||
mlir::SymbolTable::getSymbolAttrName(),
|
||||
StringAttr::get(name, this->getOperation()->getContext()));
|
||||
}
|
||||
|
||||
/// Returns the visibility of the current symbol.
|
||||
Visibility getVisibility() {
|
||||
return mlir::SymbolTable::getSymbolVisibility(this->getOperation());
|
||||
}
|
||||
|
||||
/// Sets the visibility of the current symbol.
|
||||
void setVisibility(Visibility vis) {
|
||||
mlir::SymbolTable::setSymbolVisibility(this->getOperation(), vis);
|
||||
}
|
||||
|
||||
/// Get all of the uses of the current symbol that are nested within the given
|
||||
/// operation 'from'.
|
||||
/// Note: See mlir::SymbolTable::getSymbolUses for more details.
|
||||
Optional<::mlir::SymbolTable::UseRange> getSymbolUses(Operation *from) {
|
||||
return ::mlir::SymbolTable::getSymbolUses(this->getOperation(), from);
|
||||
}
|
||||
|
||||
/// Return if the current symbol is known to have no uses that are nested
|
||||
/// within the given operation 'from'.
|
||||
/// Note: See mlir::SymbolTable::symbolKnownUseEmpty for more details.
|
||||
bool symbolKnownUseEmpty(Operation *from) {
|
||||
return ::mlir::SymbolTable::symbolKnownUseEmpty(this->getOperation(), from);
|
||||
}
|
||||
|
||||
/// Attempt to replace all uses of the current symbol with the provided symbol
|
||||
/// 'newSymbol' that are nested within the given operation 'from'.
|
||||
/// Note: See mlir::SymbolTable::replaceAllSymbolUses for more details.
|
||||
LLVM_NODISCARD LogicalResult replaceAllSymbolUses(StringRef newSymbol,
|
||||
Operation *from) {
|
||||
return ::mlir::SymbolTable::replaceAllSymbolUses(this->getOperation(),
|
||||
newSymbol, from);
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace OpTrait
|
||||
|
||||
/// Include the generated symbol interfaces.
|
||||
#include "mlir/IR/SymbolInterfaces.h.inc"
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_IR_SYMBOLTABLE_H
|
||||
|
|
|
@ -89,6 +89,9 @@ public:
|
|||
// Return the interfaces extra class declaration code.
|
||||
llvm::Optional<StringRef> getExtraClassDeclaration() const;
|
||||
|
||||
// Return the traits extra class declaration code.
|
||||
llvm::Optional<StringRef> getExtraTraitClassDeclaration() const;
|
||||
|
||||
// Return the verify method body if it has one.
|
||||
llvm::Optional<StringRef> getVerify() const;
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ add_mlir_library(MLIRIR
|
|||
DEPENDS
|
||||
MLIRCallInterfacesIncGen
|
||||
MLIROpAsmInterfacesIncGen
|
||||
MLIRSymbolInterfacesIncGen
|
||||
)
|
||||
target_link_libraries(MLIRIR
|
||||
PUBLIC
|
||||
|
|
|
@ -146,11 +146,6 @@ void SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
|
|||
setSymbolName(symbol, nameBuffer);
|
||||
}
|
||||
|
||||
/// Returns true if the given operation defines a symbol.
|
||||
bool SymbolTable::isSymbol(Operation *op) {
|
||||
return op->hasTrait<OpTrait::Symbol>() || getNameIfSymbol(op).hasValue();
|
||||
}
|
||||
|
||||
/// Returns the name of the given symbol operation.
|
||||
StringRef SymbolTable::getSymbolName(Operation *symbol) {
|
||||
Optional<StringRef> name = getNameIfSymbol(symbol);
|
||||
|
@ -286,7 +281,7 @@ Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from,
|
|||
// SymbolTable Trait Types
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult OpTrait::impl::verifySymbolTable(Operation *op) {
|
||||
LogicalResult detail::verifySymbolTable(Operation *op) {
|
||||
if (op->getNumRegions() != 1)
|
||||
return op->emitOpError()
|
||||
<< "Operations with a 'SymbolTable' must have exactly one region";
|
||||
|
@ -316,7 +311,7 @@ LogicalResult OpTrait::impl::verifySymbolTable(Operation *op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult OpTrait::impl::verifySymbol(Operation *op) {
|
||||
LogicalResult detail::verifySymbol(Operation *op) {
|
||||
// Verify the name attribute.
|
||||
if (!op->getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName()))
|
||||
return op->emitOpError() << "requires string attribute '"
|
||||
|
@ -866,3 +861,10 @@ LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
|
|||
Region *from) {
|
||||
return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Symbol Interfaces
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Include the generated symbol interfaces.
|
||||
#include "mlir/IR/SymbolInterfaces.cpp.inc"
|
||||
|
|
|
@ -92,6 +92,12 @@ llvm::Optional<StringRef> OpInterface::getExtraClassDeclaration() const {
|
|||
return value.empty() ? llvm::Optional<StringRef>() : value;
|
||||
}
|
||||
|
||||
// Return the traits extra class declaration code.
|
||||
llvm::Optional<StringRef> OpInterface::getExtraTraitClassDeclaration() const {
|
||||
auto value = def->getValueAsString("extraTraitClassDeclaration");
|
||||
return value.empty() ? llvm::Optional<StringRef>() : value;
|
||||
}
|
||||
|
||||
// Return the body for this method if it has one.
|
||||
llvm::Optional<StringRef> OpInterface::getVerify() const {
|
||||
auto value = def->getValueAsString("verify");
|
||||
|
|
|
@ -31,26 +31,6 @@ using namespace mlir;
|
|||
// Symbol Use Tracking
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Returns true if this operation can be discarded if it is a symbol and has no
|
||||
/// uses. 'allUsesVisible' corresponds to if the parent symbol table is hidden
|
||||
/// from above.
|
||||
static bool canDiscardSymbolOnUseEmpty(Operation *op, bool allUsesVisible) {
|
||||
if (!SymbolTable::isSymbol(op))
|
||||
return false;
|
||||
|
||||
// TODO: This is essentially the same logic from SymbolDCE. Remove this when
|
||||
// we have a 'Symbol' interface.
|
||||
// Private symbols are always initially considered dead.
|
||||
SymbolTable::Visibility visibility = SymbolTable::getSymbolVisibility(op);
|
||||
if (visibility == mlir::SymbolTable::Visibility::Private)
|
||||
return true;
|
||||
// We only include nested visibility here if all uses are visible.
|
||||
if (allUsesVisible && visibility == SymbolTable::Visibility::Nested)
|
||||
return true;
|
||||
// Otherwise, public symbols are never removable.
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Walk all of the symbol table operations nested with 'op' along with a
|
||||
/// boolean signifying if the symbols within can be treated as if all uses are
|
||||
/// visible. The provided callback is invoked with the symbol table operation,
|
||||
|
@ -59,9 +39,8 @@ static bool canDiscardSymbolOnUseEmpty(Operation *op, bool allUsesVisible) {
|
|||
static void walkSymbolTables(Operation *op, bool allSymUsesVisible,
|
||||
function_ref<void(Operation *, bool)> callback) {
|
||||
if (op->hasTrait<OpTrait::SymbolTable>()) {
|
||||
allSymUsesVisible = allSymUsesVisible || !SymbolTable::isSymbol(op) ||
|
||||
SymbolTable::getSymbolVisibility(op) ==
|
||||
SymbolTable::Visibility::Private;
|
||||
SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
|
||||
allSymUsesVisible = allSymUsesVisible || !symbol || symbol.isPrivate();
|
||||
callback(op, allSymUsesVisible);
|
||||
} else {
|
||||
// Otherwise if 'op' is not a symbol table, any nested symbols are
|
||||
|
@ -171,8 +150,11 @@ CGUseList::CGUseList(Operation *op, CallGraph &cg) {
|
|||
// If this is a callgraph operation, check to see if it is discardable.
|
||||
if (auto callable = dyn_cast<CallableOpInterface>(&op)) {
|
||||
if (auto *node = cg.lookupNode(callable.getCallableRegion())) {
|
||||
if (canDiscardSymbolOnUseEmpty(&op, allUsesVisible))
|
||||
SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
|
||||
if (symbol && (allUsesVisible || symbol.isPrivate()) &&
|
||||
symbol.canDiscardOnUseEmpty()) {
|
||||
discardableSymNodeUses.try_emplace(node, 0);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
@ -224,7 +206,7 @@ void CGUseList::eraseNode(CallGraphNode *node) {
|
|||
bool CGUseList::isDead(CallGraphNode *node) const {
|
||||
// If the parent operation isn't a symbol, simply check normal SSA deadness.
|
||||
Operation *nodeOp = node->getCallableRegion()->getParentOp();
|
||||
if (!SymbolTable::isSymbol(nodeOp))
|
||||
if (!isa<SymbolOpInterface>(nodeOp))
|
||||
return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->use_empty();
|
||||
|
||||
// Otherwise, check the number of symbol uses.
|
||||
|
@ -235,7 +217,7 @@ bool CGUseList::isDead(CallGraphNode *node) const {
|
|||
bool CGUseList::hasOneUseAndDiscardable(CallGraphNode *node) const {
|
||||
// If this isn't a symbol node, check for side-effects and SSA use count.
|
||||
Operation *nodeOp = node->getCallableRegion()->getParentOp();
|
||||
if (!SymbolTable::isSymbol(nodeOp))
|
||||
if (!isa<SymbolOpInterface>(nodeOp))
|
||||
return MemoryEffectOpInterface::hasNoEffect(nodeOp) && nodeOp->hasOneUse();
|
||||
|
||||
// Otherwise, check the number of symbol uses.
|
||||
|
|
|
@ -43,10 +43,9 @@ void SymbolDCE::runOnOperation() {
|
|||
// A flag that signals if the top level symbol table is hidden, i.e. not
|
||||
// accessible from parent scopes.
|
||||
bool symbolTableIsHidden = true;
|
||||
if (symbolTableOp->getParentOp() && SymbolTable::isSymbol(symbolTableOp)) {
|
||||
symbolTableIsHidden = SymbolTable::getSymbolVisibility(symbolTableOp) ==
|
||||
SymbolTable::Visibility::Private;
|
||||
}
|
||||
SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(symbolTableOp);
|
||||
if (symbolTableOp->getParentOp() && symbol)
|
||||
symbolTableIsHidden = symbol.isPrivate();
|
||||
|
||||
// Compute the set of live symbols within the symbol table.
|
||||
DenseSet<Operation *> liveSymbols;
|
||||
|
@ -61,7 +60,7 @@ void SymbolDCE::runOnOperation() {
|
|||
for (auto &block : nestedSymbolTable->getRegion(0)) {
|
||||
for (Operation &op :
|
||||
llvm::make_early_inc_range(block.without_terminator())) {
|
||||
if (SymbolTable::isSymbol(&op) && !liveSymbols.count(&op))
|
||||
if (isa<SymbolOpInterface>(&op) && !liveSymbols.count(&op))
|
||||
op.erase();
|
||||
}
|
||||
}
|
||||
|
@ -80,30 +79,16 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
|
|||
// Walk the symbols within the current symbol table, marking the symbols that
|
||||
// are known to be live.
|
||||
for (auto &block : symbolTableOp->getRegion(0)) {
|
||||
// Add all non-symbols or symbols that can't be discarded.
|
||||
for (Operation &op : block.without_terminator()) {
|
||||
// Always add non symbol operations to the worklist.
|
||||
if (!SymbolTable::isSymbol(&op)) {
|
||||
SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
|
||||
if (!symbol) {
|
||||
worklist.push_back(&op);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check the visibility to see if this symbol may be referenced
|
||||
// externally.
|
||||
SymbolTable::Visibility visibility =
|
||||
SymbolTable::getSymbolVisibility(&op);
|
||||
|
||||
// Private symbols are always initially considered dead.
|
||||
if (visibility == mlir::SymbolTable::Visibility::Private)
|
||||
continue;
|
||||
// We only include nested visibility here if the symbol table isn't
|
||||
// hidden.
|
||||
if (symbolTableIsHidden && visibility == SymbolTable::Visibility::Nested)
|
||||
continue;
|
||||
|
||||
// TODO(riverriddle) Add hooks here to allow symbols to provide additional
|
||||
// information, e.g. linkage can be used to drop some symbols that may
|
||||
// otherwise be considered "live".
|
||||
if (liveSymbols.insert(&op).second)
|
||||
bool isDiscardable = (symbolTableIsHidden || symbol.isPrivate()) &&
|
||||
symbol.canDiscardOnUseEmpty();
|
||||
if (!isDiscardable && liveSymbols.insert(&op).second)
|
||||
worklist.push_back(&op);
|
||||
}
|
||||
}
|
||||
|
@ -117,10 +102,9 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
|
|||
if (op->hasTrait<OpTrait::SymbolTable>()) {
|
||||
// The internal symbol table is hidden if the parent is, if its not a
|
||||
// symbol, or if it is a private symbol.
|
||||
bool symbolIsHidden = symbolTableIsHidden || !SymbolTable::isSymbol(op) ||
|
||||
SymbolTable::getSymbolVisibility(op) ==
|
||||
SymbolTable::Visibility::Private;
|
||||
if (failed(computeLiveness(op, symbolIsHidden, liveSymbols)))
|
||||
SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
|
||||
bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate();
|
||||
if (failed(computeLiveness(op, symIsHidden, liveSymbols)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@
|
|||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/IR/OpAsmInterface.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffects.td"
|
||||
include "mlir/Interfaces/CallInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
|
|
|
@ -66,7 +66,7 @@ struct SymbolUsesPass
|
|||
// Walk nested symbols.
|
||||
SmallVector<FuncOp, 4> deadFunctions;
|
||||
module.getBodyRegion().walk([&](Operation *nestedOp) {
|
||||
if (SymbolTable::isSymbol(nestedOp))
|
||||
if (isa<SymbolOpInterface>(nestedOp))
|
||||
return operateOnSymbol(nestedOp, module, deadFunctions);
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
|
|
@ -174,6 +174,8 @@ static void emitTraitDecl(OpInterface &interface, raw_ostream &os,
|
|||
os << " static LogicalResult verifyTrait(Operation* op) {\n"
|
||||
<< std::string(tblgen::tgfmt(*verify, &traitCtx)) << "\n }\n";
|
||||
}
|
||||
if (auto extraTraitDecls = interface.getExtraTraitClassDeclaration())
|
||||
os << extraTraitDecls << "\n";
|
||||
|
||||
os << " };\n";
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue