forked from OSchip/llvm-project
Implement call and call_indirect ops.
This also fixes an infinite recursion in VariadicOperands that this turned up. PiperOrigin-RevId: 209692932
This commit is contained in:
parent
00bed4bd99
commit
84259c7def
|
@ -24,6 +24,7 @@
|
|||
namespace mlir {
|
||||
class AffineMap;
|
||||
class Function;
|
||||
class FunctionType;
|
||||
class MLIRContext;
|
||||
class Type;
|
||||
|
||||
|
@ -222,10 +223,12 @@ private:
|
|||
/// remain in MLIRContext.
|
||||
class FunctionAttr : public Attribute {
|
||||
public:
|
||||
static FunctionAttr *get(Function *value, MLIRContext *context);
|
||||
static FunctionAttr *get(const Function *value, MLIRContext *context);
|
||||
|
||||
Function *getValue() const { return value; }
|
||||
|
||||
FunctionType *getType() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const Attribute *attr) {
|
||||
return attr->getKind() == Kind::Function;
|
||||
|
|
|
@ -86,7 +86,7 @@ public:
|
|||
ArrayAttr *getArrayAttr(ArrayRef<Attribute *> value);
|
||||
AffineMapAttr *getAffineMapAttr(AffineMap *value);
|
||||
TypeAttr *getTypeAttr(Type *type);
|
||||
FunctionAttr *getFunctionAttr(Function *value);
|
||||
FunctionAttr *getFunctionAttr(const Function *value);
|
||||
|
||||
// Affine Expressions and Affine Map.
|
||||
AffineMap *getAffineMap(unsigned dimCount, unsigned symbolCount,
|
||||
|
|
|
@ -370,7 +370,7 @@ public:
|
|||
return this->getOperation()->operand_end();
|
||||
}
|
||||
llvm::iterator_range<const_operand_iterator> getOperands() const {
|
||||
return this->getOperands();
|
||||
return this->getOperation()->getOperands();
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -31,6 +31,7 @@ namespace mlir {
|
|||
class AffineMap;
|
||||
class AffineExpr;
|
||||
class Builder;
|
||||
class Function;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpAsmPrinter
|
||||
|
@ -65,6 +66,7 @@ public:
|
|||
}
|
||||
}
|
||||
virtual void printType(const Type *type) = 0;
|
||||
virtual void printFunctionReference(const Function *func) = 0;
|
||||
virtual void printAttribute(const Attribute *attr) = 0;
|
||||
virtual void printAffineMap(const AffineMap *map) = 0;
|
||||
virtual void printAffineExpr(const AffineExpr *expr) = 0;
|
||||
|
@ -147,9 +149,12 @@ public:
|
|||
// High level parsing methods.
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
// These return void if they always succeed. If they can fail, they emit an
|
||||
// error and return "true". On success, they can optionally provide location
|
||||
// information for clients who want it.
|
||||
// These emit an error and return "true" on failure. On success, they can
|
||||
// optionally provide location information for clients who want it.
|
||||
|
||||
/// Get the location of the next token and store it into the argument. This
|
||||
/// always succeeds.
|
||||
virtual bool getCurrentLocation(llvm::SMLoc *loc) = 0;
|
||||
|
||||
/// This parses... a comma!
|
||||
virtual bool parseComma(llvm::SMLoc *loc = nullptr) = 0;
|
||||
|
@ -190,6 +195,14 @@ public:
|
|||
return false;
|
||||
}
|
||||
|
||||
/// Add the specified types to the end of the specified type list and return
|
||||
/// false. This is a helper designed to allow parse methods to be simple and
|
||||
/// chain through || operators.
|
||||
bool addTypesToList(ArrayRef<Type *> types, SmallVectorImpl<Type *> &result) {
|
||||
result.append(types.begin(), types.end());
|
||||
return false;
|
||||
}
|
||||
|
||||
/// Parse an arbitrary attribute and return it in result. This also adds the
|
||||
/// attribute to the specified attribute list with the specified name. this
|
||||
/// captures the location of the attribute in 'loc' if it is non-null.
|
||||
|
@ -226,6 +239,10 @@ public:
|
|||
parseOptionalAttributeDict(SmallVectorImpl<NamedAttribute> &result,
|
||||
llvm::SMLoc *loc = nullptr) = 0;
|
||||
|
||||
/// Parse a function name like '@foo' and return the name in a form that can
|
||||
/// be passed to resolveFunctionName when a function type is available.
|
||||
virtual bool parseFunctionName(StringRef &result, llvm::SMLoc &loc) = 0;
|
||||
|
||||
/// This is the representation of an operand reference.
|
||||
struct OperandType {
|
||||
llvm::SMLoc location; // Location of the token.
|
||||
|
@ -270,7 +287,7 @@ public:
|
|||
|
||||
/// Resolve an operand to an SSA value, emitting an error and returning true
|
||||
/// on failure.
|
||||
virtual bool resolveOperand(OperandType operand, Type *type,
|
||||
virtual bool resolveOperand(const OperandType &operand, Type *type,
|
||||
SmallVectorImpl<SSAValue *> &result) = 0;
|
||||
|
||||
/// Resolve a list of operands to SSA values, emitting an error and returning
|
||||
|
@ -288,8 +305,13 @@ public:
|
|||
/// emitting an error and returning true on failure, or appending the results
|
||||
/// to the list on success.
|
||||
virtual bool resolveOperands(ArrayRef<OperandType> operands,
|
||||
ArrayRef<Type *> types,
|
||||
ArrayRef<Type *> types, llvm::SMLoc loc,
|
||||
SmallVectorImpl<SSAValue *> &result) {
|
||||
if (operands.size() != types.size())
|
||||
return emitError(loc, Twine(operands.size()) +
|
||||
" operands present, but expected " +
|
||||
Twine(types.size()));
|
||||
|
||||
for (unsigned i = 0, e = operands.size(); i != e; ++i) {
|
||||
if (resolveOperand(operands[i], types[i], result))
|
||||
return true;
|
||||
|
@ -297,6 +319,10 @@ public:
|
|||
return false;
|
||||
}
|
||||
|
||||
/// Resolve a parse function name and a type into a function reference.
|
||||
virtual bool resolveFunctionName(StringRef name, FunctionType *type,
|
||||
llvm::SMLoc loc, Function *&result) = 0;
|
||||
|
||||
/// Emit a diagnostic at the specified location and return true.
|
||||
virtual bool emitError(llvm::SMLoc loc, const Twine &message) = 0;
|
||||
};
|
||||
|
|
|
@ -41,7 +41,6 @@ class AddFOp
|
|||
: public OpBase<AddFOp, OpTrait::NOperands<2>::Impl, OpTrait::OneResult,
|
||||
OpTrait::SameOperandsAndResultType> {
|
||||
public:
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static StringRef getOperationName() { return "addf"; }
|
||||
|
||||
template <class Builder, class Value>
|
||||
|
@ -86,7 +85,6 @@ public:
|
|||
return getAttrOfType<AffineMapAttr>("map")->getValue();
|
||||
}
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static StringRef getOperationName() { return "affine_apply"; }
|
||||
|
||||
// Hooks to customize behavior of this op.
|
||||
|
@ -136,6 +134,63 @@ private:
|
|||
explicit AllocOp(const Operation *state) : OpBase(state) {}
|
||||
};
|
||||
|
||||
/// The "call" operation represents a direct call to a function. The operands
|
||||
/// and result types of the call must match the specified function type. The
|
||||
/// callee is encoded as a function attribute named "callee".
|
||||
///
|
||||
/// %31 = call @my_add(%0, %1)
|
||||
/// : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32>
|
||||
class CallOp : public OpBase<CallOp, OpTrait::VariadicOperands,
|
||||
OpTrait::VariadicResults> {
|
||||
public:
|
||||
static StringRef getOperationName() { return "call"; }
|
||||
|
||||
static OperationState build(Builder *builder, Function *callee,
|
||||
ArrayRef<SSAValue *> operands);
|
||||
|
||||
Function *getCallee() const {
|
||||
return getAttrOfType<FunctionAttr>("callee")->getValue();
|
||||
}
|
||||
|
||||
// Hooks to customize behavior of this op.
|
||||
static bool parse(OpAsmParser *parser, OperationState *result);
|
||||
void print(OpAsmPrinter *p) const;
|
||||
const char *verify() const;
|
||||
|
||||
protected:
|
||||
friend class Operation;
|
||||
explicit CallOp(const Operation *state) : OpBase(state) {}
|
||||
};
|
||||
|
||||
/// The "call_indirect" operation represents an indirect call to a value of
|
||||
/// function type. Functions are first class types in MLIR, and may be passed
|
||||
/// as arguments and merged together with basic block arguments. The operands
|
||||
/// and result types of the call must match the specified function type.
|
||||
///
|
||||
/// %31 = call_indirect %15(%0, %1)
|
||||
/// : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32>
|
||||
///
|
||||
class CallIndirectOp : public OpBase<CallIndirectOp, OpTrait::VariadicOperands,
|
||||
OpTrait::VariadicResults> {
|
||||
public:
|
||||
static StringRef getOperationName() { return "call_indirect"; }
|
||||
|
||||
static OperationState build(Builder *builder, SSAValue *callee,
|
||||
ArrayRef<SSAValue *> operands);
|
||||
|
||||
const SSAValue *getCallee() const { return getOperand(0); }
|
||||
SSAValue *getCallee() { return getOperand(0); }
|
||||
|
||||
// Hooks to customize behavior of this op.
|
||||
static bool parse(OpAsmParser *parser, OperationState *result);
|
||||
void print(OpAsmPrinter *p) const;
|
||||
const char *verify() const;
|
||||
|
||||
protected:
|
||||
friend class Operation;
|
||||
explicit CallIndirectOp(const Operation *state) : OpBase(state) {}
|
||||
};
|
||||
|
||||
/// The "constant" operation requires a single attribute named "value".
|
||||
/// It returns its value as an SSA value. For example:
|
||||
///
|
||||
|
@ -147,7 +202,6 @@ class ConstantOp
|
|||
public:
|
||||
Attribute *getValue() const { return getAttr("value"); }
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static StringRef getOperationName() { return "constant"; }
|
||||
|
||||
// Hooks to customize behavior of this op.
|
||||
|
@ -264,7 +318,6 @@ public:
|
|||
return (unsigned)getAttrOfType<IntegerAttr>("index")->getValue();
|
||||
}
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static StringRef getOperationName() { return "dim"; }
|
||||
|
||||
// Hooks to customize behavior of this op.
|
||||
|
@ -362,7 +415,6 @@ private:
|
|||
class ReturnOp
|
||||
: public OpBase<ReturnOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
|
||||
public:
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static StringRef getOperationName() { return "return"; }
|
||||
|
||||
// Hooks to customize behavior of this op.
|
||||
|
|
|
@ -224,12 +224,22 @@ public:
|
|||
static FunctionType *get(ArrayRef<Type*> inputs, ArrayRef<Type*> results,
|
||||
MLIRContext *context);
|
||||
|
||||
// Input types.
|
||||
unsigned getNumInputs() const { return getSubclassData(); }
|
||||
|
||||
Type *getInput(unsigned i) const { return getInputs()[i]; }
|
||||
|
||||
ArrayRef<Type*> getInputs() const {
|
||||
return ArrayRef<Type*>(inputsAndResults, getSubclassData());
|
||||
return ArrayRef<Type *>(inputsAndResults, getNumInputs());
|
||||
}
|
||||
|
||||
// Result types.
|
||||
unsigned getNumResults() const { return numResults; }
|
||||
|
||||
Type *getResult(unsigned i) const { return getResults()[i]; }
|
||||
|
||||
ArrayRef<Type*> getResults() const {
|
||||
return ArrayRef<Type*>(inputsAndResults+getSubclassData(), numResults);
|
||||
return ArrayRef<Type *>(inputsAndResults + getSubclassData(), numResults);
|
||||
}
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
|
|
|
@ -247,6 +247,7 @@ public:
|
|||
}
|
||||
|
||||
void print(const Module *module);
|
||||
void printFunctionReference(const Function *func);
|
||||
void printAttribute(const Attribute *attr);
|
||||
void printType(const Type *type);
|
||||
void print(const Function *fn);
|
||||
|
@ -387,6 +388,10 @@ static void printFloatValue(double value, raw_ostream &os) {
|
|||
}
|
||||
}
|
||||
|
||||
void ModulePrinter::printFunctionReference(const Function *func) {
|
||||
os << '@' << func->getName();
|
||||
}
|
||||
|
||||
void ModulePrinter::printAttribute(const Attribute *attr) {
|
||||
switch (attr->getKind()) {
|
||||
case Attribute::Kind::Bool:
|
||||
|
@ -420,7 +425,8 @@ void ModulePrinter::printAttribute(const Attribute *attr) {
|
|||
if (!function) {
|
||||
os << "<<FUNCTION ATTR FOR DELETED FUNCTION>>";
|
||||
} else {
|
||||
os << '@' << function->getName() << " : ";
|
||||
printFunctionReference(function);
|
||||
os << " : ";
|
||||
printType(function->getType());
|
||||
}
|
||||
break;
|
||||
|
@ -768,6 +774,9 @@ public:
|
|||
void printAffineExpr(const AffineExpr *expr) {
|
||||
return ModulePrinter::printAffineExpr(expr);
|
||||
}
|
||||
void printFunctionReference(const Function *func) {
|
||||
return ModulePrinter::printFunctionReference(func);
|
||||
}
|
||||
|
||||
void printOperand(const SSAValue *value) { printValueID(value); }
|
||||
|
||||
|
|
|
@ -111,7 +111,7 @@ TypeAttr *Builder::getTypeAttr(Type *type) {
|
|||
return TypeAttr::get(type, context);
|
||||
}
|
||||
|
||||
FunctionAttr *Builder::getFunctionAttr(Function *value) {
|
||||
FunctionAttr *Builder::getFunctionAttr(const Function *value) {
|
||||
return FunctionAttr::get(value, context);
|
||||
}
|
||||
|
||||
|
|
|
@ -255,7 +255,7 @@ public:
|
|||
using AttributeListSet =
|
||||
DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
|
||||
AttributeListSet attributeLists;
|
||||
DenseMap<Function *, FunctionAttr *> functionAttrs;
|
||||
DenseMap<const Function *, FunctionAttr *> functionAttrs;
|
||||
|
||||
public:
|
||||
MLIRContextImpl() : identifiers(allocator) {
|
||||
|
@ -648,16 +648,20 @@ TypeAttr *TypeAttr::get(Type *type, MLIRContext *context) {
|
|||
return result;
|
||||
}
|
||||
|
||||
FunctionAttr *FunctionAttr::get(Function *value, MLIRContext *context) {
|
||||
FunctionAttr *FunctionAttr::get(const Function *value, MLIRContext *context) {
|
||||
assert(value && "Cannot get FunctionAttr for a null function");
|
||||
|
||||
auto *&result = context->getImpl().functionAttrs[value];
|
||||
if (result)
|
||||
return result;
|
||||
|
||||
result = context->getImpl().allocator.Allocate<FunctionAttr>();
|
||||
new (result) FunctionAttr(value);
|
||||
new (result) FunctionAttr(const_cast<Function *>(value));
|
||||
return result;
|
||||
}
|
||||
|
||||
FunctionType *FunctionAttr::getType() const { return getValue()->getType(); }
|
||||
|
||||
/// This function is used by the internals of the Function class to null out
|
||||
/// attributes refering to functions that are about to be deleted.
|
||||
void FunctionAttr::dropFunctionReference(Function *value) {
|
||||
|
|
|
@ -129,7 +129,7 @@ const char *AffineApplyOp::verify() const {
|
|||
// Check that affine map attribute was specified.
|
||||
auto *affineMapAttr = getAttrOfType<AffineMapAttr>("map");
|
||||
if (!affineMapAttr)
|
||||
return "requires an affine map.";
|
||||
return "requires an affine map";
|
||||
|
||||
// Check input and output dimensions match.
|
||||
auto *map = affineMapAttr->getValue();
|
||||
|
@ -198,7 +198,152 @@ const char *AllocOp::verify() const {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConstantOp
|
||||
// CallOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OperationState CallOp::build(Builder *builder, Function *callee,
|
||||
ArrayRef<SSAValue *> operands) {
|
||||
OperationState result(builder->getIdentifier("call"));
|
||||
result.operands.append(operands.begin(), operands.end());
|
||||
result.attributes.push_back(
|
||||
{builder->getIdentifier("callee"), builder->getFunctionAttr(callee)});
|
||||
result.types.append(callee->getType()->getResults().begin(),
|
||||
callee->getType()->getResults().end());
|
||||
return result;
|
||||
}
|
||||
|
||||
bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
StringRef calleeName;
|
||||
llvm::SMLoc calleeLoc;
|
||||
FunctionType *calleeType = nullptr;
|
||||
SmallVector<OpAsmParser::OperandType, 4> operands;
|
||||
Function *callee = nullptr;
|
||||
if (parser->parseFunctionName(calleeName, calleeLoc) ||
|
||||
parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
|
||||
OpAsmParser::Delimiter::Paren) ||
|
||||
parser->parseOptionalAttributeDict(result->attributes) ||
|
||||
parser->parseColonType(calleeType) ||
|
||||
parser->resolveFunctionName(calleeName, calleeType, calleeLoc, callee) ||
|
||||
parser->addTypesToList(calleeType->getResults(), result->types) ||
|
||||
parser->resolveOperands(operands, calleeType->getInputs(), calleeLoc,
|
||||
result->operands))
|
||||
return true;
|
||||
|
||||
auto &builder = parser->getBuilder();
|
||||
result->attributes.push_back(
|
||||
{builder.getIdentifier("callee"), builder.getFunctionAttr(callee)});
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void CallOp::print(OpAsmPrinter *p) const {
|
||||
*p << "call ";
|
||||
p->printFunctionReference(getCallee());
|
||||
*p << '(';
|
||||
p->printOperands(getOperands());
|
||||
*p << ')';
|
||||
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
|
||||
*p << " : " << *getCallee()->getType();
|
||||
}
|
||||
|
||||
const char *CallOp::verify() const {
|
||||
// Check that the callee attribute was specified.
|
||||
auto *fnAttr = getAttrOfType<FunctionAttr>("callee");
|
||||
if (!fnAttr)
|
||||
return "requires a 'callee' function attribute";
|
||||
|
||||
// Verify that the operand and result types match the callee.
|
||||
auto *fnType = fnAttr->getValue()->getType();
|
||||
if (fnType->getNumInputs() != getNumOperands())
|
||||
return "incorrect number of operands for callee";
|
||||
|
||||
for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
|
||||
if (getOperand(i)->getType() != fnType->getInput(i))
|
||||
return "operand type mismatch";
|
||||
}
|
||||
|
||||
if (fnType->getNumResults() != getNumResults())
|
||||
return "incorrect number of results for callee";
|
||||
|
||||
for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
|
||||
if (getResult(i)->getType() != fnType->getResult(i))
|
||||
return "result type mismatch";
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CallIndirectOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OperationState CallIndirectOp::build(Builder *builder, SSAValue *callee,
|
||||
ArrayRef<SSAValue *> operands) {
|
||||
auto *fnType = cast<FunctionType>(callee->getType());
|
||||
|
||||
OperationState result(builder->getIdentifier("call_indirect"));
|
||||
result.operands.push_back(callee);
|
||||
result.operands.append(operands.begin(), operands.end());
|
||||
result.types.append(fnType->getResults().begin(), fnType->getResults().end());
|
||||
return result;
|
||||
}
|
||||
|
||||
bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
FunctionType *calleeType = nullptr;
|
||||
OpAsmParser::OperandType callee;
|
||||
llvm::SMLoc operandsLoc;
|
||||
SmallVector<OpAsmParser::OperandType, 4> operands;
|
||||
return parser->parseOperand(callee) ||
|
||||
parser->getCurrentLocation(&operandsLoc) ||
|
||||
parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
|
||||
OpAsmParser::Delimiter::Paren) ||
|
||||
parser->parseOptionalAttributeDict(result->attributes) ||
|
||||
parser->parseColonType(calleeType) ||
|
||||
parser->resolveOperand(callee, calleeType, result->operands) ||
|
||||
parser->resolveOperands(operands, calleeType->getInputs(), operandsLoc,
|
||||
result->operands) ||
|
||||
parser->addTypesToList(calleeType->getResults(), result->types);
|
||||
}
|
||||
|
||||
void CallIndirectOp::print(OpAsmPrinter *p) const {
|
||||
*p << "call_indirect ";
|
||||
p->printOperand(getCallee());
|
||||
*p << '(';
|
||||
auto operandRange = getOperands();
|
||||
p->printOperands(++operandRange.begin(), operandRange.end());
|
||||
*p << ')';
|
||||
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
|
||||
*p << " : " << *getCallee()->getType();
|
||||
}
|
||||
|
||||
const char *CallIndirectOp::verify() const {
|
||||
// The callee must be a function.
|
||||
auto *fnType = dyn_cast<FunctionType>(getCallee()->getType());
|
||||
if (!fnType)
|
||||
return "callee must have function type";
|
||||
|
||||
// Verify that the operand and result types match the callee.
|
||||
if (fnType->getNumInputs() != getNumOperands() - 1)
|
||||
return "incorrect number of operands for callee";
|
||||
|
||||
for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
|
||||
if (getOperand(i + 1)->getType() != fnType->getInput(i))
|
||||
return "operand type mismatch";
|
||||
}
|
||||
|
||||
if (fnType->getNumResults() != getNumResults())
|
||||
return "incorrect number of results for callee";
|
||||
|
||||
for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
|
||||
if (getResult(i)->getType() != fnType->getResult(i))
|
||||
return "result type mismatch";
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Constant*Op
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void ConstantOp::print(OpAsmPrinter *p) const {
|
||||
|
@ -444,10 +589,10 @@ const char *LoadOp::verify() const {
|
|||
bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
SmallVector<OpAsmParser::OperandType, 2> opInfo;
|
||||
SmallVector<Type *, 2> types;
|
||||
|
||||
return parser->parseOperandList(opInfo, -1, OpAsmParser::Delimiter::None) ||
|
||||
llvm::SMLoc loc;
|
||||
return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) ||
|
||||
(!opInfo.empty() && parser->parseColonTypeList(types)) ||
|
||||
parser->resolveOperands(opInfo, types, result->operands);
|
||||
parser->resolveOperands(opInfo, types, loc, result->operands);
|
||||
}
|
||||
|
||||
void ReturnOp::print(OpAsmPrinter *p) const {
|
||||
|
@ -541,7 +686,7 @@ const char *StoreOp::verify() const {
|
|||
|
||||
/// Install the standard operations in the specified operation set.
|
||||
void mlir::registerStandardOperations(OperationSet &opSet) {
|
||||
opSet.addOperations<AddFOp, AffineApplyOp, AllocOp, ConstantOp, DeallocOp,
|
||||
DimOp, LoadOp, ReturnOp, StoreOp>(
|
||||
opSet.addOperations<AddFOp, AffineApplyOp, AllocOp, CallOp, CallIndirectOp,
|
||||
ConstantOp, DeallocOp, DimOp, LoadOp, ReturnOp, StoreOp>(
|
||||
/*prefix=*/"");
|
||||
}
|
||||
|
|
|
@ -33,6 +33,7 @@
|
|||
#include "mlir/IR/StmtVisitor.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/Support/PrettyStackTrace.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
using namespace mlir;
|
||||
using llvm::SMLoc;
|
||||
|
@ -180,6 +181,8 @@ public:
|
|||
ParseResult parseTypeList(SmallVectorImpl<Type *> &elements);
|
||||
|
||||
// Attribute parsing.
|
||||
Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
|
||||
FunctionType *type);
|
||||
Attribute *parseAttribute();
|
||||
ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes);
|
||||
|
||||
|
@ -578,6 +581,33 @@ ParseResult Parser::parseTypeList(SmallVectorImpl<Type *> &elements) {
|
|||
// Attribute parsing.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Given a parsed reference to a function name like @foo and a type that it
|
||||
/// corresponds to, resolve it to a concrete function object (possibly
|
||||
/// synthesizing a forward reference) or emit an error and return null on
|
||||
/// failure.
|
||||
Function *Parser::resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
|
||||
FunctionType *type) {
|
||||
Identifier name = builder.getIdentifier(nameStr.drop_front());
|
||||
|
||||
// See if the function has already been defined in the module.
|
||||
Function *function = getModule()->getNamedFunction(name);
|
||||
|
||||
// If not, get or create a forward reference to one.
|
||||
if (!function) {
|
||||
auto &entry = state.functionForwardRefs[name];
|
||||
if (!entry.first) {
|
||||
entry.first = new ExtFunction(name, type);
|
||||
entry.second = nameLoc;
|
||||
}
|
||||
function = entry.first;
|
||||
}
|
||||
|
||||
if (function->getType() != type)
|
||||
return (emitError(nameLoc, "reference to function with mismatched type"),
|
||||
nullptr);
|
||||
return function;
|
||||
}
|
||||
|
||||
/// Attribute parsing.
|
||||
///
|
||||
/// attribute-value ::= bool-literal
|
||||
|
@ -664,7 +694,7 @@ Attribute *Parser::parseAttribute() {
|
|||
|
||||
case Token::at_identifier: {
|
||||
auto nameLoc = getToken().getLoc();
|
||||
Identifier name = builder.getIdentifier(getTokenSpelling().drop_front());
|
||||
auto nameStr = getTokenSpelling();
|
||||
consumeToken(Token::at_identifier);
|
||||
|
||||
if (parseToken(Token::colon, "expected ':' and function type"))
|
||||
|
@ -673,28 +703,12 @@ Attribute *Parser::parseAttribute() {
|
|||
Type *type = parseType();
|
||||
if (!type)
|
||||
return nullptr;
|
||||
auto fnType = dyn_cast<FunctionType>(type);
|
||||
auto *fnType = dyn_cast<FunctionType>(type);
|
||||
if (!fnType)
|
||||
return (emitError(typeLoc, "expected function type"), nullptr);
|
||||
|
||||
// See if the function has already been defined in the module.
|
||||
Function *function = getModule()->getNamedFunction(name);
|
||||
|
||||
// If not, get or create a forward reference to one.
|
||||
if (!function) {
|
||||
auto &entry = state.functionForwardRefs[name];
|
||||
if (!entry.first) {
|
||||
entry.first = new ExtFunction(name, fnType);
|
||||
entry.second = nameLoc;
|
||||
}
|
||||
function = entry.first;
|
||||
}
|
||||
|
||||
if (function->getType() != type)
|
||||
return (emitError(typeLoc, "reference to function with mismatched type"),
|
||||
nullptr);
|
||||
|
||||
return builder.getFunctionAttr(function);
|
||||
auto *function = resolveFunctionReference(nameStr, nameLoc, fnType);
|
||||
return function ? builder.getFunctionAttr(function) : nullptr;
|
||||
}
|
||||
|
||||
default: {
|
||||
|
@ -1701,6 +1715,10 @@ public:
|
|||
// High level parsing methods.
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
bool getCurrentLocation(llvm::SMLoc *loc) override {
|
||||
*loc = parser.getToken().getLoc();
|
||||
return false;
|
||||
}
|
||||
bool parseComma(llvm::SMLoc *loc = nullptr) override {
|
||||
if (loc)
|
||||
*loc = parser.getToken().getLoc();
|
||||
|
@ -1753,6 +1771,19 @@ public:
|
|||
return parser.parseAttributeDict(result) == ParseFailure;
|
||||
}
|
||||
|
||||
/// Parse a function name like '@foo' and return the name in a form that can
|
||||
/// be passed to resolveFunctionName when a function type is available.
|
||||
virtual bool parseFunctionName(StringRef &result, llvm::SMLoc &loc) {
|
||||
loc = parser.getToken().getLoc();
|
||||
|
||||
if (parser.getToken().isNot(Token::at_identifier))
|
||||
return emitError(loc, "expected function name");
|
||||
|
||||
result = parser.getTokenSpelling();
|
||||
parser.consumeToken(Token::at_identifier);
|
||||
return false;
|
||||
}
|
||||
|
||||
bool parseOperand(OperandType &result) override {
|
||||
FunctionParser::SSAUseInfo useInfo;
|
||||
if (parser.parseSSAUse(useInfo))
|
||||
|
@ -1822,6 +1853,13 @@ public:
|
|||
return false;
|
||||
}
|
||||
|
||||
/// Resolve a parse function name and a type into a function reference.
|
||||
virtual bool resolveFunctionName(StringRef name, FunctionType *type,
|
||||
llvm::SMLoc loc, Function *&result) {
|
||||
result = parser.resolveFunctionReference(name, loc, type);
|
||||
return result == nullptr;
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Methods for interacting with the parser
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -1830,7 +1868,7 @@ public:
|
|||
|
||||
llvm::SMLoc getNameLoc() const override { return nameLoc; }
|
||||
|
||||
bool resolveOperand(OperandType operand, Type *type,
|
||||
bool resolveOperand(const OperandType &operand, Type *type,
|
||||
SmallVectorImpl<SSAValue *> &result) override {
|
||||
FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number,
|
||||
operand.location};
|
||||
|
@ -1872,6 +1910,11 @@ Operation *FunctionParser::parseCustomOperation(
|
|||
|
||||
consumeToken();
|
||||
|
||||
// If the custom op parser crashes, produce some indication to help debugging.
|
||||
std::string opNameStr = opName.str();
|
||||
llvm::PrettyStackTraceFormat fmt("MLIR Parser: custom op parser '%s'",
|
||||
opNameStr.c_str());
|
||||
|
||||
// Have the op implementation take a crack and parsing this.
|
||||
OperationState opState(builder.getIdentifier(opName));
|
||||
if (opDefinition->parseAssembly(&opAsmParser, &opState))
|
||||
|
|
|
@ -58,6 +58,9 @@ bb42(%t: tensor<4x4x?xf32>, %f: f32):
|
|||
// CHECK: %f_1 = constant @affine_apply : () -> ()
|
||||
%11 = constant @affine_apply : () -> ()
|
||||
|
||||
// CHECK: %f_2 = constant @affine_apply : () -> ()
|
||||
%12 = constant @affine_apply : () -> ()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -101,3 +104,32 @@ mlfunc @return_op(%a : i32) -> i32 {
|
|||
// CHECK: return %arg0 : i32
|
||||
"return" (%a) : (i32)->()
|
||||
}
|
||||
|
||||
// CHECK-LABEL: mlfunc @calls(%arg0 : i32) {
|
||||
mlfunc @calls(%arg0 : i32) {
|
||||
// CHECK: %0 = call @return_op(%arg0) : (i32) -> i32
|
||||
%x = call @return_op(%arg0) : (i32) -> i32
|
||||
// CHECK: %1 = call @return_op(%0) : (i32) -> i32
|
||||
%y = call @return_op(%x) : (i32) -> i32
|
||||
// CHECK: %2 = call @return_op(%0) : (i32) -> i32
|
||||
%z = "call"(%x) {callee: @return_op : (i32) -> i32} : (i32) -> i32
|
||||
|
||||
// CHECK: %f = constant @affine_apply : () -> ()
|
||||
%f = constant @affine_apply : () -> ()
|
||||
|
||||
// CHECK: call_indirect %f() : () -> ()
|
||||
call_indirect %f() : () -> ()
|
||||
|
||||
// CHECK: %f_0 = constant @return_op : (i32) -> i32
|
||||
%f_0 = constant @return_op : (i32) -> i32
|
||||
|
||||
// CHECK: %3 = call_indirect %f_0(%arg0) : (i32) -> i32
|
||||
%2 = call_indirect %f_0(%arg0) : (i32) -> i32
|
||||
|
||||
// CHECK: %4 = call_indirect %f_0(%arg0) : (i32) -> i32
|
||||
%3 = "call_indirect"(%f_0, %arg0) : ((i32) -> i32, i32) -> i32
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ bb:
|
|||
cfgfunc @affine_apply_no_map() {
|
||||
bb0:
|
||||
%i = "constant"() {value: 0} : () -> affineint
|
||||
%x = "affine_apply" (%i) { } : (affineint) -> (affineint) // expected-error {{'affine_apply' op requires an affine map.}}
|
||||
%x = "affine_apply" (%i) { } : (affineint) -> (affineint) // expected-error {{'affine_apply' op requires an affine map}}
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -105,3 +105,11 @@ mlfunc @mlfunc_constant() {
|
|||
%x = "constant"(){value: "xyz"} : () -> i32 // expected-error {{'constant' op requires 'value' to be an integer for an integer result type}}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
mlfunc @calls(%arg0 : i32) {
|
||||
%x = call @calls() : () -> i32 // expected-error {{reference to function with mismatched type}}
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -410,7 +410,7 @@ bb0:
|
|||
|
||||
// -----
|
||||
|
||||
cfgfunc @foo() {
|
||||
cfgfunc @undefined_function() {
|
||||
bb0:
|
||||
%x = constant @bar : (i32) -> () // expected-error {{reference to undefined function 'bar'}}
|
||||
return
|
||||
|
|
Loading…
Reference in New Issue