Simplify several usages of attributes now that they always have a type and, transitively, access to the context.

This also fixes a bug where FunctionAttrs were not being remapped for function and function argument attributes.

--

PiperOrigin-RevId: 246876924
This commit is contained in:
River Riddle 2019-05-06 12:40:43 -07:00 committed by Mehdi Amini
parent 94afc426e2
commit 983e0eea95
13 changed files with 95 additions and 68 deletions

View File

@ -571,7 +571,7 @@ PythonMLIRModule::declareFunction(const std::string &name,
inAttrs.emplace_back(Identifier::get(named.name, &mlirContext),
mlir::Attribute::getFromOpaquePointer(
reinterpret_cast<const void *>(named.value)));
inputAttrs.emplace_back(&mlirContext, inAttrs);
inputAttrs.emplace_back(inAttrs);
}
// Create the function itself.
@ -634,7 +634,7 @@ PYBIND11_MODULE(pybind, m) {
});
m.def("constant_function", [](PythonFunction func) -> PythonValueHandle {
auto *function = reinterpret_cast<Function *>(func.function);
auto attr = FunctionAttr::get(function, function->getContext());
auto attr = FunctionAttr::get(function);
return ValueHandle::create<ConstantOp>(function->getType(), attr);
});
m.def("appendTo", [](const PythonBlockHandle &handle) {

View File

@ -127,6 +127,9 @@ public:
/// Return the type of this attribute.
Type getType() const;
/// Return the context this attribute belongs to.
MLIRContext *getContext() const;
/// Return true if this field is, or contains, a function attribute.
bool isOrContainsFunction() const;
@ -135,8 +138,7 @@ public:
/// remapping table. Return the original attribute if it (or any of nested
/// attributes) is not present in the table.
Attribute remapFunctionAttrs(
const llvm::DenseMap<Attribute, FunctionAttr> &remappingTable,
MLIRContext *context) const;
const llvm::DenseMap<Attribute, FunctionAttr> &remappingTable) const;
/// Print the attribute.
void print(raw_ostream &os) const;
@ -299,7 +301,7 @@ public:
using ImplType = detail::TypeAttributeStorage;
using ValueType = Type;
static TypeAttr get(Type type, MLIRContext *context);
static TypeAttr get(Type value);
Type getValue() const;
@ -320,7 +322,7 @@ public:
using ImplType = detail::FunctionAttributeStorage;
using ValueType = Function *;
static FunctionAttr get(Function *value, MLIRContext *context);
static FunctionAttr get(Function *value);
Function *getValue() const;
@ -642,13 +644,13 @@ using NamedAttribute = std::pair<Identifier, Attribute>;
class NamedAttributeList {
public:
NamedAttributeList() : attrs(nullptr) {}
NamedAttributeList(MLIRContext *context, ArrayRef<NamedAttribute> attributes);
NamedAttributeList(ArrayRef<NamedAttribute> attributes);
/// Return all of the attributes on this operation.
ArrayRef<NamedAttribute> getAttrs() const;
/// Replace the held attributes with ones provided in 'newAttrs'.
void setAttrs(MLIRContext *context, ArrayRef<NamedAttribute> attributes);
void setAttrs(ArrayRef<NamedAttribute> attributes);
/// Return the specified attribute if present, null otherwise.
Attribute get(StringRef name) const;
@ -656,13 +658,13 @@ public:
/// If the an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value.
void set(MLIRContext *context, Identifier name, Attribute value);
void set(Identifier name, Attribute value);
enum class RemoveResult { Removed, NotFound };
/// Remove the attribute with the specified name if it exists. The return
/// value indicates whether the attribute was present or not.
RemoveResult remove(MLIRContext *context, Identifier name);
RemoveResult remove(Identifier name);
private:
detail::AttributeListStorage *attrs;

View File

@ -157,6 +157,9 @@ public:
/// Return all of the attributes on this function.
ArrayRef<NamedAttribute> getAttrs() { return attrs.getAttrs(); }
/// Return the internal attribute list on this function.
NamedAttributeList &getAttrList() { return attrs; }
/// Return all of the attributes for the argument at 'index'.
ArrayRef<NamedAttribute> getArgAttrs(unsigned index) {
assert(index < getNumArguments() && "invalid argument number");
@ -165,13 +168,13 @@ public:
/// Set the attributes held by this function.
void setAttrs(ArrayRef<NamedAttribute> attributes) {
attrs.setAttrs(getContext(), attributes);
attrs.setAttrs(attributes);
}
/// Set the attributes held by the argument at 'index'.
void setArgAttrs(unsigned index, ArrayRef<NamedAttribute> attributes) {
assert(index < getNumArguments() && "invalid argument number");
argAttrs[index].setAttrs(getContext(), attributes);
argAttrs[index].setAttrs(attributes);
}
/// Return all argument attributes of this function.
@ -212,15 +215,13 @@ public:
/// If the an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value.
void setAttr(Identifier name, Attribute value) {
attrs.set(getContext(), name, value);
}
void setAttr(Identifier name, Attribute value) { attrs.set(name, value); }
void setAttr(StringRef name, Attribute value) {
setAttr(Identifier::get(name, getContext()), value);
}
void setArgAttr(unsigned index, Identifier name, Attribute value) {
assert(index < getNumArguments() && "invalid argument number");
argAttrs[index].set(getContext(), name, value);
argAttrs[index].set(name, value);
}
void setArgAttr(unsigned index, StringRef name, Attribute value) {
setArgAttr(index, Identifier::get(name, getContext()), value);
@ -229,12 +230,12 @@ public:
/// Remove the attribute with the specified name if it exists. The return
/// value indicates whether the attribute was present or not.
NamedAttributeList::RemoveResult removeAttr(Identifier name) {
return attrs.remove(getContext(), name);
return attrs.remove(name);
}
NamedAttributeList::RemoveResult removeArgAttr(unsigned index,
Identifier name) {
assert(index < getNumArguments() && "invalid argument number");
return attrs.remove(getContext(), name);
return attrs.remove(name);
}
//===--------------------------------------------------------------------===//

View File

@ -239,6 +239,9 @@ public:
/// Return all of the attributes on this operation.
ArrayRef<NamedAttribute> getAttrs() { return attrs.getAttrs(); }
/// Return the internal attribute list on this operation.
NamedAttributeList &getAttrList() { return attrs; }
/// Return the specified attribute if present, null otherwise.
Attribute getAttr(Identifier name) { return attrs.get(name); }
Attribute getAttr(StringRef name) { return attrs.get(name); }
@ -253,9 +256,7 @@ public:
/// If the an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value.
void setAttr(Identifier name, Attribute value) {
attrs.set(getContext(), name, value);
}
void setAttr(Identifier name, Attribute value) { attrs.set(name, value); }
void setAttr(StringRef name, Attribute value) {
setAttr(Identifier::get(name, getContext()), value);
}
@ -263,7 +264,7 @@ public:
/// Remove the attribute with the specified name if it exists. The return
/// value indicates whether the attribute was present or not.
NamedAttributeList::RemoveResult removeAttr(Identifier name) {
return attrs.remove(getContext(), name);
return attrs.remove(name);
}
//===--------------------------------------------------------------------===//

View File

@ -408,8 +408,7 @@ public:
/// Given a list of NamedAttribute's, canonicalize the list (sorting
/// by name) and return the unique'd result. Note that the empty list is
/// represented with a null pointer.
static AttributeListStorage *get(ArrayRef<NamedAttribute> attrs,
MLIRContext *context);
static AttributeListStorage *get(ArrayRef<NamedAttribute> attrs);
/// Return the element constants for this aggregate constant. These are
/// known to all be constants.

View File

@ -67,6 +67,9 @@ Attribute::Kind Attribute::getKind() const {
/// Return the type of this attribute.
Type Attribute::getType() const { return attr->getType(); }
/// Return the context this attribute belongs to.
MLIRContext *Attribute::getContext() const { return getType().getContext(); }
bool Attribute::isOrContainsFunction() const {
return attr->isOrContainsFunctionCache();
}
@ -75,8 +78,7 @@ bool Attribute::isOrContainsFunction() const {
// table, walk it and rewrite it to use the mapped function. If it doesn't
// refer to anything in the table, then it is returned unmodified.
Attribute Attribute::remapFunctionAttrs(
const llvm::DenseMap<Attribute, FunctionAttr> &remappingTable,
MLIRContext *context) const {
const llvm::DenseMap<Attribute, FunctionAttr> &remappingTable) const {
// Most attributes are trivially unrelated to function attributes, skip them
// rapidly.
if (!isOrContainsFunction())
@ -93,7 +95,7 @@ Attribute Attribute::remapFunctionAttrs(
SmallVector<Attribute, 8> remappedElts;
bool anyChange = false;
for (auto elt : arrayAttr.getValue()) {
auto newElt = elt.remapFunctionAttrs(remappingTable, context);
auto newElt = elt.remapFunctionAttrs(remappingTable);
remappedElts.push_back(newElt);
anyChange |= (elt != newElt);
}
@ -101,7 +103,7 @@ Attribute Attribute::remapFunctionAttrs(
if (!anyChange)
return *this;
return ArrayAttr::get(remappedElts, context);
return ArrayAttr::get(remappedElts, getContext());
}
//===----------------------------------------------------------------------===//
@ -262,8 +264,9 @@ IntegerSet IntegerSetAttr::getValue() const {
// TypeAttr
//===----------------------------------------------------------------------===//
TypeAttr TypeAttr::get(Type value, MLIRContext *context) {
return AttributeUniquer::get<TypeAttr>(context, Attribute::Kind::Type, value);
TypeAttr TypeAttr::get(Type value) {
return AttributeUniquer::get<TypeAttr>(value.getContext(),
Attribute::Kind::Type, value);
}
Type TypeAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
@ -272,10 +275,10 @@ Type TypeAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
// FunctionAttr
//===----------------------------------------------------------------------===//
FunctionAttr FunctionAttr::get(Function *value, MLIRContext *context) {
FunctionAttr FunctionAttr::get(Function *value) {
assert(value && "Cannot get FunctionAttr for a null function");
return AttributeUniquer::get<FunctionAttr>(context, Attribute::Kind::Function,
value);
return AttributeUniquer::get<FunctionAttr>(value->getContext(),
Attribute::Kind::Function, value);
}
/// This function is used by the internals of the Function class to null out
@ -737,9 +740,8 @@ Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
// NamedAttributeList
//===----------------------------------------------------------------------===//
NamedAttributeList::NamedAttributeList(MLIRContext *context,
ArrayRef<NamedAttribute> attributes) {
setAttrs(context, attributes);
NamedAttributeList::NamedAttributeList(ArrayRef<NamedAttribute> attributes) {
setAttrs(attributes);
}
/// Return all of the attributes on this operation.
@ -748,8 +750,7 @@ ArrayRef<NamedAttribute> NamedAttributeList::getAttrs() const {
}
/// Replace the held attributes with ones provided in 'newAttrs'.
void NamedAttributeList::setAttrs(MLIRContext *context,
ArrayRef<NamedAttribute> attributes) {
void NamedAttributeList::setAttrs(ArrayRef<NamedAttribute> attributes) {
// Don't create an attribute list if there are no attributes.
if (attributes.empty()) {
attrs = nullptr;
@ -759,7 +760,7 @@ void NamedAttributeList::setAttrs(MLIRContext *context,
assert(llvm::all_of(attributes,
[](const NamedAttribute &attr) { return attr.second; }) &&
"attributes cannot have null entries");
attrs = AttributeListStorage::get(attributes, context);
attrs = AttributeListStorage::get(attributes);
}
/// Return the specified attribute if present, null otherwise.
@ -778,8 +779,7 @@ Attribute NamedAttributeList::get(Identifier name) const {
/// If the an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value.
void NamedAttributeList::set(MLIRContext *context, Identifier name,
Attribute value) {
void NamedAttributeList::set(Identifier name, Attribute value) {
assert(value && "attributes may never be null");
// If we already have this attribute, replace it.
@ -788,27 +788,32 @@ void NamedAttributeList::set(MLIRContext *context, Identifier name,
for (auto &elt : newAttrs)
if (elt.first == name) {
elt.second = value;
attrs = AttributeListStorage::get(newAttrs, context);
attrs = AttributeListStorage::get(newAttrs);
return;
}
// Otherwise, add it.
newAttrs.push_back({name, value});
attrs = AttributeListStorage::get(newAttrs, context);
attrs = AttributeListStorage::get(newAttrs);
}
/// Remove the attribute with the specified name if it exists. The return
/// value indicates whether the attribute was present or not.
auto NamedAttributeList::remove(MLIRContext *context, Identifier name)
-> RemoveResult {
auto NamedAttributeList::remove(Identifier name) -> RemoveResult {
auto origAttrs = getAttrs();
for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) {
if (origAttrs[i].first == name) {
// Handle the simple case of removing the only attribute in the list.
if (e == 1) {
attrs = nullptr;
return RemoveResult::Removed;
}
SmallVector<NamedAttribute, 8> newAttrs;
newAttrs.reserve(origAttrs.size() - 1);
newAttrs.append(origAttrs.begin(), origAttrs.begin() + i);
newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end());
attrs = AttributeListStorage::get(newAttrs, context);
attrs = AttributeListStorage::get(newAttrs);
return RemoveResult::Removed;
}
}

View File

@ -167,12 +167,10 @@ IntegerSetAttr Builder::getIntegerSetAttr(IntegerSet set) {
return IntegerSetAttr::get(set);
}
TypeAttr Builder::getTypeAttr(Type type) {
return TypeAttr::get(type, context);
}
TypeAttr Builder::getTypeAttr(Type type) { return TypeAttr::get(type); }
FunctionAttr Builder::getFunctionAttr(Function *value) {
return FunctionAttr::get(value, context);
return FunctionAttr::get(value);
}
ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType type,

View File

@ -30,15 +30,13 @@ using namespace mlir;
Function::Function(Location location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs)
: name(Identifier::get(name, type.getContext())), location(location),
type(type), attrs(type.getContext(), attrs),
argAttrs(type.getNumInputs()), body(this) {}
type(type), attrs(attrs), argAttrs(type.getNumInputs()), body(this) {}
Function::Function(Location location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs,
ArrayRef<NamedAttributeList> argAttrs)
: name(Identifier::get(name, type.getContext())), location(location),
type(type), attrs(type.getContext(), attrs), argAttrs(argAttrs),
body(this) {}
type(type), attrs(attrs), argAttrs(argAttrs), body(this) {}
Function::~Function() {
// Clean up function attributes referring to this function.

View File

@ -849,8 +849,8 @@ static int compareNamedAttributes(const NamedAttribute *lhs,
/// Given a list of NamedAttribute's, canonicalize the list (sorting
/// by name) and return the unique'd result. Note that the empty list is
/// represented with a null pointer.
AttributeListStorage *AttributeListStorage::get(ArrayRef<NamedAttribute> attrs,
MLIRContext *context) {
AttributeListStorage *
AttributeListStorage::get(ArrayRef<NamedAttribute> attrs) {
// We need to sort the element list to canonicalize it, but we also don't want
// to do a ton of work in the super common case where the element list is
// already sorted.
@ -888,7 +888,7 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef<NamedAttribute> attrs,
}
}
auto &impl = context->getImpl();
auto &impl = attrs[0].second.getContext()->getImpl();
// Safely get or create an attribute instance.
return safeGetOrCreate(impl.attributeLists, attrs, impl.attributeMutex, [&] {

View File

@ -102,7 +102,7 @@ Operation *Operation::create(Location location, OperationName name,
ArrayRef<Block *> successors, unsigned numRegions,
bool resizableOperandList, MLIRContext *context) {
return create(location, name, operands, resultTypes,
NamedAttributeList(context, attributes), successors, numRegions,
NamedAttributeList(attributes), successors, numRegions,
resizableOperandList, context);
}

View File

@ -316,8 +316,8 @@ LogicalResult impl::FunctionConversion::run(Module *module) {
if (!converted)
return failure();
auto origFuncAttr = FunctionAttr::get(func, context);
auto convertedFuncAttr = FunctionAttr::get(converted, context);
auto origFuncAttr = FunctionAttr::get(func);
auto convertedFuncAttr = FunctionAttr::get(converted);
convertedFuncs.push_back(converted);
functionAttrRemapping.insert({origFuncAttr, convertedFuncAttr});
}

View File

@ -290,33 +290,44 @@ void mlir::createAffineComputationSlice(
}
}
void mlir::remapFunctionAttrs(
Operation &op, const DenseMap<Attribute, FunctionAttr> &remappingTable) {
for (auto attr : op.getAttrs()) {
static void
remapFunctionAttrs(NamedAttributeList &attrs,
const DenseMap<Attribute, FunctionAttr> &remappingTable) {
for (auto attr : attrs.getAttrs()) {
// Do the remapping, if we got the same thing back, then it must contain
// functions that aren't getting remapped.
auto newVal =
attr.second.remapFunctionAttrs(remappingTable, op.getContext());
auto newVal = attr.second.remapFunctionAttrs(remappingTable);
if (newVal == attr.second)
continue;
// Otherwise, replace the existing attribute with the new one. It is safe
// to mutate the attribute list while we walk it because underlying
// attribute lists are uniqued and immortal.
op.setAttr(attr.first, newVal);
attrs.set(attr.first, newVal);
}
}
void mlir::remapFunctionAttrs(
Operation &op, const DenseMap<Attribute, FunctionAttr> &remappingTable) {
::remapFunctionAttrs(op.getAttrList(), remappingTable);
}
void mlir::remapFunctionAttrs(
Function &fn, const DenseMap<Attribute, FunctionAttr> &remappingTable) {
// Remap the attributes of the function.
::remapFunctionAttrs(fn.getAttrList(), remappingTable);
// Remap the attributes of the arguments of this function.
for (auto &attrList : fn.getAllArgAttrs())
::remapFunctionAttrs(attrList, remappingTable);
// Look at all operations in a Function.
fn.walk([&](Operation *op) { remapFunctionAttrs(*op, remappingTable); });
}
void mlir::remapFunctionAttrs(
Module &module, const DenseMap<Attribute, FunctionAttr> &remappingTable) {
for (auto &fn : module) {
for (auto &fn : module)
remapFunctionAttrs(fn, remappingTable);
}
}

View File

@ -901,3 +901,15 @@ func @none_type() {
%none_val = "foo.unknown_op"() : () -> none
return
}
// CHECK-LABEL: func @fn_attr_remap
// CHECK: {some_dialect.arg_attr: @fn_attr_ref : () -> ()}
func @fn_attr_remap(%arg0: i1 {some_dialect.arg_attr: @fn_attr_ref : () -> ()}) -> ()
// CHECK-NEXT: {some_dialect.fn_attr: @fn_attr_ref : () -> ()}
attributes {some_dialect.fn_attr: @fn_attr_ref : () -> ()} {
return
}
// CHECK-LABEL: func @fn_attr_ref
func @fn_attr_ref() -> ()