Add support for named function argument attributes. The attribute dictionary is printed after the argument type:

func @arg_attrs(i32 {arg_attr: 10})

func @arg_attrs(%arg0: i32 {arg_attr: 10})

PiperOrigin-RevId: 236136830
This commit is contained in:
River Riddle 2019-02-28 09:30:52 -08:00 committed by jpienaar
parent 8cc50208a6
commit db1757f858
9 changed files with 139 additions and 41 deletions

View File

@ -1101,8 +1101,8 @@ function ::= `func` function-signature function-attributes? function-body?
function-signature ::= function-id `(` argument-list `)` (`->` function-result-type)?
argument-list ::= named-argument (`,` named-argument)* | /*empty*/
argument-list ::= type (`,` type)* | /*empty*/ named-argument ::= ssa-id `:`
type
argument-list ::= type attribute-dict? (`,` type attribute-dict?)* | /*empty*/
named-argument ::= ssa-id `:` type attribute-dict?
function-attributes ::= `attributes` attribute-dict
function-body ::= `{` block+ `}`

View File

@ -548,6 +548,7 @@ using NamedAttribute = std::pair<Identifier, Attribute>;
/// searches for everything.
class NamedAttributeList {
public:
NamedAttributeList() : attrs(nullptr) {}
NamedAttributeList(MLIRContext *context, ArrayRef<NamedAttribute> attributes);
/// Return all of the attributes on this operation.

View File

@ -178,18 +178,41 @@ public:
/// constants to names. Attributes may be dynamically added and removed over
/// the lifetime of an function.
/// Return all of the attributes on this instruction.
/// Return all of the attributes on this function.
ArrayRef<NamedAttribute> getAttrs() const { return attrs.getAttrs(); }
/// Return all of the attributes for the argument at 'index'.
ArrayRef<NamedAttribute> getArgAttrs(unsigned index) const {
assert(index < getNumArguments() && "invalid argument number");
return argAttrs[index].getAttrs();
}
/// Set the attributes held by this function.
void setAttrs(ArrayRef<NamedAttribute> attributes) {
attrs.setAttrs(getContext(), attributes);
}
/// Set the attributes held by the argument at 'index'.
void setArgAttrs(unsigned index, ArrayRef<NamedAttribute> attributes) {
assert(index < getNumArguments() && "invalid argument number");
argAttrs[index].setAttrs(getContext(), attributes);
}
/// Return the specified attribute if present, null otherwise.
Attribute getAttr(Identifier name) const { return attrs.get(name); }
Attribute getAttr(StringRef name) const { return attrs.get(name); }
/// Return the specified attribute, if present, for the argument at 'index',
/// null otherwise.
Attribute getArgAttr(unsigned index, Identifier name) const {
assert(index < getNumArguments() && "invalid argument number");
return argAttrs[index].get(name);
}
Attribute getArgAttr(unsigned index, StringRef name) const {
assert(index < getNumArguments() && "invalid argument number");
return argAttrs[index].get(name);
}
template <typename AttrClass> AttrClass getAttrOfType(Identifier name) const {
return getAttr(name).dyn_cast_or_null<AttrClass>();
}
@ -198,17 +221,36 @@ public:
return getAttr(name).dyn_cast_or_null<AttrClass>();
}
template <typename AttrClass>
AttrClass getArgAttrOfType(unsigned index, Identifier name) const {
return getArgAttr(index, name).dyn_cast_or_null<AttrClass>();
}
template <typename AttrClass>
AttrClass getArgAttrOfType(unsigned index, StringRef name) const {
return getArgAttr(index, name).dyn_cast_or_null<AttrClass>();
}
/// If the an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value.
void setAttr(Identifier name, Attribute value) {
attrs.set(getContext(), name, value);
}
void setArgAttr(unsigned index, Identifier name, Attribute value) {
assert(index < getNumArguments() && "invalid argument number");
argAttrs[index].set(getContext(), name, value);
}
/// Remove the attribute with the specified name if it exists. The return
/// value indicates whether the attribute was present or not.
NamedAttributeList::RemoveResult removeAttr(Identifier name) {
return attrs.remove(getContext(), name);
}
NamedAttributeList::RemoveResult removeArgAttr(unsigned index,
Identifier name) {
assert(index < getNumArguments() && "invalid argument number");
return attrs.remove(getContext(), name);
}
//===--------------------------------------------------------------------===//
// Other
@ -272,6 +314,9 @@ private:
/// This holds general named attributes for the function.
NamedAttributeList attrs;
/// The attributes lists for each of the function arguments.
std::vector<NamedAttributeList> argAttrs;
/// The contents of the body.
BlockList blocks;

View File

@ -38,6 +38,7 @@
#include "mlir/IR/Function.h"
#include "mlir/IR/Instruction.h"
#include "mlir/IR/Module.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/Regex.h"
#include "llvm/Support/raw_ostream.h"
@ -125,15 +126,6 @@ bool FuncVerifier::verify() {
if (!funcNameRegex.match(fn.getName().strref()))
return failure("invalid function name '" + fn.getName().strref() + "'", fn);
// External functions have nothing more to check.
if (fn.isExternal())
return false;
// Verify the first block has no predecessors.
auto *firstBB = &fn.front();
if (!firstBB->hasNoPredecessors())
return failure("entry block of function may not have predecessors", fn);
/// Verify that all of the attributes are okay.
for (auto attr : fn.getAttrs()) {
if (!attrNameRegex.match(attr.first))
@ -143,6 +135,28 @@ bool FuncVerifier::verify() {
return true;
}
/// Verify that all of the argument attributes are okay.
for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) {
for (auto attr : fn.getArgAttrs(i)) {
if (!attrNameRegex.match(attr.first))
return failure(
llvm::formatv("invalid attribute name '{0}' on argument {1}",
attr.first.strref(), i),
fn);
if (verifyAttribute(attr.second, fn))
return true;
}
}
// External functions have nothing more to check.
if (fn.isExternal())
return false;
// Verify the first block has no predecessors.
auto *firstBB = &fn.front();
if (!firstBB->hasNoPredecessors())
return failure("entry block of function may not have predecessors", fn);
// Verify that the argument list of the function and the arg list of the first
// block line up.
auto fnInputTypes = fn.getType().getInputs();

View File

@ -1299,20 +1299,21 @@ void FunctionPrinter::printFunctionSignature() {
os << "func @" << function->getName() << '(';
auto fnType = function->getType();
bool isExternal = function->isExternal();
for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i) {
if (i > 0)
os << ", ";
// If this is an external function, don't print argument labels.
if (function->isExternal()) {
interleaveComma(fnType.getInputs(),
[&](Type eltType) { printType(eltType); });
} else {
for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i) {
if (i > 0)
os << ", ";
auto *arg = function->getArgument(i);
printOperand(arg);
// If this is an external function, don't print argument labels.
if (!isExternal) {
printOperand(function->getArgument(i));
os << ": ";
printType(arg->getType());
}
printType(fnType.getInput(i));
// Print the attributes for this argument.
printOptionalAttrDict(function->getArgAttrs(i));
}
os << ')';

View File

@ -501,7 +501,10 @@ Attribute NamedAttributeList::get(StringRef name) const {
return nullptr;
}
Attribute NamedAttributeList::get(Identifier name) const {
return get(name.strref());
for (auto elt : getAttrs())
if (elt.first == name)
return elt.second;
return nullptr;
}
/// If the an attribute exists with the specified name, change it to the new

View File

@ -29,7 +29,8 @@ using namespace mlir;
Function::Function(Location location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs)
: name(Identifier::get(name, type.getContext())), location(location),
type(type), attrs(type.getContext(), attrs), blocks(this) {}
type(type), attrs(type.getContext(), attrs),
argAttrs(type.getNumInputs()), blocks(this) {}
Function::~Function() {
// Instructions may have cyclic references, which need to be dropped before we
@ -167,7 +168,8 @@ Function *Function::clone(BlockAndValueMapping &mapper) const {
// If the function has a body, then the user might be deleting arguments to
// the function by specifying them in the mapper. If so, we don't add the
// argument to the input type vector.
if (!empty()) {
bool isExternalFn = isExternal();
if (!isExternalFn) {
SmallVector<Type, 4> inputTypes;
for (unsigned i = 0, e = getNumArguments(); i != e; ++i)
if (!mapper.contains(getArgument(i)))
@ -175,8 +177,15 @@ Function *Function::clone(BlockAndValueMapping &mapper) const {
newType = FunctionType::get(inputTypes, type.getResults(), getContext());
}
// Create a new function and clone the current function into it.
// Create the new function.
Function *newFunc = new Function(getLoc(), getName(), newType);
/// Set the argument attributes for arguments that aren't being replaced.
for (unsigned i = 0, e = getNumArguments(), destI = 0; i != e; ++i)
if (isExternalFn || !mapper.contains(getArgument(i)))
newFunc->setArgAttrs(destI++, getArgAttrs(i));
/// Clone the current function into the new one and return it.
cloneInto(newFunc, mapper);
return newFunc;
}

View File

@ -3380,10 +3380,13 @@ private:
ParseResult parseTypeAliasDef();
// Functions.
ParseResult parseArgumentList(SmallVectorImpl<Type> &argTypes,
SmallVectorImpl<StringRef> &argNames);
ParseResult parseFunctionSignature(StringRef &name, FunctionType &type,
SmallVectorImpl<StringRef> &argNames);
ParseResult
parseArgumentList(SmallVectorImpl<Type> &argTypes,
SmallVectorImpl<StringRef> &argNames,
SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs);
ParseResult parseFunctionSignature(
StringRef &name, FunctionType &type, SmallVectorImpl<StringRef> &argNames,
SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs);
ParseResult parseFunc();
};
} // end anonymous namespace
@ -3466,13 +3469,14 @@ ParseResult ModuleParser::parseTypeAliasDef() {
/// Parse a (possibly empty) list of Function arguments with types.
///
/// named-argument ::= ssa-id `:` type
/// named-argument ::= ssa-id `:` type attribute-dict?
/// argument-list ::= named-argument (`,` named-argument)* | /*empty*/
/// argument-list ::= type (`,` type)* | /*empty*/
/// argument-list ::= type attribute-dict? (`,` type attribute-dict?)*
/// | /*empty*/
///
ParseResult
ModuleParser::parseArgumentList(SmallVectorImpl<Type> &argTypes,
SmallVectorImpl<StringRef> &argNames) {
ParseResult ModuleParser::parseArgumentList(
SmallVectorImpl<Type> &argTypes, SmallVectorImpl<StringRef> &argNames,
SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs) {
consumeToken(Token::l_paren);
// The argument list either has to consistently have ssa-id's followed by
@ -3502,6 +3506,14 @@ ModuleParser::parseArgumentList(SmallVectorImpl<Type> &argTypes,
if (!elt)
return ParseFailure;
argTypes.push_back(elt);
// Parse the attribute dict.
SmallVector<NamedAttribute, 2> attrs;
if (getToken().is(Token::l_brace)) {
if (parseAttributeDict(attrs))
return ParseFailure;
}
argAttrs.push_back(attrs);
return ParseSuccess;
};
@ -3514,9 +3526,9 @@ ModuleParser::parseArgumentList(SmallVectorImpl<Type> &argTypes,
/// function-signature ::=
/// function-id `(` argument-list `)` (`->` type-list)?
///
ParseResult
ModuleParser::parseFunctionSignature(StringRef &name, FunctionType &type,
SmallVectorImpl<StringRef> &argNames) {
ParseResult ModuleParser::parseFunctionSignature(
StringRef &name, FunctionType &type, SmallVectorImpl<StringRef> &argNames,
SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs) {
if (getToken().isNot(Token::at_identifier))
return emitError("expected a function identifier like '@foo'");
@ -3527,7 +3539,7 @@ ModuleParser::parseFunctionSignature(StringRef &name, FunctionType &type,
return emitError("expected '(' in function signature");
SmallVector<Type, 4> argTypes;
if (parseArgumentList(argTypes, argNames))
if (parseArgumentList(argTypes, argNames, argAttrs))
return ParseFailure;
// Parse the return type if present.
@ -3553,9 +3565,10 @@ ParseResult ModuleParser::parseFunc() {
StringRef name;
FunctionType type;
SmallVector<StringRef, 4> argNames;
SmallVector<SmallVector<NamedAttribute, 2>, 4> argAttrs;
auto loc = getToken().getLoc();
if (parseFunctionSignature(name, type, argNames))
if (parseFunctionSignature(name, type, argNames, argAttrs))
return ParseFailure;
// If function attributes are present, parse them.
@ -3579,6 +3592,10 @@ ParseResult ModuleParser::parseFunc() {
if (parseOptionalTrailingLocation(function))
return ParseFailure;
// Add the attributes to the function arguments.
for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i)
function->setArgAttrs(i, argAttrs[i]);
// External functions have no body.
if (getToken().isNot(Token::l_brace))
return ParseSuccess;

View File

@ -812,3 +812,11 @@ func @internal_attrs()
// CHECK-LABEL: func @_valid.function$name
func @_valid.function$name()
// CHECK-LABEL: func @external_func_arg_attrs(i32, i1 {arg.attr: 10}, i32)
func @external_func_arg_attrs(i32, i1 {arg.attr: 10}, i32)
// CHECK-LABEL: func @func_arg_attrs(%arg0: i1 {arg.attr: 10})
func @func_arg_attrs(%arg0: i1 {arg.attr: 10}) {
return
}