forked from OSchip/llvm-project
[spirv] Add support for function calls.
Add spv.FunctionCall operation and (de)serialization. Closes tensorflow/mlir#137 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/137 from denis0x0D:sandbox/function_call_op e2e6f07d21e7f23e8b44c7df8a8ab784f3356ce4 PiperOrigin-RevId: 269437167
This commit is contained in:
parent
9619ba10d4
commit
8a34d5d18c
|
@ -100,6 +100,7 @@ def SPV_OC_OpSpecConstantComposite : I32EnumAttrCase<"OpSpecConstantComposite",
|
|||
def SPV_OC_OpFunction : I32EnumAttrCase<"OpFunction", 54>;
|
||||
def SPV_OC_OpFunctionParameter : I32EnumAttrCase<"OpFunctionParameter", 55>;
|
||||
def SPV_OC_OpFunctionEnd : I32EnumAttrCase<"OpFunctionEnd", 56>;
|
||||
def SPV_OC_OpFunctionCall : I32EnumAttrCase<"OpFunctionCall", 57>;
|
||||
def SPV_OC_OpVariable : I32EnumAttrCase<"OpVariable", 59>;
|
||||
def SPV_OC_OpLoad : I32EnumAttrCase<"OpLoad", 61>;
|
||||
def SPV_OC_OpStore : I32EnumAttrCase<"OpStore", 62>;
|
||||
|
@ -161,13 +162,13 @@ def SPV_OpcodeAttr :
|
|||
SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite,
|
||||
SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse,
|
||||
SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction,
|
||||
SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpVariable,
|
||||
SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
|
||||
SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpIAdd,
|
||||
SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul,
|
||||
SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem,
|
||||
SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect, SPV_OC_OpIEqual,
|
||||
SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
|
||||
SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall,
|
||||
SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain,
|
||||
SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract,
|
||||
SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul,
|
||||
SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod,
|
||||
SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect,
|
||||
SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
|
||||
SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
|
||||
SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
|
||||
SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual,
|
||||
|
@ -1113,7 +1114,7 @@ def SPV_SamplerUseAttr:
|
|||
// Check that an op can only be used within the scope of a FuncOp.
|
||||
def InFunctionScope : PredOpTrait<
|
||||
"op must appear in a 'func' block",
|
||||
CPred<"llvm::isa_and_nonnull<FuncOp>($_op.getParentOp())">>;
|
||||
CPred<"($_op.getParentOfType<FuncOp>())">>;
|
||||
|
||||
// Check that an op can only be used within the scope of a SPIR-V ModuleOp.
|
||||
def InModuleScope : PredOpTrait<
|
||||
|
|
|
@ -151,6 +151,52 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [Terminator]> {
|
|||
|
||||
// -----
|
||||
|
||||
def SPV_FunctionCallOp : SPV_Op<"FunctionCall", [InFunctionScope]> {
|
||||
let summary = "Call a function.";
|
||||
|
||||
let description = [{
|
||||
Result Type is the type of the return value of the function. It must be
|
||||
the same as the Return Type operand of the Function Type operand of the
|
||||
Function operand.
|
||||
|
||||
Function is an OpFunction instruction. This could be a forward
|
||||
reference.
|
||||
|
||||
Argument N is the object to copy to parameter N of Function.
|
||||
|
||||
Note: A forward call is possible because there is no missing type
|
||||
information: Result Type must match the Return Type of the function, and
|
||||
the calling argument types must match the formal parameter types.
|
||||
|
||||
### Custom assembly form
|
||||
|
||||
``` {.ebnf}
|
||||
function-call-op ::= `spv.FunctionCall` function-id `(` ssa-use-list `)`
|
||||
`:` function-type
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
spv.FunctionCall @f_void(%arg0) : (i32) -> ()
|
||||
%0 = spv.FunctionCall @f_iadd(%arg0, %arg1) : (i32, i32) -> i32
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
SymbolRefAttr:$callee,
|
||||
Variadic<SPV_Type>:$arguments
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
SPV_Optional<SPV_Type>:$result
|
||||
);
|
||||
|
||||
let autogenSerialization = 0;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
def SPV_LoopOp : SPV_Op<"loop"> {
|
||||
let summary = "Define a structured loop.";
|
||||
|
||||
|
|
|
@ -35,6 +35,7 @@ using namespace mlir;
|
|||
// TODO(antiagainst): generate these strings using ODS.
|
||||
static constexpr const char kAlignmentAttrName[] = "alignment";
|
||||
static constexpr const char kBranchWeightAttrName[] = "branch_weights";
|
||||
static constexpr const char kCallee[] = "callee";
|
||||
static constexpr const char kDefaultValueAttrName[] = "default_value";
|
||||
static constexpr const char kFnNameAttrName[] = "fn";
|
||||
static constexpr const char kIndicesAttrName[] = "indices";
|
||||
|
@ -912,6 +913,108 @@ static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter *printer) {
|
|||
[&](Attribute a) { *printer << a.cast<IntegerAttr>().getInt(); });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.FuncionCall
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseFunctionCallOp(OpAsmParser *parser,
|
||||
OperationState *state) {
|
||||
SymbolRefAttr calleeAttr;
|
||||
FunctionType type;
|
||||
SmallVector<OpAsmParser::OperandType, 4> operands;
|
||||
auto loc = parser->getNameLoc();
|
||||
if (parser->parseAttribute(calleeAttr, kCallee, state->attributes) ||
|
||||
parser->parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
|
||||
parser->parseColonType(type)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto funcType = type.dyn_cast<FunctionType>();
|
||||
if (!funcType) {
|
||||
return parser->emitError(loc, "expected function type, but provided ")
|
||||
<< type;
|
||||
}
|
||||
|
||||
if (funcType.getNumResults() > 1) {
|
||||
return parser->emitError(loc, "expected callee function to have 0 or 1 "
|
||||
"result, but provided ")
|
||||
<< funcType.getNumResults();
|
||||
}
|
||||
|
||||
return failure(parser->addTypesToList(funcType.getResults(), state->types) ||
|
||||
parser->resolveOperands(operands, funcType.getInputs(), loc,
|
||||
state->operands));
|
||||
}
|
||||
|
||||
static void print(spirv::FunctionCallOp functionCallOp, OpAsmPrinter *printer) {
|
||||
SmallVector<Type, 4> argTypes(functionCallOp.getOperandTypes());
|
||||
SmallVector<Type, 1> resultTypes(functionCallOp.getResultTypes());
|
||||
Type functionType =
|
||||
FunctionType::get(argTypes, resultTypes, functionCallOp.getContext());
|
||||
|
||||
*printer << spirv::FunctionCallOp::getOperationName() << ' '
|
||||
<< functionCallOp.getAttr(kCallee) << '(';
|
||||
printer->printOperands(functionCallOp.arguments());
|
||||
*printer << ") : " << functionType;
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::FunctionCallOp functionCallOp) {
|
||||
auto fnName = functionCallOp.callee();
|
||||
|
||||
auto moduleOp = functionCallOp.getParentOfType<spirv::ModuleOp>();
|
||||
if (!moduleOp) {
|
||||
return functionCallOp.emitOpError(
|
||||
"must appear in a function inside 'spv.module'");
|
||||
}
|
||||
|
||||
auto funcOp = moduleOp.lookupSymbol<FuncOp>(fnName);
|
||||
if (!funcOp) {
|
||||
return functionCallOp.emitOpError("callee function '")
|
||||
<< fnName << "' not found in 'spv.module'";
|
||||
}
|
||||
|
||||
auto functionType = funcOp.getType();
|
||||
|
||||
if (functionCallOp.getNumResults() > 1) {
|
||||
return functionCallOp.emitOpError(
|
||||
"expected callee function to have 0 or 1 result, but provided ")
|
||||
<< functionCallOp.getNumResults();
|
||||
}
|
||||
|
||||
if (functionType.getNumInputs() != functionCallOp.getNumOperands()) {
|
||||
return functionCallOp.emitOpError(
|
||||
"has incorrect number of operands for callee: expected ")
|
||||
<< functionType.getNumInputs() << ", but provided "
|
||||
<< functionCallOp.getNumOperands();
|
||||
}
|
||||
|
||||
for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
|
||||
if (functionCallOp.getOperand(i)->getType() != functionType.getInput(i)) {
|
||||
return functionCallOp.emitOpError(
|
||||
"operand type mismatch: expected operand type ")
|
||||
<< functionType.getInput(i) << ", but provided "
|
||||
<< functionCallOp.getOperand(i)->getType()
|
||||
<< " for operand number " << i;
|
||||
}
|
||||
}
|
||||
|
||||
if (functionType.getNumResults() != functionCallOp.getNumResults()) {
|
||||
return functionCallOp.emitOpError(
|
||||
"has incorrect number of results has for callee: expected ")
|
||||
<< functionType.getNumResults() << ", but provided "
|
||||
<< functionCallOp.getNumResults();
|
||||
}
|
||||
|
||||
if (functionCallOp.getNumResults() &&
|
||||
(functionCallOp.getResult(0)->getType() != functionType.getResult(0))) {
|
||||
return functionCallOp.emitOpError("result type mismatch: expected ")
|
||||
<< functionType.getResult(0) << ", but provided "
|
||||
<< functionCallOp.getResult(0)->getType();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.globalVariable
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -128,6 +128,11 @@ private:
|
|||
/// Gets the constant's attribute and type associated with the given <id>.
|
||||
Optional<std::pair<Attribute, Type>> getConstant(uint32_t id);
|
||||
|
||||
/// Returns a symbol to be used for the function name with the given
|
||||
/// result <id>. This tries to use the function's OpName if
|
||||
/// exists; otherwise creates one based on the <id>.
|
||||
std::string getFunctionSymbol(uint32_t id);
|
||||
|
||||
/// Returns a symbol to be used for the specialization constant with the given
|
||||
/// result <id>. This tries to use the specialization constant's OpName if
|
||||
/// exists; otherwise creates one based on the <id>.
|
||||
|
@ -637,10 +642,7 @@ LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
|
|||
<< functionType << " and return type " << resultType << " specified";
|
||||
}
|
||||
|
||||
std::string fnName = nameMap.lookup(operands[1]).str();
|
||||
if (fnName.empty()) {
|
||||
fnName = "spirv_fn_" + std::to_string(operands[2]);
|
||||
}
|
||||
std::string fnName = getFunctionSymbol(operands[1]);
|
||||
auto funcOp = opBuilder.create<FuncOp>(unknownLoc, fnName, functionType,
|
||||
ArrayRef<NamedAttribute>());
|
||||
curFunction = funcMap[operands[1]] = funcOp;
|
||||
|
@ -762,6 +764,14 @@ Optional<std::pair<Attribute, Type>> Deserializer::getConstant(uint32_t id) {
|
|||
return constIt->getSecond();
|
||||
}
|
||||
|
||||
std::string Deserializer::getFunctionSymbol(uint32_t id) {
|
||||
auto funcName = nameMap.lookup(id).str();
|
||||
if (funcName.empty()) {
|
||||
funcName = "spirv_fn_" + std::to_string(id);
|
||||
}
|
||||
return funcName;
|
||||
}
|
||||
|
||||
std::string Deserializer::getSpecConstantSymbol(uint32_t id) {
|
||||
auto constName = nameMap.lookup(id).str();
|
||||
if (constName.empty()) {
|
||||
|
@ -1779,6 +1789,50 @@ Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
|
|||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult
|
||||
Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
|
||||
if (operands.size() < 3) {
|
||||
return emitError(unknownLoc,
|
||||
"OpFunctionCall must have at least 3 operands");
|
||||
}
|
||||
|
||||
Type resultType = getType(operands[0]);
|
||||
if (!resultType) {
|
||||
return emitError(unknownLoc, "undefined result type from <id> ")
|
||||
<< operands[0];
|
||||
}
|
||||
|
||||
auto resultID = operands[1];
|
||||
auto functionID = operands[2];
|
||||
|
||||
auto functionName = getFunctionSymbol(functionID);
|
||||
|
||||
llvm::SmallVector<Value *, 4> arguments;
|
||||
for (auto operand : llvm::drop_begin(operands, 3)) {
|
||||
auto *value = getValue(operand);
|
||||
if (!value) {
|
||||
return emitError(unknownLoc, "unknown <id> ")
|
||||
<< operand << " used by OpFunctionCall";
|
||||
}
|
||||
arguments.push_back(value);
|
||||
}
|
||||
|
||||
SmallVector<Type, 1> resultTypes;
|
||||
if (!isVoidType(resultType)) {
|
||||
resultTypes.push_back(resultType);
|
||||
}
|
||||
|
||||
auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>(
|
||||
unknownLoc, resultTypes, opBuilder.getSymbolRefAttr(functionName),
|
||||
arguments);
|
||||
|
||||
if (!resultTypes.empty()) {
|
||||
valueMap[resultID] = opFunctionCall.getResult(0);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
// Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and
|
||||
// various Deserializer::processOp<...>() specializations.
|
||||
#define GET_DESERIALIZATION_FNS
|
||||
|
|
|
@ -131,6 +131,10 @@ private:
|
|||
return funcIDMap.lookup(fnName);
|
||||
}
|
||||
|
||||
/// Gets the <id> for the function with the given name. Assigns the next
|
||||
/// available <id> if the function haven't been deserialized.
|
||||
uint32_t getOrCreateFunctionID(StringRef fnName);
|
||||
|
||||
void processCapability();
|
||||
|
||||
void processExtension();
|
||||
|
@ -392,6 +396,15 @@ void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
|
|||
// Module structure
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
|
||||
auto funcID = funcIDMap.lookup(fnName);
|
||||
if (!funcID) {
|
||||
funcID = getNextID();
|
||||
funcIDMap[fnName] = funcID;
|
||||
}
|
||||
return funcID;
|
||||
}
|
||||
|
||||
void Serializer::processCapability() {
|
||||
auto caps = module.getAttrOfType<ArrayAttr>("capabilities");
|
||||
if (!caps)
|
||||
|
@ -537,8 +550,7 @@ LogicalResult Serializer::processFuncOp(FuncOp op) {
|
|||
return failure();
|
||||
}
|
||||
operands.push_back(resTypeID);
|
||||
auto funcID = getNextID();
|
||||
funcIDMap[op.getName()] = funcID;
|
||||
auto funcID = getOrCreateFunctionID(op.getName());
|
||||
operands.push_back(funcID);
|
||||
// TODO : Support other function control options.
|
||||
operands.push_back(static_cast<uint32_t>(spirv::FunctionControl::None));
|
||||
|
@ -1461,6 +1473,37 @@ Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
|
|||
operands);
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult
|
||||
Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
|
||||
auto funcName = op.callee();
|
||||
uint32_t resTypeID = 0;
|
||||
|
||||
llvm::SmallVector<Type, 1> resultTypes(op.getResultTypes());
|
||||
if (failed(processType(op.getLoc(),
|
||||
(resultTypes.empty() ? getVoidType() : resultTypes[0]),
|
||||
resTypeID))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto funcID = getOrCreateFunctionID(funcName);
|
||||
auto funcCallID = getNextID();
|
||||
SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID};
|
||||
|
||||
for (auto *value : op.arguments()) {
|
||||
auto valueID = findValueID(value);
|
||||
assert(valueID && "cannot find a value for spv.FunctionCall");
|
||||
operands.push_back(valueID);
|
||||
}
|
||||
|
||||
if (!resultTypes.empty()) {
|
||||
valueIDMap[op.getResult(0)] = funcCallID;
|
||||
}
|
||||
|
||||
return encodeInstructionInto(functions, spirv::Opcode::OpFunctionCall,
|
||||
operands);
|
||||
}
|
||||
|
||||
// Pull in auto-generated Serializer::dispatchToAutogenSerialization() and
|
||||
// various Serializer::processOp<...>() specializations.
|
||||
#define GET_SERIALIZATION_FNS
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s
|
||||
|
||||
spv.module "Logical" "GLSL450" {
|
||||
spv.globalVariable @var1 : !spv.ptr<!spv.array<4xf32>, Input>
|
||||
func @fmain() -> i32 {
|
||||
%0 = spv.constant 16 : i32
|
||||
%1 = spv._address_of @var1 : !spv.ptr<!spv.array<4xf32>, Input>
|
||||
// CHECK: {{%.*}} = spv.FunctionCall @f_0({{%.*}}) : (i32) -> i32
|
||||
%3 = spv.FunctionCall @f_0(%0) : (i32) -> i32
|
||||
// CHECK: spv.FunctionCall @f_1({{%.*}}, {{%.*}}) : (i32, !spv.ptr<!spv.array<4 x f32>, Input>) -> ()
|
||||
spv.FunctionCall @f_1(%3, %1) : (i32, !spv.ptr<!spv.array<4xf32>, Input>) -> ()
|
||||
// CHECK: {{%.*}} = spv.FunctionCall @f_2({{%.*}}) : (!spv.ptr<!spv.array<4 x f32>, Input>) -> !spv.ptr<!spv.array<4 x f32>, Input>
|
||||
%4 = spv.FunctionCall @f_2(%1) : (!spv.ptr<!spv.array<4xf32>, Input>) -> !spv.ptr<!spv.array<4xf32>, Input>
|
||||
spv.ReturnValue %3 : i32
|
||||
}
|
||||
func @f_0(%arg0 : i32) -> i32 {
|
||||
spv.ReturnValue %arg0 : i32
|
||||
}
|
||||
func @f_1(%arg0 : i32, %arg1 : !spv.ptr<!spv.array<4xf32>, Input>) -> () {
|
||||
spv.Return
|
||||
}
|
||||
func @f_2(%arg0 : !spv.ptr<!spv.array<4xf32>, Input>) -> !spv.ptr<!spv.array<4xf32>, Input> {
|
||||
spv.ReturnValue %arg0 : !spv.ptr<!spv.array<4xf32>, Input>
|
||||
}
|
||||
|
||||
func @f_loop_with_function_call(%count : i32) -> () {
|
||||
%zero = spv.constant 0: i32
|
||||
%var = spv.Variable init(%zero) : !spv.ptr<i32, Function>
|
||||
spv.loop {
|
||||
spv.Branch ^header
|
||||
^header:
|
||||
%val0 = spv.Load "Function" %var : i32
|
||||
%cmp = spv.SLessThan %val0, %count : i32
|
||||
spv.BranchConditional %cmp, ^body, ^merge
|
||||
^body:
|
||||
spv.Branch ^continue
|
||||
^continue:
|
||||
// CHECK: spv.FunctionCall @f_inc({{%.*}}) : (!spv.ptr<i32, Function>) -> ()
|
||||
spv.FunctionCall @f_inc(%var) : (!spv.ptr<i32, Function>) -> ()
|
||||
spv.Branch ^header
|
||||
^merge:
|
||||
spv._merge
|
||||
}
|
||||
spv.Return
|
||||
}
|
||||
func @f_inc(%arg0 : !spv.ptr<i32, Function>) -> () {
|
||||
%one = spv.constant 1 : i32
|
||||
%0 = spv.Load "Function" %arg0 : i32
|
||||
%1 = spv.IAdd %0, %one : i32
|
||||
spv.Store "Function" %arg0, %1 : i32
|
||||
spv.Return
|
||||
}
|
||||
}
|
|
@ -144,6 +144,111 @@ func @weights_cannot_both_be_zero() -> () {
|
|||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.FunctionCall
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
spv.module "Logical" "GLSL450" {
|
||||
func @fmain(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>, %arg2 : i32) -> i32 {
|
||||
// CHECK: {{%.*}} = spv.FunctionCall @f_0({{%.*}}, {{%.*}}) : (vector<4xf32>, vector<4xf32>) -> vector<4xf32>
|
||||
%0 = spv.FunctionCall @f_0(%arg0, %arg1) : (vector<4xf32>, vector<4xf32>) -> vector<4xf32>
|
||||
// CHECK: spv.FunctionCall @f_1({{%.*}}, {{%.*}}) : (vector<4xf32>, vector<4xf32>) -> ()
|
||||
spv.FunctionCall @f_1(%0, %arg1) : (vector<4xf32>, vector<4xf32>) -> ()
|
||||
// CHECK: spv.FunctionCall @f_2() : () -> ()
|
||||
spv.FunctionCall @f_2() : () -> ()
|
||||
// CHECK: {{%.*}} = spv.FunctionCall @f_3({{%.*}}) : (i32) -> i32
|
||||
%1 = spv.FunctionCall @f_3(%arg2) : (i32) -> i32
|
||||
spv.ReturnValue %1 : i32
|
||||
}
|
||||
|
||||
func @f_0(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> (vector<4xf32>) {
|
||||
spv.ReturnValue %arg0 : vector<4xf32>
|
||||
}
|
||||
|
||||
func @f_1(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> () {
|
||||
spv.Return
|
||||
}
|
||||
|
||||
func @f_2() -> () {
|
||||
spv.Return
|
||||
}
|
||||
|
||||
func @f_3(%arg0 : i32) -> (i32) {
|
||||
spv.ReturnValue %arg0 : i32
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spv.module "Logical" "GLSL450" {
|
||||
func @f_invalid_result_type(%arg0 : i32, %arg1 : i32) -> () {
|
||||
// expected-error @+1 {{expected callee function to have 0 or 1 result, but provided 2}}
|
||||
%0 = spv.FunctionCall @f_invalid_result_type(%arg0, %arg1) : (i32, i32) -> (i32, i32)
|
||||
spv.Return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spv.module "Logical" "GLSL450" {
|
||||
func @f_result_type_mismatch(%arg0 : i32, %arg1 : i32) -> () {
|
||||
// expected-error @+1 {{has incorrect number of results has for callee: expected 0, but provided 1}}
|
||||
%1 = spv.FunctionCall @f_result_type_mismatch(%arg0, %arg0) : (i32, i32) -> (i32)
|
||||
spv.Return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spv.module "Logical" "GLSL450" {
|
||||
func @f_type_mismatch(%arg0 : i32, %arg1 : i32) -> () {
|
||||
// expected-error @+1 {{has incorrect number of operands for callee: expected 2, but provided 1}}
|
||||
spv.FunctionCall @f_type_mismatch(%arg0) : (i32) -> ()
|
||||
spv.Return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spv.module "Logical" "GLSL450" {
|
||||
func @f_type_mismatch(%arg0 : i32, %arg1 : i32) -> () {
|
||||
%0 = spv.constant 2.0 : f32
|
||||
// expected-error @+1 {{operand type mismatch: expected operand type 'i32', but provided 'f32' for operand number 1}}
|
||||
spv.FunctionCall @f_type_mismatch(%arg0, %0) : (i32, f32) -> ()
|
||||
spv.Return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spv.module "Logical" "GLSL450" {
|
||||
func @f_type_mismatch(%arg0 : i32, %arg1 : i32) -> i32 {
|
||||
// expected-error @+1 {{result type mismatch: expected 'i32', but provided 'f32'}}
|
||||
%0 = spv.FunctionCall @f_type_mismatch(%arg0, %arg0) : (i32, i32) -> f32
|
||||
spv.Return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
spv.module "Logical" "GLSL450" {
|
||||
func @f_foo(%arg0 : i32, %arg1 : i32) -> i32 {
|
||||
// expected-error @+1 {{op callee function 'f_undefined' not found in 'spv.module'}}
|
||||
%0 = spv.FunctionCall @f_undefined(%arg0, %arg0) : (i32, i32) -> i32
|
||||
spv.Return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @f_foo(%arg0 : i32, %arg1 : i32) -> i32 {
|
||||
// expected-error @+1 {{must appear in a function inside 'spv.module'}}
|
||||
%0 = spv.FunctionCall @f_foo(%arg0, %arg0) : (i32, i32) -> i32
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.loop
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
Loading…
Reference in New Issue