forked from OSchip/llvm-project
[SymbolRefAttr] Revise SymbolRefAttr to hold a StringAttr.
SymbolRefAttr is fundamentally a base string plus a sequence of nested references. Instead of storing the string data as a copies StringRef, store it as an already-uniqued StringAttr. This makes a lot of things simpler and more efficient because: 1) references to the symbol are already stored as StringAttr's: there is no need to copy the string data into MLIRContext multiple times. 2) This allows pointer comparisons instead of string comparisons (or redundant uniquing) within SymbolTable.cpp. 3) This allows SymbolTable to hold a DenseMap instead of a StringMap (which again copies the string data and slows lookup). This is a moderately invasive patch, so I kept a lot of compatibility APIs around. It would be nice to explore changing getName() to return a StringAttr for example (right now you have to use getNameAttr()), and eliminate things like the StringRef version of getSymbol. Differential Revision: https://reviews.llvm.org/D108899
This commit is contained in:
parent
3383ec5fdd
commit
41d4aa7de6
|
@ -378,10 +378,10 @@ def GPU_LaunchFuncOp : GPU_Op<"launch_func",
|
|||
unsigned getNumKernelOperands();
|
||||
|
||||
/// The name of the kernel's containing module.
|
||||
StringRef getKernelModuleName();
|
||||
StringAttr getKernelModuleName();
|
||||
|
||||
/// The name of the kernel.
|
||||
StringRef getKernelName();
|
||||
StringAttr getKernelName();
|
||||
|
||||
/// The i-th operand passed to the kernel function.
|
||||
Value getKernelOperand(unsigned i);
|
||||
|
|
|
@ -98,9 +98,16 @@ public:
|
|||
StringAttr getStringAttr(const Twine &bytes);
|
||||
ArrayAttr getArrayAttr(ArrayRef<Attribute> value);
|
||||
FlatSymbolRefAttr getSymbolRefAttr(Operation *value);
|
||||
FlatSymbolRefAttr getSymbolRefAttr(StringRef value);
|
||||
SymbolRefAttr getSymbolRefAttr(StringRef value,
|
||||
FlatSymbolRefAttr getSymbolRefAttr(StringAttr value);
|
||||
SymbolRefAttr getSymbolRefAttr(StringAttr value,
|
||||
ArrayRef<FlatSymbolRefAttr> nestedReferences);
|
||||
SymbolRefAttr getSymbolRefAttr(StringRef value,
|
||||
ArrayRef<FlatSymbolRefAttr> nestedReferences) {
|
||||
return getSymbolRefAttr(getStringAttr(value), nestedReferences);
|
||||
}
|
||||
FlatSymbolRefAttr getSymbolRefAttr(StringRef value) {
|
||||
return getSymbolRefAttr(getStringAttr(value));
|
||||
}
|
||||
|
||||
// Returns a 0-valued attribute of the given `type`. This function only
|
||||
// supports boolean, integer, and 16-/32-/64-bit float types, and vector or
|
||||
|
|
|
@ -30,8 +30,10 @@ class ShapedType;
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace detail {
|
||||
template <typename T> class ElementsAttrIterator;
|
||||
template <typename T> class ElementsAttrRange;
|
||||
template <typename T>
|
||||
class ElementsAttrIterator;
|
||||
template <typename T>
|
||||
class ElementsAttrRange;
|
||||
} // namespace detail
|
||||
|
||||
/// A base attribute that represents a reference to a static shaped tensor or
|
||||
|
@ -39,8 +41,10 @@ template <typename T> class ElementsAttrRange;
|
|||
class ElementsAttr : public Attribute {
|
||||
public:
|
||||
using Attribute::Attribute;
|
||||
template <typename T> using iterator = detail::ElementsAttrIterator<T>;
|
||||
template <typename T> using iterator_range = detail::ElementsAttrRange<T>;
|
||||
template <typename T>
|
||||
using iterator = detail::ElementsAttrIterator<T>;
|
||||
template <typename T>
|
||||
using iterator_range = detail::ElementsAttrRange<T>;
|
||||
|
||||
/// Return the type of this ElementsAttr, guaranteed to be a vector or tensor
|
||||
/// with static shape.
|
||||
|
@ -52,14 +56,16 @@ public:
|
|||
|
||||
/// Return the value of type 'T' at the given index, where 'T' corresponds to
|
||||
/// an Attribute type.
|
||||
template <typename T> T getValue(ArrayRef<uint64_t> index) const {
|
||||
template <typename T>
|
||||
T getValue(ArrayRef<uint64_t> index) const {
|
||||
return getValue(index).template cast<T>();
|
||||
}
|
||||
|
||||
/// Return the elements of this attribute as a value of type 'T'. Note:
|
||||
/// Aborts if the subclass is OpaqueElementsAttrs, these attrs do not support
|
||||
/// iteration.
|
||||
template <typename T> iterator_range<T> getValues() const;
|
||||
template <typename T>
|
||||
iterator_range<T> getValues() const;
|
||||
|
||||
/// Return if the given 'index' refers to a valid element in this attribute.
|
||||
bool isValidIndex(ArrayRef<uint64_t> index) const;
|
||||
|
@ -139,7 +145,8 @@ protected:
|
|||
};
|
||||
|
||||
/// Type trait detector that checks if a given type T is a complex type.
|
||||
template <typename T> struct is_complex_t : public std::false_type {};
|
||||
template <typename T>
|
||||
struct is_complex_t : public std::false_type {};
|
||||
template <typename T>
|
||||
struct is_complex_t<std::complex<T>> : public std::true_type {};
|
||||
} // namespace detail
|
||||
|
@ -154,7 +161,8 @@ public:
|
|||
/// floating point type that can be used to access the underlying element
|
||||
/// types of a DenseElementsAttr.
|
||||
// TODO: Use std::disjunction when C++17 is supported.
|
||||
template <typename T> struct is_valid_cpp_fp_type {
|
||||
template <typename T>
|
||||
struct is_valid_cpp_fp_type {
|
||||
/// The type is a valid floating point type if it is a builtin floating
|
||||
/// point type, or is a potentially user defined floating point type. The
|
||||
/// latter allows for supporting users that have custom types defined for
|
||||
|
@ -423,7 +431,8 @@ public:
|
|||
Attribute getValue(ArrayRef<uint64_t> index) const {
|
||||
return getValue<Attribute>(index);
|
||||
}
|
||||
template <typename T> T getValue(ArrayRef<uint64_t> index) const {
|
||||
template <typename T>
|
||||
T getValue(ArrayRef<uint64_t> index) const {
|
||||
// Skip to the element corresponding to the flattened index.
|
||||
return *std::next(getValues<T>().begin(), getFlattenedIndex(index));
|
||||
}
|
||||
|
@ -680,8 +689,15 @@ public:
|
|||
return SymbolRefAttr::get(ctx, value);
|
||||
}
|
||||
|
||||
static FlatSymbolRefAttr get(StringAttr value) {
|
||||
return SymbolRefAttr::get(value);
|
||||
}
|
||||
|
||||
/// Returns the name of the held symbol reference as a StringAttr.
|
||||
StringAttr getAttr() const { return getRootReference(); }
|
||||
|
||||
/// Returns the name of the held symbol reference.
|
||||
StringRef getValue() const { return getRootReference(); }
|
||||
StringRef getValue() const { return getAttr().getValue(); }
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(Attribute attr) {
|
||||
|
@ -845,22 +861,28 @@ class ElementsAttrIterator
|
|||
}
|
||||
|
||||
/// Utility functors used to generically implement the iterators methods.
|
||||
template <typename ItT> struct PlusAssign {
|
||||
template <typename ItT>
|
||||
struct PlusAssign {
|
||||
void operator()(ItT &it, ptrdiff_t offset) { it += offset; }
|
||||
};
|
||||
template <typename ItT> struct Minus {
|
||||
template <typename ItT>
|
||||
struct Minus {
|
||||
ptrdiff_t operator()(const ItT &lhs, const ItT &rhs) { return lhs - rhs; }
|
||||
};
|
||||
template <typename ItT> struct MinusAssign {
|
||||
template <typename ItT>
|
||||
struct MinusAssign {
|
||||
void operator()(ItT &it, ptrdiff_t offset) { it -= offset; }
|
||||
};
|
||||
template <typename ItT> struct Dereference {
|
||||
template <typename ItT>
|
||||
struct Dereference {
|
||||
T operator()(ItT &it) { return *it; }
|
||||
};
|
||||
template <typename ItT> struct ConstructIter {
|
||||
template <typename ItT>
|
||||
struct ConstructIter {
|
||||
void operator()(ItT &dest, const ItT &it) { ::new (&dest) ItT(it); }
|
||||
};
|
||||
template <typename ItT> struct DestructIter {
|
||||
template <typename ItT>
|
||||
struct DestructIter {
|
||||
void operator()(ItT &it) { it.~ItT(); }
|
||||
};
|
||||
|
||||
|
|
|
@ -881,17 +881,26 @@ def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef"> {
|
|||
@parent_reference::@nested_reference
|
||||
```
|
||||
}];
|
||||
let parameters = (ins
|
||||
StringRefParameter<"">:$rootReference,
|
||||
ArrayRefParameter<"FlatSymbolRefAttr", "">:$nestedReferences
|
||||
);
|
||||
let parameters =
|
||||
(ins "StringAttr":$rootReference,
|
||||
ArrayRefParameter<"FlatSymbolRefAttr", "">:$nestedReferences);
|
||||
|
||||
let builders = [
|
||||
AttrBuilderWithInferredContext<
|
||||
(ins "StringAttr":$rootReference,
|
||||
"ArrayRef<FlatSymbolRefAttr>":$nestedReferences), [{
|
||||
return $_get(rootReference.getContext(), rootReference, nestedReferences);
|
||||
}]>,
|
||||
];
|
||||
let extraClassDeclaration = [{
|
||||
static FlatSymbolRefAttr get(MLIRContext *ctx, StringRef value);
|
||||
static FlatSymbolRefAttr get(StringAttr value);
|
||||
|
||||
/// Returns the name of the fully resolved symbol, i.e. the leaf of the
|
||||
/// reference path.
|
||||
StringRef getLeafReference() const;
|
||||
StringAttr getLeafReference() const;
|
||||
}];
|
||||
let skipDefaultBuilders = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1734,7 +1734,7 @@ def IsNullAttr : AttrConstraint<
|
|||
class ReferToOp<string opClass> : AttrConstraint<
|
||||
CPred<"isa_and_nonnull<" # opClass # ">("
|
||||
"::mlir::SymbolTable::lookupNearestSymbolFrom("
|
||||
"&$_op, $_self.cast<::mlir::FlatSymbolRefAttr>().getValue()))">,
|
||||
"&$_op, $_self.cast<::mlir::FlatSymbolRefAttr>().getAttr()))">,
|
||||
"referencing to a '" # opClass # "' symbol">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -31,7 +31,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
|
|||
|
||||
let methods = [
|
||||
InterfaceMethod<"Returns the name of this symbol.",
|
||||
"StringRef", "getName", (ins), [{
|
||||
"StringAttr", "getNameAttr", (ins), [{
|
||||
// Don't rely on the trait implementation as optional symbol operations
|
||||
// may override this.
|
||||
return mlir::SymbolTable::getSymbolName($_op);
|
||||
|
@ -40,11 +40,10 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
|
|||
}]
|
||||
>,
|
||||
InterfaceMethod<"Sets the name of this symbol.",
|
||||
"void", "setName", (ins "StringRef":$name), [{}],
|
||||
"void", "setName", (ins "StringAttr":$name), [{}],
|
||||
/*defaultImplementation=*/[{
|
||||
this->getOperation()->setAttr(
|
||||
mlir::SymbolTable::getSymbolAttrName(),
|
||||
StringAttr::get(this->getOperation()->getContext(), name));
|
||||
mlir::SymbolTable::getSymbolAttrName(), name);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<"Gets the visibility of this symbol.",
|
||||
|
@ -122,7 +121,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
|
|||
symbol 'newSymbol' that are nested within the given operation 'from'.
|
||||
Note: See mlir::SymbolTable::replaceAllSymbolUses for more details.
|
||||
}],
|
||||
"LogicalResult", "replaceAllSymbolUses", (ins "StringRef":$newSymbol,
|
||||
"LogicalResult", "replaceAllSymbolUses", (ins "StringAttr":$newSymbol,
|
||||
"Operation *":$from), [{}],
|
||||
/*defaultImplementation=*/[{
|
||||
return ::mlir::SymbolTable::replaceAllSymbolUses(this->getOperation(),
|
||||
|
@ -176,6 +175,16 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
|
|||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Convenience version of `getNameAttr` that returns a StringRef.
|
||||
StringRef getName() {
|
||||
return getNameAttr().getValue();
|
||||
}
|
||||
|
||||
/// Convenience version of `setName` that take a StringRef.
|
||||
void setName(StringRef name) {
|
||||
setName(StringAttr::get(this->getContext(), name));
|
||||
}
|
||||
|
||||
/// Custom classof that handles the case where the symbol is optional.
|
||||
static bool classof(Operation *op) {
|
||||
auto *opConcept = getInterfaceFor(op);
|
||||
|
@ -188,6 +197,16 @@ def Symbol : OpInterface<"SymbolOpInterface"> {
|
|||
|
||||
let extraTraitClassDeclaration = [{
|
||||
using Visibility = mlir::SymbolTable::Visibility;
|
||||
|
||||
/// Convenience version of `getNameAttr` that returns a StringRef.
|
||||
StringRef getName() {
|
||||
return getNameAttr().getValue();
|
||||
}
|
||||
|
||||
/// Convenience version of `setName` that take a StringRef.
|
||||
void setName(StringRef name) {
|
||||
setName(StringAttr::get(this->getContext(), name));
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,16 @@ public:
|
|||
/// Look up a symbol with the specified name, returning null if no such
|
||||
/// name exists. Names never include the @ on them.
|
||||
Operation *lookup(StringRef name) const;
|
||||
template <typename T> T lookup(StringRef name) const {
|
||||
template <typename T>
|
||||
T lookup(StringRef name) const {
|
||||
return dyn_cast_or_null<T>(lookup(name));
|
||||
}
|
||||
|
||||
/// Look up a symbol with the specified name, returning null if no such
|
||||
/// name exists. Names never include the @ on them.
|
||||
Operation *lookup(StringAttr name) const;
|
||||
template <typename T>
|
||||
T lookup(StringAttr name) const {
|
||||
return dyn_cast_or_null<T>(lookup(name));
|
||||
}
|
||||
|
||||
|
@ -74,10 +83,15 @@ public:
|
|||
Nested,
|
||||
};
|
||||
|
||||
/// Returns the name of the given symbol operation.
|
||||
static StringRef getSymbolName(Operation *symbol);
|
||||
/// Returns the name of the given symbol operation, aborting if no symbol is
|
||||
/// present.
|
||||
static StringAttr getSymbolName(Operation *symbol);
|
||||
|
||||
/// Sets the name of the given symbol operation.
|
||||
static void setSymbolName(Operation *symbol, StringRef name);
|
||||
static void setSymbolName(Operation *symbol, StringAttr name);
|
||||
static void setSymbolName(Operation *symbol, StringRef name) {
|
||||
setSymbolName(symbol, StringAttr::get(symbol->getContext(), name));
|
||||
}
|
||||
|
||||
/// Returns the visibility of the given symbol operation.
|
||||
static Visibility getSymbolVisibility(Operation *symbol);
|
||||
|
@ -100,7 +114,10 @@ public:
|
|||
/// Returns the operation registered with the given symbol name with the
|
||||
/// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
|
||||
/// with the 'OpTrait::SymbolTable' trait.
|
||||
static Operation *lookupSymbolIn(Operation *op, StringRef symbol);
|
||||
static Operation *lookupSymbolIn(Operation *op, StringAttr symbol);
|
||||
static Operation *lookupSymbolIn(Operation *op, StringRef symbol) {
|
||||
return lookupSymbolIn(op, StringAttr::get(op->getContext(), symbol));
|
||||
}
|
||||
static Operation *lookupSymbolIn(Operation *op, SymbolRefAttr symbol);
|
||||
/// A variant of 'lookupSymbolIn' that returns all of the symbols referenced
|
||||
/// by a given SymbolRefAttr. Returns failure if any of the nested references
|
||||
|
@ -112,11 +129,11 @@ public:
|
|||
/// closest parent operation of, or including, 'from' with the
|
||||
/// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
|
||||
/// found.
|
||||
static Operation *lookupNearestSymbolFrom(Operation *from, StringRef symbol);
|
||||
static Operation *lookupNearestSymbolFrom(Operation *from, StringAttr symbol);
|
||||
static Operation *lookupNearestSymbolFrom(Operation *from,
|
||||
SymbolRefAttr symbol);
|
||||
template <typename T>
|
||||
static T lookupNearestSymbolFrom(Operation *from, StringRef symbol) {
|
||||
static T lookupNearestSymbolFrom(Operation *from, StringAttr symbol) {
|
||||
return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
|
||||
}
|
||||
template <typename T>
|
||||
|
@ -169,9 +186,9 @@ public:
|
|||
/// operation 'from'. This does not traverse into any nested symbol tables.
|
||||
/// This function returns None if there are any unknown operations that may
|
||||
/// potentially be symbol tables.
|
||||
static Optional<UseRange> getSymbolUses(StringRef symbol, Operation *from);
|
||||
static Optional<UseRange> getSymbolUses(StringAttr symbol, Operation *from);
|
||||
static Optional<UseRange> getSymbolUses(Operation *symbol, Operation *from);
|
||||
static Optional<UseRange> getSymbolUses(StringRef symbol, Region *from);
|
||||
static Optional<UseRange> getSymbolUses(StringAttr symbol, Region *from);
|
||||
static Optional<UseRange> getSymbolUses(Operation *symbol, Region *from);
|
||||
|
||||
/// Return if the given symbol is known to have no uses that are nested
|
||||
|
@ -180,9 +197,9 @@ public:
|
|||
/// unknown operations that may potentially be symbol tables. This doesn't
|
||||
/// necessarily mean that there are no uses, we just can't conservatively
|
||||
/// prove it.
|
||||
static bool symbolKnownUseEmpty(StringRef symbol, Operation *from);
|
||||
static bool symbolKnownUseEmpty(StringAttr symbol, Operation *from);
|
||||
static bool symbolKnownUseEmpty(Operation *symbol, Operation *from);
|
||||
static bool symbolKnownUseEmpty(StringRef symbol, Region *from);
|
||||
static bool symbolKnownUseEmpty(StringAttr symbol, Region *from);
|
||||
static bool symbolKnownUseEmpty(Operation *symbol, Region *from);
|
||||
|
||||
/// Attempt to replace all uses of the given symbol 'oldSymbol' with the
|
||||
|
@ -190,23 +207,24 @@ public:
|
|||
/// 'from'. This does not traverse into any nested symbol tables. If there are
|
||||
/// any unknown operations that may potentially be symbol tables, no uses are
|
||||
/// replaced and failure is returned.
|
||||
static LogicalResult replaceAllSymbolUses(StringRef oldSymbol,
|
||||
StringRef newSymbol,
|
||||
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol,
|
||||
StringAttr newSymbol,
|
||||
Operation *from);
|
||||
static LogicalResult replaceAllSymbolUses(Operation *oldSymbol,
|
||||
StringRef newSymbolName,
|
||||
StringAttr newSymbolName,
|
||||
Operation *from);
|
||||
static LogicalResult replaceAllSymbolUses(StringRef oldSymbol,
|
||||
StringRef newSymbol, Region *from);
|
||||
static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol,
|
||||
StringAttr newSymbol, Region *from);
|
||||
static LogicalResult replaceAllSymbolUses(Operation *oldSymbol,
|
||||
StringRef newSymbolName,
|
||||
StringAttr newSymbolName,
|
||||
Region *from);
|
||||
|
||||
private:
|
||||
Operation *symbolTableOp;
|
||||
|
||||
/// This is a mapping from a name to the symbol with that name.
|
||||
llvm::StringMap<Operation *> symbolTable;
|
||||
/// This is a mapping from a name to the symbol with that name. They key is
|
||||
/// always known to be a StringAttr.
|
||||
DenseMap<Attribute, Operation *> symbolTable;
|
||||
|
||||
/// This is used when name conflicts are detected.
|
||||
unsigned uniquingCounter = 0;
|
||||
|
@ -226,7 +244,7 @@ class SymbolTableCollection {
|
|||
public:
|
||||
/// Look up a symbol with the specified name within the specified symbol table
|
||||
/// operation, returning null if no such name exists.
|
||||
Operation *lookupSymbolIn(Operation *symbolTableOp, StringRef symbol);
|
||||
Operation *lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol);
|
||||
Operation *lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name);
|
||||
template <typename T, typename NameT>
|
||||
T lookupSymbolIn(Operation *symbolTableOp, NameT &&name) const {
|
||||
|
@ -244,10 +262,10 @@ public:
|
|||
/// closest parent operation of, or including, 'from' with the
|
||||
/// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
|
||||
/// found.
|
||||
Operation *lookupNearestSymbolFrom(Operation *from, StringRef symbol);
|
||||
Operation *lookupNearestSymbolFrom(Operation *from, StringAttr symbol);
|
||||
Operation *lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol);
|
||||
template <typename T>
|
||||
T lookupNearestSymbolFrom(Operation *from, StringRef symbol) {
|
||||
T lookupNearestSymbolFrom(Operation *from, StringAttr symbol) {
|
||||
return dyn_cast_or_null<T>(lookupNearestSymbolFrom(from, symbol));
|
||||
}
|
||||
template <typename T>
|
||||
|
@ -290,7 +308,7 @@ public:
|
|||
}
|
||||
|
||||
/// Replace all of the uses of the given symbol with `newSymbolName`.
|
||||
void replaceAllUsesWith(Operation *symbol, StringRef newSymbolName);
|
||||
void replaceAllUsesWith(Operation *symbol, StringAttr newSymbolName);
|
||||
|
||||
private:
|
||||
/// A reference to the symbol table used to construct this map.
|
||||
|
@ -327,18 +345,28 @@ public:
|
|||
/// Look up a symbol with the specified name, returning null if no such
|
||||
/// name exists. Symbol names never include the @ on them. Note: This
|
||||
/// performs a linear scan of held symbols.
|
||||
Operation *lookupSymbol(StringRef name) {
|
||||
Operation *lookupSymbol(StringAttr name) {
|
||||
return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), name);
|
||||
}
|
||||
template <typename T> T lookupSymbol(StringRef name) {
|
||||
template <typename T>
|
||||
T lookupSymbol(StringAttr name) {
|
||||
return dyn_cast_or_null<T>(lookupSymbol(name));
|
||||
}
|
||||
Operation *lookupSymbol(SymbolRefAttr symbol) {
|
||||
return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), symbol);
|
||||
}
|
||||
template <typename T> T lookupSymbol(SymbolRefAttr symbol) {
|
||||
template <typename T>
|
||||
T lookupSymbol(SymbolRefAttr symbol) {
|
||||
return dyn_cast_or_null<T>(lookupSymbol(symbol));
|
||||
}
|
||||
|
||||
Operation *lookupSymbol(StringRef name) {
|
||||
return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), name);
|
||||
}
|
||||
template <typename T>
|
||||
T lookupSymbol(StringRef name) {
|
||||
return dyn_cast_or_null<T>(lookupSymbol(name));
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace OpTrait
|
||||
|
|
|
@ -212,15 +212,16 @@ MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol,
|
|||
refs.reserve(numReferences);
|
||||
for (intptr_t i = 0; i < numReferences; ++i)
|
||||
refs.push_back(unwrap(references[i]).cast<FlatSymbolRefAttr>());
|
||||
return wrap(SymbolRefAttr::get(unwrap(ctx), unwrap(symbol), refs));
|
||||
auto symbolAttr = StringAttr::get(unwrap(ctx), unwrap(symbol));
|
||||
return wrap(SymbolRefAttr::get(symbolAttr, refs));
|
||||
}
|
||||
|
||||
MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) {
|
||||
return wrap(unwrap(attr).cast<SymbolRefAttr>().getRootReference());
|
||||
return wrap(unwrap(attr).cast<SymbolRefAttr>().getRootReference().getValue());
|
||||
}
|
||||
|
||||
MlirStringRef mlirSymbolRefAttrGetLeafReference(MlirAttribute attr) {
|
||||
return wrap(unwrap(attr).cast<SymbolRefAttr>().getLeafReference());
|
||||
return wrap(unwrap(attr).cast<SymbolRefAttr>().getLeafReference().getValue());
|
||||
}
|
||||
|
||||
intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr) {
|
||||
|
|
|
@ -704,7 +704,8 @@ LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
|
|||
// Get the function from the module. The name corresponds to the name of
|
||||
// the kernel function.
|
||||
auto kernelName = generateKernelNameConstant(
|
||||
launchOp.getKernelModuleName(), launchOp.getKernelName(), loc, rewriter);
|
||||
launchOp.getKernelModuleName().getValue(),
|
||||
launchOp.getKernelName().getValue(), loc, rewriter);
|
||||
auto function = moduleGetFunctionCallBuilder.create(
|
||||
loc, rewriter, {module.getResult(0), kernelName});
|
||||
auto zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
|
||||
|
|
|
@ -106,7 +106,8 @@ private:
|
|||
Operation *op) const {
|
||||
using LLVM::LLVMFuncOp;
|
||||
|
||||
Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcName);
|
||||
auto funcAttr = StringAttr::get(op->getContext(), funcName);
|
||||
Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr);
|
||||
if (funcOp)
|
||||
return cast<LLVMFuncOp>(*funcOp);
|
||||
|
||||
|
|
|
@ -181,9 +181,8 @@ void ConvertGpuLaunchFuncToVulkanLaunchFunc::convertGpuLaunchFunc(
|
|||
StringRef(binary.data(), binary.size())));
|
||||
|
||||
// Set entry point name as an attribute.
|
||||
vulkanLaunchCallOp->setAttr(
|
||||
kSPIRVEntryPointAttrName,
|
||||
StringAttr::get(loc->getContext(), launchOp.getKernelName()));
|
||||
vulkanLaunchCallOp->setAttr(kSPIRVEntryPointAttrName,
|
||||
launchOp.getKernelName());
|
||||
|
||||
launchOp.erase();
|
||||
}
|
||||
|
|
|
@ -52,9 +52,8 @@ static FlatSymbolRefAttr getLibraryCallSymbolRef(Operation *op,
|
|||
// fnName is a dynamic std::string, unique it via a SymbolRefAttr.
|
||||
FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName);
|
||||
auto module = op->getParentOfType<ModuleOp>();
|
||||
if (module.lookupSymbol(fnName)) {
|
||||
if (module.lookupSymbol(fnNameAttr.getAttr()))
|
||||
return fnNameAttr;
|
||||
}
|
||||
|
||||
SmallVector<Type, 4> inputTypes(extractOperandTypes(op));
|
||||
assert(op->getNumResults() == 0 &&
|
||||
|
|
|
@ -127,8 +127,9 @@ static LogicalResult encodeKernelName(spirv::ModuleOp module) {
|
|||
// {spv_module_name}_{function_name}
|
||||
auto entryPoint = *module.getOps<spirv::EntryPointOp>().begin();
|
||||
StringRef funcName = entryPoint.fn();
|
||||
auto funcOp = module.lookupSymbol<spirv::FuncOp>(funcName);
|
||||
std::string newFuncName = spvModuleName.str() + "_" + funcName.str();
|
||||
auto funcOp = module.lookupSymbol<spirv::FuncOp>(entryPoint.fnAttr());
|
||||
StringAttr newFuncName =
|
||||
StringAttr::get(module->getContext(), spvModuleName + "_" + funcName);
|
||||
if (failed(SymbolTable::replaceAllSymbolUses(funcOp, newFuncName, module)))
|
||||
return failure();
|
||||
SymbolTable::setSymbolName(funcOp, newFuncName);
|
||||
|
@ -166,9 +167,10 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
|
|||
// is named:
|
||||
// __spv__{kernel_module_name}
|
||||
// based on GPU to SPIR-V conversion.
|
||||
StringRef kernelModuleName = launchOp.getKernelModuleName();
|
||||
StringRef kernelModuleName = launchOp.getKernelModuleName().getValue();
|
||||
std::string spvModuleName = kSPIRVModule + kernelModuleName.str();
|
||||
auto spvModule = module.lookupSymbol<spirv::ModuleOp>(spvModuleName);
|
||||
auto spvModule = module.lookupSymbol<spirv::ModuleOp>(
|
||||
StringAttr::get(context, spvModuleName));
|
||||
if (!spvModule) {
|
||||
return launchOp.emitOpError("SPIR-V kernel module '")
|
||||
<< spvModuleName << "' is not found";
|
||||
|
@ -180,9 +182,10 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
|
|||
// variables. The name of the kernel will be
|
||||
// {spv_module_name}_{kernel_function_name}
|
||||
// to avoid symbolic name conflicts.
|
||||
StringRef kernelFuncName = launchOp.getKernelName();
|
||||
StringRef kernelFuncName = launchOp.getKernelName().getValue();
|
||||
std::string newKernelFuncName = spvModuleName + "_" + kernelFuncName.str();
|
||||
auto kernelFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(newKernelFuncName);
|
||||
auto kernelFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(
|
||||
StringAttr::get(context, newKernelFuncName));
|
||||
if (!kernelFunc) {
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(module.getBody());
|
||||
|
|
|
@ -1523,12 +1523,13 @@ void mlir::encodeBindAttribute(ModuleOp module) {
|
|||
llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
|
||||
std::to_string(descriptorSet.getInt()),
|
||||
std::to_string(binding.getInt()));
|
||||
auto nameAttr = StringAttr::get(op->getContext(), name);
|
||||
|
||||
// Replace all symbol uses and set the new symbol name. Finally, remove
|
||||
// descriptor set and binding attributes.
|
||||
if (failed(SymbolTable::replaceAllSymbolUses(op, name, spvModule)))
|
||||
if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule)))
|
||||
op.emitError("unable to replace all symbol uses for ") << name;
|
||||
SymbolTable::setSymbolName(op, name);
|
||||
SymbolTable::setSymbolName(op, nameAttr);
|
||||
op->removeAttr(kDescriptorSet);
|
||||
op->removeAttr(kBinding);
|
||||
}
|
||||
|
|
|
@ -196,14 +196,15 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,
|
|||
return success();
|
||||
|
||||
// Check that `launch_func` refers to a well-formed GPU kernel module.
|
||||
StringRef kernelModuleName = launchOp.getKernelModuleName();
|
||||
StringAttr kernelModuleName = launchOp.getKernelModuleName();
|
||||
auto kernelModule = module.lookupSymbol<GPUModuleOp>(kernelModuleName);
|
||||
if (!kernelModule)
|
||||
return launchOp.emitOpError()
|
||||
<< "kernel module '" << kernelModuleName << "' is undefined";
|
||||
<< "kernel module '" << kernelModuleName.getValue()
|
||||
<< "' is undefined";
|
||||
|
||||
// Check that `launch_func` refers to a well-formed kernel function.
|
||||
Operation *kernelFunc = module.lookupSymbol(launchOp.kernel());
|
||||
Operation *kernelFunc = module.lookupSymbol(launchOp.kernelAttr());
|
||||
auto kernelGPUFunction = dyn_cast_or_null<gpu::GPUFuncOp>(kernelFunc);
|
||||
auto kernelLLVMFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(kernelFunc);
|
||||
if (!kernelGPUFunction && !kernelLLVMFunction)
|
||||
|
@ -555,11 +556,11 @@ unsigned LaunchFuncOp::getNumKernelOperands() {
|
|||
return getNumOperands() - asyncDependencies().size() - kNumConfigOperands;
|
||||
}
|
||||
|
||||
StringRef LaunchFuncOp::getKernelModuleName() {
|
||||
StringAttr LaunchFuncOp::getKernelModuleName() {
|
||||
return kernel().getRootReference();
|
||||
}
|
||||
|
||||
StringRef LaunchFuncOp::getKernelName() { return kernel().getLeafReference(); }
|
||||
StringAttr LaunchFuncOp::getKernelName() { return kernel().getLeafReference(); }
|
||||
|
||||
Value LaunchFuncOp::getKernelOperand(unsigned i) {
|
||||
return getOperand(asyncDependencies().size() + kNumConfigOperands + i);
|
||||
|
|
|
@ -343,8 +343,8 @@ LogicalResult verifySymbolAttribute(
|
|||
// a constraint in the operation definition.
|
||||
for (SymbolRefAttr symbolRef :
|
||||
attribute.cast<ArrayAttr>().getAsRange<SymbolRefAttr>()) {
|
||||
StringRef metadataName = symbolRef.getRootReference();
|
||||
StringRef symbolName = symbolRef.getLeafReference();
|
||||
StringAttr metadataName = symbolRef.getRootReference();
|
||||
StringAttr symbolName = symbolRef.getLeafReference();
|
||||
// We want @metadata::@symbol, not just @symbol
|
||||
if (metadataName == symbolName) {
|
||||
return op->emitOpError() << "expected '" << symbolRef
|
||||
|
@ -770,7 +770,7 @@ static LogicalResult verify(CallOp &op) {
|
|||
bool isIndirect = false;
|
||||
|
||||
// If this is an indirect call, the callee attribute is missing.
|
||||
Optional<StringRef> calleeName = op.callee();
|
||||
FlatSymbolRefAttr calleeName = op.calleeAttr();
|
||||
if (!calleeName) {
|
||||
isIndirect = true;
|
||||
if (!op.getNumOperands())
|
||||
|
@ -782,14 +782,15 @@ static LogicalResult verify(CallOp &op) {
|
|||
<< ptrType;
|
||||
fnType = ptrType.getElementType();
|
||||
} else {
|
||||
Operation *callee = SymbolTable::lookupNearestSymbolFrom(op, *calleeName);
|
||||
Operation *callee =
|
||||
SymbolTable::lookupNearestSymbolFrom(op, calleeName.getAttr());
|
||||
if (!callee)
|
||||
return op.emitOpError()
|
||||
<< "'" << *calleeName
|
||||
<< "'" << calleeName.getValue()
|
||||
<< "' does not reference a symbol in the current scope";
|
||||
auto fn = dyn_cast<LLVMFuncOp>(callee);
|
||||
if (!fn)
|
||||
return op.emitOpError() << "'" << *calleeName
|
||||
return op.emitOpError() << "'" << calleeName.getValue()
|
||||
<< "' does not reference a valid LLVM function";
|
||||
|
||||
fnType = fn.getType();
|
||||
|
@ -2253,14 +2254,14 @@ LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op,
|
|||
if (!accessGroupRef)
|
||||
return op->emitOpError()
|
||||
<< "expected '" << attr << "' to be a symbol reference";
|
||||
StringRef metadataName = accessGroupRef.getRootReference();
|
||||
StringAttr metadataName = accessGroupRef.getRootReference();
|
||||
auto metadataOp =
|
||||
SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
|
||||
op->getParentOp(), metadataName);
|
||||
if (!metadataOp)
|
||||
return op->emitOpError()
|
||||
<< "expected '" << attr << "' to reference a metadata op";
|
||||
StringRef accessGroupName = accessGroupRef.getLeafReference();
|
||||
StringAttr accessGroupName = accessGroupRef.getLeafReference();
|
||||
Operation *accessGroupOp =
|
||||
SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName);
|
||||
if (!accessGroupOp)
|
||||
|
|
|
@ -1066,7 +1066,7 @@ void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
|
|||
static LogicalResult verify(spirv::AddressOfOp addressOfOp) {
|
||||
auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
|
||||
SymbolTable::lookupNearestSymbolFrom(addressOfOp->getParentOp(),
|
||||
addressOfOp.variable()));
|
||||
addressOfOp.variableAttr()));
|
||||
if (!varOp) {
|
||||
return addressOfOp.emitOpError("expected spv.GlobalVariable symbol");
|
||||
}
|
||||
|
@ -1953,14 +1953,14 @@ ArrayRef<Type> spirv::FuncOp::getCallableResults() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(spirv::FunctionCallOp functionCallOp) {
|
||||
auto fnName = functionCallOp.callee();
|
||||
auto fnName = functionCallOp.calleeAttr();
|
||||
|
||||
auto funcOp =
|
||||
dyn_cast_or_null<spirv::FuncOp>(SymbolTable::lookupNearestSymbolFrom(
|
||||
functionCallOp->getParentOp(), fnName));
|
||||
if (!funcOp) {
|
||||
return functionCallOp.emitOpError("callee function '")
|
||||
<< fnName << "' not found in nearest symbol table";
|
||||
<< fnName.getValue() << "' not found in nearest symbol table";
|
||||
}
|
||||
|
||||
auto functionType = funcOp.getType();
|
||||
|
@ -2115,7 +2115,7 @@ static LogicalResult verify(spirv::GlobalVariableOp varOp) {
|
|||
if (auto init =
|
||||
varOp->getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
|
||||
Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
|
||||
varOp->getParentOp(), init.getValue());
|
||||
varOp->getParentOp(), init.getAttr());
|
||||
// TODO: Currently only variable initialization with specialization
|
||||
// constants and other variables is supported. They could be normal
|
||||
// constants in the module scope as well.
|
||||
|
@ -2691,7 +2691,7 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) {
|
|||
|
||||
static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) {
|
||||
auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
|
||||
referenceOfOp->getParentOp(), referenceOfOp.spec_const());
|
||||
referenceOfOp->getParentOp(), referenceOfOp.spec_constAttr());
|
||||
Type constType;
|
||||
|
||||
auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
|
||||
|
@ -3516,17 +3516,17 @@ static LogicalResult verify(spirv::SpecConstantCompositeOp constOp) {
|
|||
|
||||
if (cType.isa<spirv::CooperativeMatrixNVType>())
|
||||
return constOp.emitError("unsupported composite type ") << cType;
|
||||
else if (constituents.size() != cType.getNumElements())
|
||||
if (constituents.size() != cType.getNumElements())
|
||||
return constOp.emitError("has incorrect number of operands: expected ")
|
||||
<< cType.getNumElements() << ", but provided "
|
||||
<< constituents.size();
|
||||
|
||||
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
|
||||
auto constituent = constituents[index].dyn_cast<FlatSymbolRefAttr>();
|
||||
auto constituent = constituents[index].cast<FlatSymbolRefAttr>();
|
||||
|
||||
auto constituentSpecConstOp =
|
||||
dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
|
||||
constOp->getParentOp(), constituent.getValue()));
|
||||
constOp->getParentOp(), constituent.getAttr()));
|
||||
|
||||
if (constituentSpecConstOp.default_value().getType() !=
|
||||
cType.getElementType(index))
|
||||
|
|
|
@ -30,21 +30,20 @@ static constexpr unsigned maxFreeID = 1 << 20;
|
|||
|
||||
/// Returns an unsed symbol in `module` for `oldSymbolName` by trying numeric
|
||||
/// suffix in `lastUsedID`.
|
||||
static SmallString<64> renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
|
||||
spirv::ModuleOp module) {
|
||||
static StringAttr renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
|
||||
spirv::ModuleOp module) {
|
||||
SmallString<64> newSymName(oldSymName);
|
||||
newSymName.push_back('_');
|
||||
|
||||
while (lastUsedID < maxFreeID) {
|
||||
std::string possible = (newSymName + llvm::utostr(++lastUsedID)).str();
|
||||
MLIRContext *ctx = module->getContext();
|
||||
|
||||
if (!SymbolTable::lookupSymbolIn(module, possible)) {
|
||||
newSymName += llvm::utostr(lastUsedID);
|
||||
break;
|
||||
}
|
||||
while (lastUsedID < maxFreeID) {
|
||||
auto possible = StringAttr::get(ctx, newSymName + Twine(++lastUsedID));
|
||||
if (!SymbolTable::lookupSymbolIn(module, possible))
|
||||
return possible;
|
||||
}
|
||||
|
||||
return newSymName;
|
||||
return StringAttr::get(ctx, newSymName);
|
||||
}
|
||||
|
||||
/// Checks if a symbol with the same name as `op` already exists in `source`.
|
||||
|
@ -57,7 +56,7 @@ static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
|
|||
return success();
|
||||
|
||||
StringRef oldSymName = op.getName();
|
||||
SmallString<64> newSymName = renameSymbol(oldSymName, lastUsedID, target);
|
||||
StringAttr newSymName = renameSymbol(oldSymName, lastUsedID, target);
|
||||
|
||||
if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target)))
|
||||
return op.emitError("unable to update all symbol uses for ")
|
||||
|
@ -234,7 +233,7 @@ OwningOpRef<spirv::ModuleOp> combine(ArrayRef<spirv::ModuleOp> inputModules,
|
|||
SymbolOpInterface replacementSymOp = result.first->second;
|
||||
|
||||
if (failed(SymbolTable::replaceAllSymbolUses(
|
||||
symbolOp, replacementSymOp.getName(), combinedModule))) {
|
||||
symbolOp, replacementSymOp.getNameAttr(), combinedModule))) {
|
||||
symbolOp.emitError("unable to update all symbol uses for ")
|
||||
<< symbolOp.getName() << " to " << replacementSymOp.getName();
|
||||
return nullptr;
|
||||
|
|
|
@ -64,11 +64,11 @@ public:
|
|||
LogicalResult matchAndRewrite(spirv::AddressOfOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto spirvModule = op->getParentOfType<spirv::ModuleOp>();
|
||||
auto varName = op.variable();
|
||||
auto varName = op.variableAttr();
|
||||
auto varOp = spirvModule.lookupSymbol<spirv::GlobalVariableOp>(varName);
|
||||
|
||||
rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(
|
||||
op, varOp.type(), rewriter.getSymbolRefAttr(varName));
|
||||
op, varOp.type(), rewriter.getSymbolRefAttr(varName.getAttr()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
|
|
@ -96,19 +96,21 @@ static Value getTensor(ConversionPatternRewriter &rewriter, unsigned width,
|
|||
}
|
||||
|
||||
/// Returns function reference (first hit also inserts into module).
|
||||
static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result,
|
||||
static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type resultType,
|
||||
ValueRange operands) {
|
||||
MLIRContext *context = op->getContext();
|
||||
auto module = op->getParentOfType<ModuleOp>();
|
||||
auto func = module.lookupSymbol<FuncOp>(name);
|
||||
auto result = SymbolRefAttr::get(context, name);
|
||||
auto func = module.lookupSymbol<FuncOp>(result.getAttr());
|
||||
if (!func) {
|
||||
OpBuilder moduleBuilder(module.getBodyRegion());
|
||||
moduleBuilder
|
||||
.create<FuncOp>(op->getLoc(), name,
|
||||
FunctionType::get(context, operands.getTypes(), result))
|
||||
.create<FuncOp>(
|
||||
op->getLoc(), name,
|
||||
FunctionType::get(context, operands.getTypes(), resultType))
|
||||
.setPrivate();
|
||||
}
|
||||
return SymbolRefAttr::get(context, name);
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Generates a call into the "swiss army knife" method of the sparse runtime
|
||||
|
|
|
@ -1659,7 +1659,7 @@ void ModulePrinter::printAttribute(Attribute attr,
|
|||
printType(typeAttr.getValue());
|
||||
|
||||
} else if (auto refAttr = attr.dyn_cast<SymbolRefAttr>()) {
|
||||
printSymbolReference(refAttr.getRootReference(), os);
|
||||
printSymbolReference(refAttr.getRootReference().getValue(), os);
|
||||
for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) {
|
||||
os << "::";
|
||||
printSymbolReference(nestedRef.getValue(), os);
|
||||
|
|
|
@ -216,13 +216,15 @@ FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) {
|
|||
assert(symName && "value does not have a valid symbol name");
|
||||
return getSymbolRefAttr(symName.getValue());
|
||||
}
|
||||
FlatSymbolRefAttr Builder::getSymbolRefAttr(StringRef value) {
|
||||
return SymbolRefAttr::get(getContext(), value);
|
||||
|
||||
FlatSymbolRefAttr Builder::getSymbolRefAttr(StringAttr value) {
|
||||
return SymbolRefAttr::get(value);
|
||||
}
|
||||
|
||||
SymbolRefAttr
|
||||
Builder::getSymbolRefAttr(StringRef value,
|
||||
Builder::getSymbolRefAttr(StringAttr value,
|
||||
ArrayRef<FlatSymbolRefAttr> nestedReferences) {
|
||||
return SymbolRefAttr::get(getContext(), value, nestedReferences);
|
||||
return SymbolRefAttr::get(value, nestedReferences);
|
||||
}
|
||||
|
||||
ArrayAttr Builder::getBoolArrayAttr(ArrayRef<bool> values) {
|
||||
|
|
|
@ -273,12 +273,16 @@ LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) {
|
||||
return get(ctx, value, llvm::None).cast<FlatSymbolRefAttr>();
|
||||
return get(StringAttr::get(ctx, value));
|
||||
}
|
||||
|
||||
StringRef SymbolRefAttr::getLeafReference() const {
|
||||
FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) {
|
||||
return get(value, {}).cast<FlatSymbolRefAttr>();
|
||||
}
|
||||
|
||||
StringAttr SymbolRefAttr::getLeafReference() const {
|
||||
ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
|
||||
return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue();
|
||||
return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -22,17 +22,13 @@ static bool isPotentiallyUnknownSymbolTable(Operation *op) {
|
|||
return op->getNumRegions() == 1 && !op->getDialect();
|
||||
}
|
||||
|
||||
/// Returns the string name of the given symbol, or None if this is not a
|
||||
/// Returns the string name of the given symbol, or null if this is not a
|
||||
/// symbol.
|
||||
static Optional<StringRef> getNameIfSymbol(Operation *symbol) {
|
||||
auto nameAttr =
|
||||
symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
|
||||
return nameAttr ? nameAttr.getValue() : Optional<StringRef>();
|
||||
static StringAttr getNameIfSymbol(Operation *op) {
|
||||
return op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
|
||||
}
|
||||
static Optional<StringRef> getNameIfSymbol(Operation *symbol,
|
||||
Identifier symbolAttrNameId) {
|
||||
auto nameAttr = symbol->getAttrOfType<StringAttr>(symbolAttrNameId);
|
||||
return nameAttr ? nameAttr.getValue() : Optional<StringRef>();
|
||||
static StringAttr getNameIfSymbol(Operation *op, Identifier symbolAttrNameId) {
|
||||
return op->getAttrOfType<StringAttr>(symbolAttrNameId);
|
||||
}
|
||||
|
||||
/// Computes the nested symbol reference attribute for the symbol 'symbolName'
|
||||
|
@ -40,13 +36,13 @@ static Optional<StringRef> getNameIfSymbol(Operation *symbol,
|
|||
/// to the given operation 'within', where 'within' is an ancestor of 'symbol'.
|
||||
/// Returns success if all references up to 'within' could be computed.
|
||||
static LogicalResult
|
||||
collectValidReferencesFor(Operation *symbol, StringRef symbolName,
|
||||
collectValidReferencesFor(Operation *symbol, StringAttr symbolName,
|
||||
Operation *within,
|
||||
SmallVectorImpl<SymbolRefAttr> &results) {
|
||||
assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor");
|
||||
MLIRContext *ctx = symbol->getContext();
|
||||
|
||||
auto leafRef = FlatSymbolRefAttr::get(ctx, symbolName);
|
||||
auto leafRef = FlatSymbolRefAttr::get(symbolName);
|
||||
results.push_back(leafRef);
|
||||
|
||||
// Early exit for when 'within' is the parent of 'symbol'.
|
||||
|
@ -63,17 +59,16 @@ collectValidReferencesFor(Operation *symbol, StringRef symbolName,
|
|||
if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
|
||||
return failure();
|
||||
// Each parent of 'symbol' should also be a symbol.
|
||||
Optional<StringRef> symbolTableName =
|
||||
getNameIfSymbol(symbolTableOp, symbolNameId);
|
||||
StringAttr symbolTableName = getNameIfSymbol(symbolTableOp, symbolNameId);
|
||||
if (!symbolTableName)
|
||||
return failure();
|
||||
results.push_back(SymbolRefAttr::get(ctx, *symbolTableName, nestedRefs));
|
||||
results.push_back(SymbolRefAttr::get(symbolTableName, nestedRefs));
|
||||
|
||||
symbolTableOp = symbolTableOp->getParentOp();
|
||||
if (symbolTableOp == within)
|
||||
break;
|
||||
nestedRefs.insert(nestedRefs.begin(),
|
||||
FlatSymbolRefAttr::get(ctx, *symbolTableName));
|
||||
FlatSymbolRefAttr::get(symbolTableName));
|
||||
} while (true);
|
||||
return success();
|
||||
}
|
||||
|
@ -119,11 +114,11 @@ SymbolTable::SymbolTable(Operation *symbolTableOp)
|
|||
Identifier symbolNameId = Identifier::get(SymbolTable::getSymbolAttrName(),
|
||||
symbolTableOp->getContext());
|
||||
for (auto &op : symbolTableOp->getRegion(0).front()) {
|
||||
Optional<StringRef> name = getNameIfSymbol(&op, symbolNameId);
|
||||
StringAttr name = getNameIfSymbol(&op, symbolNameId);
|
||||
if (!name)
|
||||
continue;
|
||||
|
||||
auto inserted = symbolTable.insert({*name, &op});
|
||||
auto inserted = symbolTable.insert({name, &op});
|
||||
(void)inserted;
|
||||
assert(inserted.second &&
|
||||
"expected region to contain uniquely named symbol operations");
|
||||
|
@ -133,18 +128,21 @@ SymbolTable::SymbolTable(Operation *symbolTableOp)
|
|||
/// Look up a symbol with the specified name, returning null if no such name
|
||||
/// exists. Names never include the @ on them.
|
||||
Operation *SymbolTable::lookup(StringRef name) const {
|
||||
return lookup(StringAttr::get(symbolTableOp->getContext(), name));
|
||||
}
|
||||
Operation *SymbolTable::lookup(StringAttr name) const {
|
||||
return symbolTable.lookup(name);
|
||||
}
|
||||
|
||||
/// Erase the given symbol from the table.
|
||||
void SymbolTable::erase(Operation *symbol) {
|
||||
Optional<StringRef> name = getNameIfSymbol(symbol);
|
||||
StringAttr name = getNameIfSymbol(symbol);
|
||||
assert(name && "expected valid 'name' attribute");
|
||||
assert(symbol->getParentOp() == symbolTableOp &&
|
||||
"expected this operation to be inside of the operation with this "
|
||||
"SymbolTable");
|
||||
|
||||
auto it = symbolTable.find(*name);
|
||||
auto it = symbolTable.find(name);
|
||||
if (it != symbolTable.end() && it->second == symbol) {
|
||||
symbolTable.erase(it);
|
||||
symbol->erase();
|
||||
|
@ -180,7 +178,7 @@ void SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
|
|||
|
||||
// Add this symbol to the symbol table, uniquing the name if a conflict is
|
||||
// detected.
|
||||
StringRef name = getSymbolName(symbol);
|
||||
StringAttr name = getSymbolName(symbol);
|
||||
if (symbolTable.insert({name, symbol}).second)
|
||||
return;
|
||||
// If the symbol was already in the table, also return.
|
||||
|
@ -188,28 +186,31 @@ void SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
|
|||
return;
|
||||
// If a conflict was detected, then the symbol will not have been added to
|
||||
// the symbol table. Try suffixes until we get to a unique name that works.
|
||||
SmallString<128> nameBuffer(name);
|
||||
SmallString<128> nameBuffer(name.getValue());
|
||||
unsigned originalLength = nameBuffer.size();
|
||||
|
||||
MLIRContext *context = symbol->getContext();
|
||||
|
||||
// Iteratively try suffixes until we find one that isn't used.
|
||||
do {
|
||||
nameBuffer.resize(originalLength);
|
||||
nameBuffer += '_';
|
||||
nameBuffer += std::to_string(uniquingCounter++);
|
||||
} while (!symbolTable.insert({nameBuffer, symbol}).second);
|
||||
} while (!symbolTable.insert({StringAttr::get(context, nameBuffer), symbol})
|
||||
.second);
|
||||
setSymbolName(symbol, nameBuffer);
|
||||
}
|
||||
|
||||
/// Returns the name of the given symbol operation.
|
||||
StringRef SymbolTable::getSymbolName(Operation *symbol) {
|
||||
Optional<StringRef> name = getNameIfSymbol(symbol);
|
||||
StringAttr SymbolTable::getSymbolName(Operation *symbol) {
|
||||
StringAttr name = getNameIfSymbol(symbol);
|
||||
assert(name && "expected valid symbol name");
|
||||
return *name;
|
||||
return name;
|
||||
}
|
||||
|
||||
/// Sets the name of the given symbol operation.
|
||||
void SymbolTable::setSymbolName(Operation *symbol, StringRef name) {
|
||||
symbol->setAttr(getSymbolAttrName(),
|
||||
StringAttr::get(symbol->getContext(), name));
|
||||
void SymbolTable::setSymbolName(Operation *symbol, StringAttr name) {
|
||||
symbol->setAttr(getSymbolAttrName(), name);
|
||||
}
|
||||
|
||||
/// Returns the visibility of the given symbol operation.
|
||||
|
@ -295,7 +296,7 @@ void SymbolTable::walkSymbolTables(
|
|||
/// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol
|
||||
/// was found.
|
||||
Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
|
||||
StringRef symbol) {
|
||||
StringAttr symbol) {
|
||||
assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
|
||||
Region ®ion = symbolTableOp->getRegion(0);
|
||||
if (region.empty())
|
||||
|
@ -322,7 +323,7 @@ Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
|
|||
static LogicalResult lookupSymbolInImpl(
|
||||
Operation *symbolTableOp, SymbolRefAttr symbol,
|
||||
SmallVectorImpl<Operation *> &symbols,
|
||||
function_ref<Operation *(Operation *, StringRef)> lookupSymbolFn) {
|
||||
function_ref<Operation *(Operation *, StringAttr)> lookupSymbolFn) {
|
||||
assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
|
||||
|
||||
// Lookup the root reference for this symbol.
|
||||
|
@ -343,7 +344,7 @@ static LogicalResult lookupSymbolInImpl(
|
|||
// Otherwise, lookup each of the nested non-leaf references and ensure that
|
||||
// each corresponds to a valid symbol table.
|
||||
for (FlatSymbolRefAttr ref : nestedRefs.drop_back()) {
|
||||
symbolTableOp = lookupSymbolFn(symbolTableOp, ref.getValue());
|
||||
symbolTableOp = lookupSymbolFn(symbolTableOp, ref.getAttr());
|
||||
if (!symbolTableOp || !symbolTableOp->hasTrait<OpTrait::SymbolTable>())
|
||||
return failure();
|
||||
symbols.push_back(symbolTableOp);
|
||||
|
@ -355,7 +356,7 @@ static LogicalResult lookupSymbolInImpl(
|
|||
LogicalResult
|
||||
SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol,
|
||||
SmallVectorImpl<Operation *> &symbols) {
|
||||
auto lookupFn = [](Operation *symbolTableOp, StringRef symbol) {
|
||||
auto lookupFn = [](Operation *symbolTableOp, StringAttr symbol) {
|
||||
return lookupSymbolIn(symbolTableOp, symbol);
|
||||
};
|
||||
return lookupSymbolInImpl(symbolTableOp, symbol, symbols, lookupFn);
|
||||
|
@ -365,7 +366,7 @@ SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol,
|
|||
/// closes parent operation with the 'OpTrait::SymbolTable' trait. Returns
|
||||
/// nullptr if no valid symbol was found.
|
||||
Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from,
|
||||
StringRef symbol) {
|
||||
StringAttr symbol) {
|
||||
Operation *symbolTableOp = getNearestSymbolTable(from);
|
||||
return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
|
||||
}
|
||||
|
@ -610,7 +611,7 @@ struct SymbolScope {
|
|||
/// Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'.
|
||||
static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
|
||||
Operation *limit) {
|
||||
StringRef symName = SymbolTable::getSymbolName(symbol);
|
||||
StringAttr symName = SymbolTable::getSymbolName(symbol);
|
||||
assert(!symbol->hasTrait<OpTrait::SymbolTable>() || symbol != limit);
|
||||
|
||||
// Compute the ancestors of 'limit'.
|
||||
|
@ -625,7 +626,7 @@ static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
|
|||
// doesn't support parent references.
|
||||
if (SymbolTable::getNearestSymbolTable(limit->getParentOp()) ==
|
||||
symbol->getParentOp())
|
||||
return {{SymbolRefAttr::get(symbol->getContext(), symName), limit}};
|
||||
return {{SymbolRefAttr::get(symName), limit}};
|
||||
return {};
|
||||
}
|
||||
|
||||
|
@ -679,9 +680,9 @@ static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
|
|||
return scopes;
|
||||
}
|
||||
template <typename IRUnit>
|
||||
static SmallVector<SymbolScope, 1> collectSymbolScopes(StringRef symbol,
|
||||
static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol,
|
||||
IRUnit *limit) {
|
||||
return {{SymbolRefAttr::get(limit->getContext(), symbol), limit}};
|
||||
return {{SymbolRefAttr::get(symbol), limit}};
|
||||
}
|
||||
|
||||
/// Returns true if the given reference 'SubRef' is a sub reference of the
|
||||
|
@ -753,7 +754,7 @@ static Optional<SymbolTable::UseRange> getSymbolUsesImpl(SymbolT symbol,
|
|||
/// operation 'from', invoking the provided callback for each. This does not
|
||||
/// traverse into any nested symbol tables. This function returns None if there
|
||||
/// are any unknown operations that may potentially be symbol tables.
|
||||
auto SymbolTable::getSymbolUses(StringRef symbol, Operation *from)
|
||||
auto SymbolTable::getSymbolUses(StringAttr symbol, Operation *from)
|
||||
-> Optional<UseRange> {
|
||||
return getSymbolUsesImpl(symbol, from);
|
||||
}
|
||||
|
@ -761,7 +762,7 @@ auto SymbolTable::getSymbolUses(Operation *symbol, Operation *from)
|
|||
-> Optional<UseRange> {
|
||||
return getSymbolUsesImpl(symbol, from);
|
||||
}
|
||||
auto SymbolTable::getSymbolUses(StringRef symbol, Region *from)
|
||||
auto SymbolTable::getSymbolUses(StringAttr symbol, Region *from)
|
||||
-> Optional<UseRange> {
|
||||
return getSymbolUsesImpl(symbol, from);
|
||||
}
|
||||
|
@ -792,13 +793,13 @@ static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit) {
|
|||
/// the given operation 'from'. This does not traverse into any nested symbol
|
||||
/// tables. This function will also return false if there are any unknown
|
||||
/// operations that may potentially be symbol tables.
|
||||
bool SymbolTable::symbolKnownUseEmpty(StringRef symbol, Operation *from) {
|
||||
bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Operation *from) {
|
||||
return symbolKnownUseEmptyImpl(symbol, from);
|
||||
}
|
||||
bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Operation *from) {
|
||||
return symbolKnownUseEmptyImpl(symbol, from);
|
||||
}
|
||||
bool SymbolTable::symbolKnownUseEmpty(StringRef symbol, Region *from) {
|
||||
bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Region *from) {
|
||||
return symbolKnownUseEmptyImpl(symbol, from);
|
||||
}
|
||||
bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Region *from) {
|
||||
|
@ -861,14 +862,13 @@ static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
|
|||
return newLeafAttr;
|
||||
auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences());
|
||||
nestedRefs.back() = newLeafAttr;
|
||||
return SymbolRefAttr::get(oldAttr.getContext(), oldAttr.getRootReference(),
|
||||
nestedRefs);
|
||||
return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs);
|
||||
}
|
||||
|
||||
/// The implementation of SymbolTable::replaceAllSymbolUses below.
|
||||
template <typename SymbolT, typename IRUnitT>
|
||||
static LogicalResult
|
||||
replaceAllSymbolUsesImpl(SymbolT symbol, StringRef newSymbol, IRUnitT *limit) {
|
||||
replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) {
|
||||
// A collection of operations along with their new attribute dictionary.
|
||||
std::vector<std::pair<Operation *, DictionaryAttr>> updatedAttrDicts;
|
||||
|
||||
|
@ -888,8 +888,7 @@ replaceAllSymbolUsesImpl(SymbolT symbol, StringRef newSymbol, IRUnitT *limit) {
|
|||
};
|
||||
|
||||
// Generate a new attribute to replace the given attribute.
|
||||
MLIRContext *ctx = limit->getContext();
|
||||
FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(ctx, newSymbol);
|
||||
FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol);
|
||||
for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
|
||||
SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr);
|
||||
auto walkFn = [&](SymbolTable::SymbolUse symbolUse,
|
||||
|
@ -905,13 +904,13 @@ replaceAllSymbolUsesImpl(SymbolT symbol, StringRef newSymbol, IRUnitT *limit) {
|
|||
if (useRef != scope.symbol) {
|
||||
if (scope.symbol.isa<FlatSymbolRefAttr>()) {
|
||||
replacementRef =
|
||||
SymbolRefAttr::get(ctx, newSymbol, useRef.getNestedReferences());
|
||||
SymbolRefAttr::get(newSymbol, useRef.getNestedReferences());
|
||||
} else {
|
||||
auto nestedRefs = llvm::to_vector<4>(useRef.getNestedReferences());
|
||||
nestedRefs[scope.symbol.getNestedReferences().size() - 1] =
|
||||
newLeafAttr;
|
||||
replacementRef =
|
||||
SymbolRefAttr::get(ctx, useRef.getRootReference(), nestedRefs);
|
||||
SymbolRefAttr::get(useRef.getRootReference(), nestedRefs);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -949,23 +948,23 @@ replaceAllSymbolUsesImpl(SymbolT symbol, StringRef newSymbol, IRUnitT *limit) {
|
|||
/// 'from'. This does not traverse into any nested symbol tables. If there are
|
||||
/// any unknown operations that may potentially be symbol tables, no uses are
|
||||
/// replaced and failure is returned.
|
||||
LogicalResult SymbolTable::replaceAllSymbolUses(StringRef oldSymbol,
|
||||
StringRef newSymbol,
|
||||
LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol,
|
||||
StringAttr newSymbol,
|
||||
Operation *from) {
|
||||
return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
|
||||
}
|
||||
LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
|
||||
StringRef newSymbol,
|
||||
StringAttr newSymbol,
|
||||
Operation *from) {
|
||||
return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
|
||||
}
|
||||
LogicalResult SymbolTable::replaceAllSymbolUses(StringRef oldSymbol,
|
||||
StringRef newSymbol,
|
||||
LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol,
|
||||
StringAttr newSymbol,
|
||||
Region *from) {
|
||||
return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
|
||||
}
|
||||
LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
|
||||
StringRef newSymbol,
|
||||
StringAttr newSymbol,
|
||||
Region *from) {
|
||||
return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
|
||||
}
|
||||
|
@ -975,7 +974,7 @@ LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Operation *SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
|
||||
StringRef symbol) {
|
||||
StringAttr symbol) {
|
||||
return getSymbolTable(symbolTableOp).lookup(symbol);
|
||||
}
|
||||
Operation *SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
|
||||
|
@ -992,7 +991,7 @@ LogicalResult
|
|||
SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
|
||||
SymbolRefAttr name,
|
||||
SmallVectorImpl<Operation *> &symbols) {
|
||||
auto lookupFn = [this](Operation *symbolTableOp, StringRef symbol) {
|
||||
auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) {
|
||||
return lookupSymbolIn(symbolTableOp, symbol);
|
||||
};
|
||||
return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
|
||||
|
@ -1003,7 +1002,7 @@ SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
|
|||
/// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was
|
||||
/// found.
|
||||
Operation *SymbolTableCollection::lookupNearestSymbolFrom(Operation *from,
|
||||
StringRef symbol) {
|
||||
StringAttr symbol) {
|
||||
Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
|
||||
return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
|
||||
}
|
||||
|
@ -1052,7 +1051,7 @@ SymbolUserMap::SymbolUserMap(SymbolTableCollection &symbolTable,
|
|||
}
|
||||
|
||||
void SymbolUserMap::replaceAllUsesWith(Operation *symbol,
|
||||
StringRef newSymbolName) {
|
||||
StringAttr newSymbolName) {
|
||||
auto it = symbolToUsers.find(symbol);
|
||||
if (it == symbolToUsers.end())
|
||||
return;
|
||||
|
|
|
@ -818,7 +818,7 @@ void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
|
|||
void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
|
||||
ByteCodeField patternIndex = patterns.size();
|
||||
patterns.emplace_back(PDLByteCodePattern::create(
|
||||
op, rewriterToAddr[op.rewriter().getLeafReference()]));
|
||||
op, rewriterToAddr[op.rewriter().getLeafReference().getValue()]));
|
||||
writer.append(OpCode::RecordMatch, patternIndex,
|
||||
SuccessorRange(op.getOperation()), op.matchedOps());
|
||||
writer.appendPDLValueList(op.inputs());
|
||||
|
|
|
@ -814,8 +814,8 @@ LogicalResult ModuleTranslation::createAliasScopeMetadata() {
|
|||
llvm::MDNode *
|
||||
ModuleTranslation::getAliasScope(Operation &opInst,
|
||||
SymbolRefAttr aliasScopeRef) const {
|
||||
StringRef metadataName = aliasScopeRef.getRootReference();
|
||||
StringRef scopeName = aliasScopeRef.getLeafReference();
|
||||
StringAttr metadataName = aliasScopeRef.getRootReference();
|
||||
StringAttr scopeName = aliasScopeRef.getLeafReference();
|
||||
auto metadataOp = SymbolTable::lookupNearestSymbolFrom<LLVM::MetadataOp>(
|
||||
opInst.getParentOp(), metadataName);
|
||||
Operation *aliasScopeOp =
|
||||
|
|
|
@ -84,7 +84,7 @@ struct SymbolUsesPass
|
|||
table.erase(op);
|
||||
assert(!table.lookup(name) &&
|
||||
"expected erased operation to be unknown now");
|
||||
module.emitRemark() << name << " function successfully erased";
|
||||
module.emitRemark() << name.getValue() << " function successfully erased";
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -110,8 +110,8 @@ struct SymbolReplacementPass
|
|||
StringAttr newName = nestedOp->getAttrOfType<StringAttr>("sym.new_name");
|
||||
if (!newName)
|
||||
return;
|
||||
symbolUsers.replaceAllUsesWith(nestedOp, newName.getValue());
|
||||
SymbolTable::setSymbolName(nestedOp, newName.getValue());
|
||||
symbolUsers.replaceAllUsesWith(nestedOp, newName);
|
||||
SymbolTable::setSymbolName(nestedOp, newName);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
|
|
@ -80,8 +80,10 @@ struct TestPDLByteCodePass
|
|||
|
||||
// The test cases are encompassed via two modules, one containing the
|
||||
// patterns and one containing the operations to rewrite.
|
||||
ModuleOp patternModule = module.lookupSymbol<ModuleOp>("patterns");
|
||||
ModuleOp irModule = module.lookupSymbol<ModuleOp>("ir");
|
||||
ModuleOp patternModule = module.lookupSymbol<ModuleOp>(
|
||||
StringAttr::get(module->getContext(), "patterns"));
|
||||
ModuleOp irModule = module.lookupSymbol<ModuleOp>(
|
||||
StringAttr::get(module->getContext(), "ir"));
|
||||
if (!patternModule || !irModule)
|
||||
return;
|
||||
|
||||
|
|
Loading…
Reference in New Issue