Implement call and call_indirect ops.

This also fixes an infinite recursion in VariadicOperands that this turned up.

PiperOrigin-RevId: 209692932
This commit is contained in:
Chris Lattner 2018-08-21 17:55:22 -07:00 committed by jpienaar
parent 00bed4bd99
commit 84259c7def
14 changed files with 382 additions and 50 deletions

View File

@ -24,6 +24,7 @@
namespace mlir {
class AffineMap;
class Function;
class FunctionType;
class MLIRContext;
class Type;
@ -222,10 +223,12 @@ private:
/// remain in MLIRContext.
class FunctionAttr : public Attribute {
public:
static FunctionAttr *get(Function *value, MLIRContext *context);
static FunctionAttr *get(const Function *value, MLIRContext *context);
Function *getValue() const { return value; }
FunctionType *getType() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Attribute *attr) {
return attr->getKind() == Kind::Function;

View File

@ -86,7 +86,7 @@ public:
ArrayAttr *getArrayAttr(ArrayRef<Attribute *> value);
AffineMapAttr *getAffineMapAttr(AffineMap *value);
TypeAttr *getTypeAttr(Type *type);
FunctionAttr *getFunctionAttr(Function *value);
FunctionAttr *getFunctionAttr(const Function *value);
// Affine Expressions and Affine Map.
AffineMap *getAffineMap(unsigned dimCount, unsigned symbolCount,

View File

@ -370,7 +370,7 @@ public:
return this->getOperation()->operand_end();
}
llvm::iterator_range<const_operand_iterator> getOperands() const {
return this->getOperands();
return this->getOperation()->getOperands();
}
};

View File

@ -31,6 +31,7 @@ namespace mlir {
class AffineMap;
class AffineExpr;
class Builder;
class Function;
//===----------------------------------------------------------------------===//
// OpAsmPrinter
@ -65,6 +66,7 @@ public:
}
}
virtual void printType(const Type *type) = 0;
virtual void printFunctionReference(const Function *func) = 0;
virtual void printAttribute(const Attribute *attr) = 0;
virtual void printAffineMap(const AffineMap *map) = 0;
virtual void printAffineExpr(const AffineExpr *expr) = 0;
@ -147,9 +149,12 @@ public:
// High level parsing methods.
//===--------------------------------------------------------------------===//
// These return void if they always succeed. If they can fail, they emit an
// error and return "true". On success, they can optionally provide location
// information for clients who want it.
// These emit an error and return "true" on failure. On success, they can
// optionally provide location information for clients who want it.
/// Get the location of the next token and store it into the argument. This
/// always succeeds.
virtual bool getCurrentLocation(llvm::SMLoc *loc) = 0;
/// This parses... a comma!
virtual bool parseComma(llvm::SMLoc *loc = nullptr) = 0;
@ -190,6 +195,14 @@ public:
return false;
}
/// Add the specified types to the end of the specified type list and return
/// false. This is a helper designed to allow parse methods to be simple and
/// chain through || operators.
bool addTypesToList(ArrayRef<Type *> types, SmallVectorImpl<Type *> &result) {
result.append(types.begin(), types.end());
return false;
}
/// Parse an arbitrary attribute and return it in result. This also adds the
/// attribute to the specified attribute list with the specified name. this
/// captures the location of the attribute in 'loc' if it is non-null.
@ -226,6 +239,10 @@ public:
parseOptionalAttributeDict(SmallVectorImpl<NamedAttribute> &result,
llvm::SMLoc *loc = nullptr) = 0;
/// Parse a function name like '@foo' and return the name in a form that can
/// be passed to resolveFunctionName when a function type is available.
virtual bool parseFunctionName(StringRef &result, llvm::SMLoc &loc) = 0;
/// This is the representation of an operand reference.
struct OperandType {
llvm::SMLoc location; // Location of the token.
@ -270,7 +287,7 @@ public:
/// Resolve an operand to an SSA value, emitting an error and returning true
/// on failure.
virtual bool resolveOperand(OperandType operand, Type *type,
virtual bool resolveOperand(const OperandType &operand, Type *type,
SmallVectorImpl<SSAValue *> &result) = 0;
/// Resolve a list of operands to SSA values, emitting an error and returning
@ -288,8 +305,13 @@ public:
/// emitting an error and returning true on failure, or appending the results
/// to the list on success.
virtual bool resolveOperands(ArrayRef<OperandType> operands,
ArrayRef<Type *> types,
ArrayRef<Type *> types, llvm::SMLoc loc,
SmallVectorImpl<SSAValue *> &result) {
if (operands.size() != types.size())
return emitError(loc, Twine(operands.size()) +
" operands present, but expected " +
Twine(types.size()));
for (unsigned i = 0, e = operands.size(); i != e; ++i) {
if (resolveOperand(operands[i], types[i], result))
return true;
@ -297,6 +319,10 @@ public:
return false;
}
/// Resolve a parse function name and a type into a function reference.
virtual bool resolveFunctionName(StringRef name, FunctionType *type,
llvm::SMLoc loc, Function *&result) = 0;
/// Emit a diagnostic at the specified location and return true.
virtual bool emitError(llvm::SMLoc loc, const Twine &message) = 0;
};

View File

@ -41,7 +41,6 @@ class AddFOp
: public OpBase<AddFOp, OpTrait::NOperands<2>::Impl, OpTrait::OneResult,
OpTrait::SameOperandsAndResultType> {
public:
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static StringRef getOperationName() { return "addf"; }
template <class Builder, class Value>
@ -86,7 +85,6 @@ public:
return getAttrOfType<AffineMapAttr>("map")->getValue();
}
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static StringRef getOperationName() { return "affine_apply"; }
// Hooks to customize behavior of this op.
@ -136,6 +134,63 @@ private:
explicit AllocOp(const Operation *state) : OpBase(state) {}
};
/// The "call" operation represents a direct call to a function. The operands
/// and result types of the call must match the specified function type. The
/// callee is encoded as a function attribute named "callee".
///
/// %31 = call @my_add(%0, %1)
/// : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32>
class CallOp : public OpBase<CallOp, OpTrait::VariadicOperands,
OpTrait::VariadicResults> {
public:
static StringRef getOperationName() { return "call"; }
static OperationState build(Builder *builder, Function *callee,
ArrayRef<SSAValue *> operands);
Function *getCallee() const {
return getAttrOfType<FunctionAttr>("callee")->getValue();
}
// Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
const char *verify() const;
protected:
friend class Operation;
explicit CallOp(const Operation *state) : OpBase(state) {}
};
/// The "call_indirect" operation represents an indirect call to a value of
/// function type. Functions are first class types in MLIR, and may be passed
/// as arguments and merged together with basic block arguments. The operands
/// and result types of the call must match the specified function type.
///
/// %31 = call_indirect %15(%0, %1)
/// : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32>
///
class CallIndirectOp : public OpBase<CallIndirectOp, OpTrait::VariadicOperands,
OpTrait::VariadicResults> {
public:
static StringRef getOperationName() { return "call_indirect"; }
static OperationState build(Builder *builder, SSAValue *callee,
ArrayRef<SSAValue *> operands);
const SSAValue *getCallee() const { return getOperand(0); }
SSAValue *getCallee() { return getOperand(0); }
// Hooks to customize behavior of this op.
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
const char *verify() const;
protected:
friend class Operation;
explicit CallIndirectOp(const Operation *state) : OpBase(state) {}
};
/// The "constant" operation requires a single attribute named "value".
/// It returns its value as an SSA value. For example:
///
@ -147,7 +202,6 @@ class ConstantOp
public:
Attribute *getValue() const { return getAttr("value"); }
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static StringRef getOperationName() { return "constant"; }
// Hooks to customize behavior of this op.
@ -264,7 +318,6 @@ public:
return (unsigned)getAttrOfType<IntegerAttr>("index")->getValue();
}
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static StringRef getOperationName() { return "dim"; }
// Hooks to customize behavior of this op.
@ -362,7 +415,6 @@ private:
class ReturnOp
: public OpBase<ReturnOp, OpTrait::VariadicOperands, OpTrait::ZeroResult> {
public:
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static StringRef getOperationName() { return "return"; }
// Hooks to customize behavior of this op.

View File

@ -224,12 +224,22 @@ public:
static FunctionType *get(ArrayRef<Type*> inputs, ArrayRef<Type*> results,
MLIRContext *context);
// Input types.
unsigned getNumInputs() const { return getSubclassData(); }
Type *getInput(unsigned i) const { return getInputs()[i]; }
ArrayRef<Type*> getInputs() const {
return ArrayRef<Type*>(inputsAndResults, getSubclassData());
return ArrayRef<Type *>(inputsAndResults, getNumInputs());
}
// Result types.
unsigned getNumResults() const { return numResults; }
Type *getResult(unsigned i) const { return getResults()[i]; }
ArrayRef<Type*> getResults() const {
return ArrayRef<Type*>(inputsAndResults+getSubclassData(), numResults);
return ArrayRef<Type *>(inputsAndResults + getSubclassData(), numResults);
}
/// Methods for support type inquiry through isa, cast, and dyn_cast.

View File

@ -247,6 +247,7 @@ public:
}
void print(const Module *module);
void printFunctionReference(const Function *func);
void printAttribute(const Attribute *attr);
void printType(const Type *type);
void print(const Function *fn);
@ -387,6 +388,10 @@ static void printFloatValue(double value, raw_ostream &os) {
}
}
void ModulePrinter::printFunctionReference(const Function *func) {
os << '@' << func->getName();
}
void ModulePrinter::printAttribute(const Attribute *attr) {
switch (attr->getKind()) {
case Attribute::Kind::Bool:
@ -420,7 +425,8 @@ void ModulePrinter::printAttribute(const Attribute *attr) {
if (!function) {
os << "<<FUNCTION ATTR FOR DELETED FUNCTION>>";
} else {
os << '@' << function->getName() << " : ";
printFunctionReference(function);
os << " : ";
printType(function->getType());
}
break;
@ -768,6 +774,9 @@ public:
void printAffineExpr(const AffineExpr *expr) {
return ModulePrinter::printAffineExpr(expr);
}
void printFunctionReference(const Function *func) {
return ModulePrinter::printFunctionReference(func);
}
void printOperand(const SSAValue *value) { printValueID(value); }

View File

@ -111,7 +111,7 @@ TypeAttr *Builder::getTypeAttr(Type *type) {
return TypeAttr::get(type, context);
}
FunctionAttr *Builder::getFunctionAttr(Function *value) {
FunctionAttr *Builder::getFunctionAttr(const Function *value) {
return FunctionAttr::get(value, context);
}

View File

@ -255,7 +255,7 @@ public:
using AttributeListSet =
DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
AttributeListSet attributeLists;
DenseMap<Function *, FunctionAttr *> functionAttrs;
DenseMap<const Function *, FunctionAttr *> functionAttrs;
public:
MLIRContextImpl() : identifiers(allocator) {
@ -648,16 +648,20 @@ TypeAttr *TypeAttr::get(Type *type, MLIRContext *context) {
return result;
}
FunctionAttr *FunctionAttr::get(Function *value, MLIRContext *context) {
FunctionAttr *FunctionAttr::get(const Function *value, MLIRContext *context) {
assert(value && "Cannot get FunctionAttr for a null function");
auto *&result = context->getImpl().functionAttrs[value];
if (result)
return result;
result = context->getImpl().allocator.Allocate<FunctionAttr>();
new (result) FunctionAttr(value);
new (result) FunctionAttr(const_cast<Function *>(value));
return result;
}
FunctionType *FunctionAttr::getType() const { return getValue()->getType(); }
/// This function is used by the internals of the Function class to null out
/// attributes refering to functions that are about to be deleted.
void FunctionAttr::dropFunctionReference(Function *value) {

View File

@ -129,7 +129,7 @@ const char *AffineApplyOp::verify() const {
// Check that affine map attribute was specified.
auto *affineMapAttr = getAttrOfType<AffineMapAttr>("map");
if (!affineMapAttr)
return "requires an affine map.";
return "requires an affine map";
// Check input and output dimensions match.
auto *map = affineMapAttr->getValue();
@ -198,7 +198,152 @@ const char *AllocOp::verify() const {
}
//===----------------------------------------------------------------------===//
// ConstantOp
// CallOp
//===----------------------------------------------------------------------===//
OperationState CallOp::build(Builder *builder, Function *callee,
ArrayRef<SSAValue *> operands) {
OperationState result(builder->getIdentifier("call"));
result.operands.append(operands.begin(), operands.end());
result.attributes.push_back(
{builder->getIdentifier("callee"), builder->getFunctionAttr(callee)});
result.types.append(callee->getType()->getResults().begin(),
callee->getType()->getResults().end());
return result;
}
bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
StringRef calleeName;
llvm::SMLoc calleeLoc;
FunctionType *calleeType = nullptr;
SmallVector<OpAsmParser::OperandType, 4> operands;
Function *callee = nullptr;
if (parser->parseFunctionName(calleeName, calleeLoc) ||
parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
OpAsmParser::Delimiter::Paren) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(calleeType) ||
parser->resolveFunctionName(calleeName, calleeType, calleeLoc, callee) ||
parser->addTypesToList(calleeType->getResults(), result->types) ||
parser->resolveOperands(operands, calleeType->getInputs(), calleeLoc,
result->operands))
return true;
auto &builder = parser->getBuilder();
result->attributes.push_back(
{builder.getIdentifier("callee"), builder.getFunctionAttr(callee)});
return false;
}
void CallOp::print(OpAsmPrinter *p) const {
*p << "call ";
p->printFunctionReference(getCallee());
*p << '(';
p->printOperands(getOperands());
*p << ')';
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
*p << " : " << *getCallee()->getType();
}
const char *CallOp::verify() const {
// Check that the callee attribute was specified.
auto *fnAttr = getAttrOfType<FunctionAttr>("callee");
if (!fnAttr)
return "requires a 'callee' function attribute";
// Verify that the operand and result types match the callee.
auto *fnType = fnAttr->getValue()->getType();
if (fnType->getNumInputs() != getNumOperands())
return "incorrect number of operands for callee";
for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
if (getOperand(i)->getType() != fnType->getInput(i))
return "operand type mismatch";
}
if (fnType->getNumResults() != getNumResults())
return "incorrect number of results for callee";
for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
if (getResult(i)->getType() != fnType->getResult(i))
return "result type mismatch";
}
return nullptr;
}
//===----------------------------------------------------------------------===//
// CallIndirectOp
//===----------------------------------------------------------------------===//
OperationState CallIndirectOp::build(Builder *builder, SSAValue *callee,
ArrayRef<SSAValue *> operands) {
auto *fnType = cast<FunctionType>(callee->getType());
OperationState result(builder->getIdentifier("call_indirect"));
result.operands.push_back(callee);
result.operands.append(operands.begin(), operands.end());
result.types.append(fnType->getResults().begin(), fnType->getResults().end());
return result;
}
bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
FunctionType *calleeType = nullptr;
OpAsmParser::OperandType callee;
llvm::SMLoc operandsLoc;
SmallVector<OpAsmParser::OperandType, 4> operands;
return parser->parseOperand(callee) ||
parser->getCurrentLocation(&operandsLoc) ||
parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
OpAsmParser::Delimiter::Paren) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(calleeType) ||
parser->resolveOperand(callee, calleeType, result->operands) ||
parser->resolveOperands(operands, calleeType->getInputs(), operandsLoc,
result->operands) ||
parser->addTypesToList(calleeType->getResults(), result->types);
}
void CallIndirectOp::print(OpAsmPrinter *p) const {
*p << "call_indirect ";
p->printOperand(getCallee());
*p << '(';
auto operandRange = getOperands();
p->printOperands(++operandRange.begin(), operandRange.end());
*p << ')';
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
*p << " : " << *getCallee()->getType();
}
const char *CallIndirectOp::verify() const {
// The callee must be a function.
auto *fnType = dyn_cast<FunctionType>(getCallee()->getType());
if (!fnType)
return "callee must have function type";
// Verify that the operand and result types match the callee.
if (fnType->getNumInputs() != getNumOperands() - 1)
return "incorrect number of operands for callee";
for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
if (getOperand(i + 1)->getType() != fnType->getInput(i))
return "operand type mismatch";
}
if (fnType->getNumResults() != getNumResults())
return "incorrect number of results for callee";
for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
if (getResult(i)->getType() != fnType->getResult(i))
return "result type mismatch";
}
return nullptr;
}
//===----------------------------------------------------------------------===//
// Constant*Op
//===----------------------------------------------------------------------===//
void ConstantOp::print(OpAsmPrinter *p) const {
@ -444,10 +589,10 @@ const char *LoadOp::verify() const {
bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> opInfo;
SmallVector<Type *, 2> types;
return parser->parseOperandList(opInfo, -1, OpAsmParser::Delimiter::None) ||
llvm::SMLoc loc;
return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) ||
(!opInfo.empty() && parser->parseColonTypeList(types)) ||
parser->resolveOperands(opInfo, types, result->operands);
parser->resolveOperands(opInfo, types, loc, result->operands);
}
void ReturnOp::print(OpAsmPrinter *p) const {
@ -541,7 +686,7 @@ const char *StoreOp::verify() const {
/// Install the standard operations in the specified operation set.
void mlir::registerStandardOperations(OperationSet &opSet) {
opSet.addOperations<AddFOp, AffineApplyOp, AllocOp, ConstantOp, DeallocOp,
DimOp, LoadOp, ReturnOp, StoreOp>(
opSet.addOperations<AddFOp, AffineApplyOp, AllocOp, CallOp, CallIndirectOp,
ConstantOp, DeallocOp, DimOp, LoadOp, ReturnOp, StoreOp>(
/*prefix=*/"");
}

View File

@ -33,6 +33,7 @@
#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/SourceMgr.h"
using namespace mlir;
using llvm::SMLoc;
@ -180,6 +181,8 @@ public:
ParseResult parseTypeList(SmallVectorImpl<Type *> &elements);
// Attribute parsing.
Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
FunctionType *type);
Attribute *parseAttribute();
ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes);
@ -578,6 +581,33 @@ ParseResult Parser::parseTypeList(SmallVectorImpl<Type *> &elements) {
// Attribute parsing.
//===----------------------------------------------------------------------===//
/// Given a parsed reference to a function name like @foo and a type that it
/// corresponds to, resolve it to a concrete function object (possibly
/// synthesizing a forward reference) or emit an error and return null on
/// failure.
Function *Parser::resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
FunctionType *type) {
Identifier name = builder.getIdentifier(nameStr.drop_front());
// See if the function has already been defined in the module.
Function *function = getModule()->getNamedFunction(name);
// If not, get or create a forward reference to one.
if (!function) {
auto &entry = state.functionForwardRefs[name];
if (!entry.first) {
entry.first = new ExtFunction(name, type);
entry.second = nameLoc;
}
function = entry.first;
}
if (function->getType() != type)
return (emitError(nameLoc, "reference to function with mismatched type"),
nullptr);
return function;
}
/// Attribute parsing.
///
/// attribute-value ::= bool-literal
@ -664,7 +694,7 @@ Attribute *Parser::parseAttribute() {
case Token::at_identifier: {
auto nameLoc = getToken().getLoc();
Identifier name = builder.getIdentifier(getTokenSpelling().drop_front());
auto nameStr = getTokenSpelling();
consumeToken(Token::at_identifier);
if (parseToken(Token::colon, "expected ':' and function type"))
@ -673,28 +703,12 @@ Attribute *Parser::parseAttribute() {
Type *type = parseType();
if (!type)
return nullptr;
auto fnType = dyn_cast<FunctionType>(type);
auto *fnType = dyn_cast<FunctionType>(type);
if (!fnType)
return (emitError(typeLoc, "expected function type"), nullptr);
// See if the function has already been defined in the module.
Function *function = getModule()->getNamedFunction(name);
// If not, get or create a forward reference to one.
if (!function) {
auto &entry = state.functionForwardRefs[name];
if (!entry.first) {
entry.first = new ExtFunction(name, fnType);
entry.second = nameLoc;
}
function = entry.first;
}
if (function->getType() != type)
return (emitError(typeLoc, "reference to function with mismatched type"),
nullptr);
return builder.getFunctionAttr(function);
auto *function = resolveFunctionReference(nameStr, nameLoc, fnType);
return function ? builder.getFunctionAttr(function) : nullptr;
}
default: {
@ -1701,6 +1715,10 @@ public:
// High level parsing methods.
//===--------------------------------------------------------------------===//
bool getCurrentLocation(llvm::SMLoc *loc) override {
*loc = parser.getToken().getLoc();
return false;
}
bool parseComma(llvm::SMLoc *loc = nullptr) override {
if (loc)
*loc = parser.getToken().getLoc();
@ -1753,6 +1771,19 @@ public:
return parser.parseAttributeDict(result) == ParseFailure;
}
/// Parse a function name like '@foo' and return the name in a form that can
/// be passed to resolveFunctionName when a function type is available.
virtual bool parseFunctionName(StringRef &result, llvm::SMLoc &loc) {
loc = parser.getToken().getLoc();
if (parser.getToken().isNot(Token::at_identifier))
return emitError(loc, "expected function name");
result = parser.getTokenSpelling();
parser.consumeToken(Token::at_identifier);
return false;
}
bool parseOperand(OperandType &result) override {
FunctionParser::SSAUseInfo useInfo;
if (parser.parseSSAUse(useInfo))
@ -1822,6 +1853,13 @@ public:
return false;
}
/// Resolve a parse function name and a type into a function reference.
virtual bool resolveFunctionName(StringRef name, FunctionType *type,
llvm::SMLoc loc, Function *&result) {
result = parser.resolveFunctionReference(name, loc, type);
return result == nullptr;
}
//===--------------------------------------------------------------------===//
// Methods for interacting with the parser
//===--------------------------------------------------------------------===//
@ -1830,7 +1868,7 @@ public:
llvm::SMLoc getNameLoc() const override { return nameLoc; }
bool resolveOperand(OperandType operand, Type *type,
bool resolveOperand(const OperandType &operand, Type *type,
SmallVectorImpl<SSAValue *> &result) override {
FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number,
operand.location};
@ -1872,6 +1910,11 @@ Operation *FunctionParser::parseCustomOperation(
consumeToken();
// If the custom op parser crashes, produce some indication to help debugging.
std::string opNameStr = opName.str();
llvm::PrettyStackTraceFormat fmt("MLIR Parser: custom op parser '%s'",
opNameStr.c_str());
// Have the op implementation take a crack and parsing this.
OperationState opState(builder.getIdentifier(opName));
if (opDefinition->parseAssembly(&opAsmParser, &opState))

View File

@ -58,6 +58,9 @@ bb42(%t: tensor<4x4x?xf32>, %f: f32):
// CHECK: %f_1 = constant @affine_apply : () -> ()
%11 = constant @affine_apply : () -> ()
// CHECK: %f_2 = constant @affine_apply : () -> ()
%12 = constant @affine_apply : () -> ()
return
}
@ -101,3 +104,32 @@ mlfunc @return_op(%a : i32) -> i32 {
// CHECK: return %arg0 : i32
"return" (%a) : (i32)->()
}
// CHECK-LABEL: mlfunc @calls(%arg0 : i32) {
mlfunc @calls(%arg0 : i32) {
// CHECK: %0 = call @return_op(%arg0) : (i32) -> i32
%x = call @return_op(%arg0) : (i32) -> i32
// CHECK: %1 = call @return_op(%0) : (i32) -> i32
%y = call @return_op(%x) : (i32) -> i32
// CHECK: %2 = call @return_op(%0) : (i32) -> i32
%z = "call"(%x) {callee: @return_op : (i32) -> i32} : (i32) -> i32
// CHECK: %f = constant @affine_apply : () -> ()
%f = constant @affine_apply : () -> ()
// CHECK: call_indirect %f() : () -> ()
call_indirect %f() : () -> ()
// CHECK: %f_0 = constant @return_op : (i32) -> i32
%f_0 = constant @return_op : (i32) -> i32
// CHECK: %3 = call_indirect %f_0(%arg0) : (i32) -> i32
%2 = call_indirect %f_0(%arg0) : (i32) -> i32
// CHECK: %4 = call_indirect %f_0(%arg0) : (i32) -> i32
%3 = "call_indirect"(%f_0, %arg0) : ((i32) -> i32, i32) -> i32
return
}

View File

@ -35,7 +35,7 @@ bb:
cfgfunc @affine_apply_no_map() {
bb0:
%i = "constant"() {value: 0} : () -> affineint
%x = "affine_apply" (%i) { } : (affineint) -> (affineint) // expected-error {{'affine_apply' op requires an affine map.}}
%x = "affine_apply" (%i) { } : (affineint) -> (affineint) // expected-error {{'affine_apply' op requires an affine map}}
return
}
@ -105,3 +105,11 @@ mlfunc @mlfunc_constant() {
%x = "constant"(){value: "xyz"} : () -> i32 // expected-error {{'constant' op requires 'value' to be an integer for an integer result type}}
return
}
// -----
mlfunc @calls(%arg0 : i32) {
%x = call @calls() : () -> i32 // expected-error {{reference to function with mismatched type}}
return
}

View File

@ -410,7 +410,7 @@ bb0:
// -----
cfgfunc @foo() {
cfgfunc @undefined_function() {
bb0:
%x = constant @bar : (i32) -> () // expected-error {{reference to undefined function 'bar'}}
return