forked from OSchip/llvm-project
Add support for named function argument attributes. The attribute dictionary is printed after the argument type:
func @arg_attrs(i32 {arg_attr: 10}) func @arg_attrs(%arg0: i32 {arg_attr: 10}) PiperOrigin-RevId: 236136830
This commit is contained in:
parent
8cc50208a6
commit
db1757f858
|
@ -1101,8 +1101,8 @@ function ::= `func` function-signature function-attributes? function-body?
|
|||
|
||||
function-signature ::= function-id `(` argument-list `)` (`->` function-result-type)?
|
||||
argument-list ::= named-argument (`,` named-argument)* | /*empty*/
|
||||
argument-list ::= type (`,` type)* | /*empty*/ named-argument ::= ssa-id `:`
|
||||
type
|
||||
argument-list ::= type attribute-dict? (`,` type attribute-dict?)* | /*empty*/
|
||||
named-argument ::= ssa-id `:` type attribute-dict?
|
||||
|
||||
function-attributes ::= `attributes` attribute-dict
|
||||
function-body ::= `{` block+ `}`
|
||||
|
|
|
@ -548,6 +548,7 @@ using NamedAttribute = std::pair<Identifier, Attribute>;
|
|||
/// searches for everything.
|
||||
class NamedAttributeList {
|
||||
public:
|
||||
NamedAttributeList() : attrs(nullptr) {}
|
||||
NamedAttributeList(MLIRContext *context, ArrayRef<NamedAttribute> attributes);
|
||||
|
||||
/// Return all of the attributes on this operation.
|
||||
|
|
|
@ -178,18 +178,41 @@ public:
|
|||
/// constants to names. Attributes may be dynamically added and removed over
|
||||
/// the lifetime of an function.
|
||||
|
||||
/// Return all of the attributes on this instruction.
|
||||
/// Return all of the attributes on this function.
|
||||
ArrayRef<NamedAttribute> getAttrs() const { return attrs.getAttrs(); }
|
||||
|
||||
/// Return all of the attributes for the argument at 'index'.
|
||||
ArrayRef<NamedAttribute> getArgAttrs(unsigned index) const {
|
||||
assert(index < getNumArguments() && "invalid argument number");
|
||||
return argAttrs[index].getAttrs();
|
||||
}
|
||||
|
||||
/// Set the attributes held by this function.
|
||||
void setAttrs(ArrayRef<NamedAttribute> attributes) {
|
||||
attrs.setAttrs(getContext(), attributes);
|
||||
}
|
||||
|
||||
/// Set the attributes held by the argument at 'index'.
|
||||
void setArgAttrs(unsigned index, ArrayRef<NamedAttribute> attributes) {
|
||||
assert(index < getNumArguments() && "invalid argument number");
|
||||
argAttrs[index].setAttrs(getContext(), attributes);
|
||||
}
|
||||
|
||||
/// Return the specified attribute if present, null otherwise.
|
||||
Attribute getAttr(Identifier name) const { return attrs.get(name); }
|
||||
Attribute getAttr(StringRef name) const { return attrs.get(name); }
|
||||
|
||||
/// Return the specified attribute, if present, for the argument at 'index',
|
||||
/// null otherwise.
|
||||
Attribute getArgAttr(unsigned index, Identifier name) const {
|
||||
assert(index < getNumArguments() && "invalid argument number");
|
||||
return argAttrs[index].get(name);
|
||||
}
|
||||
Attribute getArgAttr(unsigned index, StringRef name) const {
|
||||
assert(index < getNumArguments() && "invalid argument number");
|
||||
return argAttrs[index].get(name);
|
||||
}
|
||||
|
||||
template <typename AttrClass> AttrClass getAttrOfType(Identifier name) const {
|
||||
return getAttr(name).dyn_cast_or_null<AttrClass>();
|
||||
}
|
||||
|
@ -198,17 +221,36 @@ public:
|
|||
return getAttr(name).dyn_cast_or_null<AttrClass>();
|
||||
}
|
||||
|
||||
template <typename AttrClass>
|
||||
AttrClass getArgAttrOfType(unsigned index, Identifier name) const {
|
||||
return getArgAttr(index, name).dyn_cast_or_null<AttrClass>();
|
||||
}
|
||||
|
||||
template <typename AttrClass>
|
||||
AttrClass getArgAttrOfType(unsigned index, StringRef name) const {
|
||||
return getArgAttr(index, name).dyn_cast_or_null<AttrClass>();
|
||||
}
|
||||
|
||||
/// If the an attribute exists with the specified name, change it to the new
|
||||
/// value. Otherwise, add a new attribute with the specified name/value.
|
||||
void setAttr(Identifier name, Attribute value) {
|
||||
attrs.set(getContext(), name, value);
|
||||
}
|
||||
void setArgAttr(unsigned index, Identifier name, Attribute value) {
|
||||
assert(index < getNumArguments() && "invalid argument number");
|
||||
argAttrs[index].set(getContext(), name, value);
|
||||
}
|
||||
|
||||
/// Remove the attribute with the specified name if it exists. The return
|
||||
/// value indicates whether the attribute was present or not.
|
||||
NamedAttributeList::RemoveResult removeAttr(Identifier name) {
|
||||
return attrs.remove(getContext(), name);
|
||||
}
|
||||
NamedAttributeList::RemoveResult removeArgAttr(unsigned index,
|
||||
Identifier name) {
|
||||
assert(index < getNumArguments() && "invalid argument number");
|
||||
return attrs.remove(getContext(), name);
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Other
|
||||
|
@ -272,6 +314,9 @@ private:
|
|||
/// This holds general named attributes for the function.
|
||||
NamedAttributeList attrs;
|
||||
|
||||
/// The attributes lists for each of the function arguments.
|
||||
std::vector<NamedAttributeList> argAttrs;
|
||||
|
||||
/// The contents of the body.
|
||||
BlockList blocks;
|
||||
|
||||
|
|
|
@ -38,6 +38,7 @@
|
|||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Instruction.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/PrettyStackTrace.h"
|
||||
#include "llvm/Support/Regex.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
@ -125,15 +126,6 @@ bool FuncVerifier::verify() {
|
|||
if (!funcNameRegex.match(fn.getName().strref()))
|
||||
return failure("invalid function name '" + fn.getName().strref() + "'", fn);
|
||||
|
||||
// External functions have nothing more to check.
|
||||
if (fn.isExternal())
|
||||
return false;
|
||||
|
||||
// Verify the first block has no predecessors.
|
||||
auto *firstBB = &fn.front();
|
||||
if (!firstBB->hasNoPredecessors())
|
||||
return failure("entry block of function may not have predecessors", fn);
|
||||
|
||||
/// Verify that all of the attributes are okay.
|
||||
for (auto attr : fn.getAttrs()) {
|
||||
if (!attrNameRegex.match(attr.first))
|
||||
|
@ -143,6 +135,28 @@ bool FuncVerifier::verify() {
|
|||
return true;
|
||||
}
|
||||
|
||||
/// Verify that all of the argument attributes are okay.
|
||||
for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) {
|
||||
for (auto attr : fn.getArgAttrs(i)) {
|
||||
if (!attrNameRegex.match(attr.first))
|
||||
return failure(
|
||||
llvm::formatv("invalid attribute name '{0}' on argument {1}",
|
||||
attr.first.strref(), i),
|
||||
fn);
|
||||
if (verifyAttribute(attr.second, fn))
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// External functions have nothing more to check.
|
||||
if (fn.isExternal())
|
||||
return false;
|
||||
|
||||
// Verify the first block has no predecessors.
|
||||
auto *firstBB = &fn.front();
|
||||
if (!firstBB->hasNoPredecessors())
|
||||
return failure("entry block of function may not have predecessors", fn);
|
||||
|
||||
// Verify that the argument list of the function and the arg list of the first
|
||||
// block line up.
|
||||
auto fnInputTypes = fn.getType().getInputs();
|
||||
|
|
|
@ -1299,20 +1299,21 @@ void FunctionPrinter::printFunctionSignature() {
|
|||
os << "func @" << function->getName() << '(';
|
||||
|
||||
auto fnType = function->getType();
|
||||
bool isExternal = function->isExternal();
|
||||
for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i) {
|
||||
if (i > 0)
|
||||
os << ", ";
|
||||
|
||||
// If this is an external function, don't print argument labels.
|
||||
if (function->isExternal()) {
|
||||
interleaveComma(fnType.getInputs(),
|
||||
[&](Type eltType) { printType(eltType); });
|
||||
} else {
|
||||
for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i) {
|
||||
if (i > 0)
|
||||
os << ", ";
|
||||
auto *arg = function->getArgument(i);
|
||||
printOperand(arg);
|
||||
// If this is an external function, don't print argument labels.
|
||||
if (!isExternal) {
|
||||
printOperand(function->getArgument(i));
|
||||
os << ": ";
|
||||
printType(arg->getType());
|
||||
}
|
||||
|
||||
printType(fnType.getInput(i));
|
||||
|
||||
// Print the attributes for this argument.
|
||||
printOptionalAttrDict(function->getArgAttrs(i));
|
||||
}
|
||||
os << ')';
|
||||
|
||||
|
|
|
@ -501,7 +501,10 @@ Attribute NamedAttributeList::get(StringRef name) const {
|
|||
return nullptr;
|
||||
}
|
||||
Attribute NamedAttributeList::get(Identifier name) const {
|
||||
return get(name.strref());
|
||||
for (auto elt : getAttrs())
|
||||
if (elt.first == name)
|
||||
return elt.second;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// If the an attribute exists with the specified name, change it to the new
|
||||
|
|
|
@ -29,7 +29,8 @@ using namespace mlir;
|
|||
Function::Function(Location location, StringRef name, FunctionType type,
|
||||
ArrayRef<NamedAttribute> attrs)
|
||||
: name(Identifier::get(name, type.getContext())), location(location),
|
||||
type(type), attrs(type.getContext(), attrs), blocks(this) {}
|
||||
type(type), attrs(type.getContext(), attrs),
|
||||
argAttrs(type.getNumInputs()), blocks(this) {}
|
||||
|
||||
Function::~Function() {
|
||||
// Instructions may have cyclic references, which need to be dropped before we
|
||||
|
@ -167,7 +168,8 @@ Function *Function::clone(BlockAndValueMapping &mapper) const {
|
|||
// If the function has a body, then the user might be deleting arguments to
|
||||
// the function by specifying them in the mapper. If so, we don't add the
|
||||
// argument to the input type vector.
|
||||
if (!empty()) {
|
||||
bool isExternalFn = isExternal();
|
||||
if (!isExternalFn) {
|
||||
SmallVector<Type, 4> inputTypes;
|
||||
for (unsigned i = 0, e = getNumArguments(); i != e; ++i)
|
||||
if (!mapper.contains(getArgument(i)))
|
||||
|
@ -175,8 +177,15 @@ Function *Function::clone(BlockAndValueMapping &mapper) const {
|
|||
newType = FunctionType::get(inputTypes, type.getResults(), getContext());
|
||||
}
|
||||
|
||||
// Create a new function and clone the current function into it.
|
||||
// Create the new function.
|
||||
Function *newFunc = new Function(getLoc(), getName(), newType);
|
||||
|
||||
/// Set the argument attributes for arguments that aren't being replaced.
|
||||
for (unsigned i = 0, e = getNumArguments(), destI = 0; i != e; ++i)
|
||||
if (isExternalFn || !mapper.contains(getArgument(i)))
|
||||
newFunc->setArgAttrs(destI++, getArgAttrs(i));
|
||||
|
||||
/// Clone the current function into the new one and return it.
|
||||
cloneInto(newFunc, mapper);
|
||||
return newFunc;
|
||||
}
|
||||
|
|
|
@ -3380,10 +3380,13 @@ private:
|
|||
ParseResult parseTypeAliasDef();
|
||||
|
||||
// Functions.
|
||||
ParseResult parseArgumentList(SmallVectorImpl<Type> &argTypes,
|
||||
SmallVectorImpl<StringRef> &argNames);
|
||||
ParseResult parseFunctionSignature(StringRef &name, FunctionType &type,
|
||||
SmallVectorImpl<StringRef> &argNames);
|
||||
ParseResult
|
||||
parseArgumentList(SmallVectorImpl<Type> &argTypes,
|
||||
SmallVectorImpl<StringRef> &argNames,
|
||||
SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs);
|
||||
ParseResult parseFunctionSignature(
|
||||
StringRef &name, FunctionType &type, SmallVectorImpl<StringRef> &argNames,
|
||||
SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs);
|
||||
ParseResult parseFunc();
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
@ -3466,13 +3469,14 @@ ParseResult ModuleParser::parseTypeAliasDef() {
|
|||
|
||||
/// Parse a (possibly empty) list of Function arguments with types.
|
||||
///
|
||||
/// named-argument ::= ssa-id `:` type
|
||||
/// named-argument ::= ssa-id `:` type attribute-dict?
|
||||
/// argument-list ::= named-argument (`,` named-argument)* | /*empty*/
|
||||
/// argument-list ::= type (`,` type)* | /*empty*/
|
||||
/// argument-list ::= type attribute-dict? (`,` type attribute-dict?)*
|
||||
/// | /*empty*/
|
||||
///
|
||||
ParseResult
|
||||
ModuleParser::parseArgumentList(SmallVectorImpl<Type> &argTypes,
|
||||
SmallVectorImpl<StringRef> &argNames) {
|
||||
ParseResult ModuleParser::parseArgumentList(
|
||||
SmallVectorImpl<Type> &argTypes, SmallVectorImpl<StringRef> &argNames,
|
||||
SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs) {
|
||||
consumeToken(Token::l_paren);
|
||||
|
||||
// The argument list either has to consistently have ssa-id's followed by
|
||||
|
@ -3502,6 +3506,14 @@ ModuleParser::parseArgumentList(SmallVectorImpl<Type> &argTypes,
|
|||
if (!elt)
|
||||
return ParseFailure;
|
||||
argTypes.push_back(elt);
|
||||
|
||||
// Parse the attribute dict.
|
||||
SmallVector<NamedAttribute, 2> attrs;
|
||||
if (getToken().is(Token::l_brace)) {
|
||||
if (parseAttributeDict(attrs))
|
||||
return ParseFailure;
|
||||
}
|
||||
argAttrs.push_back(attrs);
|
||||
return ParseSuccess;
|
||||
};
|
||||
|
||||
|
@ -3514,9 +3526,9 @@ ModuleParser::parseArgumentList(SmallVectorImpl<Type> &argTypes,
|
|||
/// function-signature ::=
|
||||
/// function-id `(` argument-list `)` (`->` type-list)?
|
||||
///
|
||||
ParseResult
|
||||
ModuleParser::parseFunctionSignature(StringRef &name, FunctionType &type,
|
||||
SmallVectorImpl<StringRef> &argNames) {
|
||||
ParseResult ModuleParser::parseFunctionSignature(
|
||||
StringRef &name, FunctionType &type, SmallVectorImpl<StringRef> &argNames,
|
||||
SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs) {
|
||||
if (getToken().isNot(Token::at_identifier))
|
||||
return emitError("expected a function identifier like '@foo'");
|
||||
|
||||
|
@ -3527,7 +3539,7 @@ ModuleParser::parseFunctionSignature(StringRef &name, FunctionType &type,
|
|||
return emitError("expected '(' in function signature");
|
||||
|
||||
SmallVector<Type, 4> argTypes;
|
||||
if (parseArgumentList(argTypes, argNames))
|
||||
if (parseArgumentList(argTypes, argNames, argAttrs))
|
||||
return ParseFailure;
|
||||
|
||||
// Parse the return type if present.
|
||||
|
@ -3553,9 +3565,10 @@ ParseResult ModuleParser::parseFunc() {
|
|||
StringRef name;
|
||||
FunctionType type;
|
||||
SmallVector<StringRef, 4> argNames;
|
||||
SmallVector<SmallVector<NamedAttribute, 2>, 4> argAttrs;
|
||||
|
||||
auto loc = getToken().getLoc();
|
||||
if (parseFunctionSignature(name, type, argNames))
|
||||
if (parseFunctionSignature(name, type, argNames, argAttrs))
|
||||
return ParseFailure;
|
||||
|
||||
// If function attributes are present, parse them.
|
||||
|
@ -3579,6 +3592,10 @@ ParseResult ModuleParser::parseFunc() {
|
|||
if (parseOptionalTrailingLocation(function))
|
||||
return ParseFailure;
|
||||
|
||||
// Add the attributes to the function arguments.
|
||||
for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i)
|
||||
function->setArgAttrs(i, argAttrs[i]);
|
||||
|
||||
// External functions have no body.
|
||||
if (getToken().isNot(Token::l_brace))
|
||||
return ParseSuccess;
|
||||
|
|
|
@ -812,3 +812,11 @@ func @internal_attrs()
|
|||
|
||||
// CHECK-LABEL: func @_valid.function$name
|
||||
func @_valid.function$name()
|
||||
|
||||
// CHECK-LABEL: func @external_func_arg_attrs(i32, i1 {arg.attr: 10}, i32)
|
||||
func @external_func_arg_attrs(i32, i1 {arg.attr: 10}, i32)
|
||||
|
||||
// CHECK-LABEL: func @func_arg_attrs(%arg0: i1 {arg.attr: 10})
|
||||
func @func_arg_attrs(%arg0: i1 {arg.attr: 10}) {
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue