diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index 8a49ae63b2a4..b2b844dca122 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -100,6 +100,7 @@ def SPV_OC_OpSpecConstantComposite : I32EnumAttrCase<"OpSpecConstantComposite", def SPV_OC_OpFunction : I32EnumAttrCase<"OpFunction", 54>; def SPV_OC_OpFunctionParameter : I32EnumAttrCase<"OpFunctionParameter", 55>; def SPV_OC_OpFunctionEnd : I32EnumAttrCase<"OpFunctionEnd", 56>; +def SPV_OC_OpFunctionCall : I32EnumAttrCase<"OpFunctionCall", 57>; def SPV_OC_OpVariable : I32EnumAttrCase<"OpVariable", 59>; def SPV_OC_OpLoad : I32EnumAttrCase<"OpLoad", 61>; def SPV_OC_OpStore : I32EnumAttrCase<"OpStore", 62>; @@ -161,13 +162,13 @@ def SPV_OpcodeAttr : SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, - SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpVariable, - SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate, - SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpIAdd, - SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, - SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, - SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect, SPV_OC_OpIEqual, - SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, + SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, + SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain, + SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, + SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, + SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, + SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect, + SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, @@ -1113,7 +1114,7 @@ def SPV_SamplerUseAttr: // Check that an op can only be used within the scope of a FuncOp. def InFunctionScope : PredOpTrait< "op must appear in a 'func' block", - CPred<"llvm::isa_and_nonnull($_op.getParentOp())">>; + CPred<"($_op.getParentOfType())">>; // Check that an op can only be used within the scope of a SPIR-V ModuleOp. def InModuleScope : PredOpTrait< diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td index 015dde8e77db..b9f86043f5c4 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td @@ -151,6 +151,52 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [Terminator]> { // ----- +def SPV_FunctionCallOp : SPV_Op<"FunctionCall", [InFunctionScope]> { + let summary = "Call a function."; + + let description = [{ + Result Type is the type of the return value of the function. It must be + the same as the Return Type operand of the Function Type operand of the + Function operand. + + Function is an OpFunction instruction. This could be a forward + reference. + + Argument N is the object to copy to parameter N of Function. + + Note: A forward call is possible because there is no missing type + information: Result Type must match the Return Type of the function, and + the calling argument types must match the formal parameter types. + + ### Custom assembly form + + ``` {.ebnf} + function-call-op ::= `spv.FunctionCall` function-id `(` ssa-use-list `)` + `:` function-type + ``` + + For example: + + ``` + spv.FunctionCall @f_void(%arg0) : (i32) -> () + %0 = spv.FunctionCall @f_iadd(%arg0, %arg1) : (i32, i32) -> i32 + ``` + }]; + + let arguments = (ins + SymbolRefAttr:$callee, + Variadic:$arguments + ); + + let results = (outs + SPV_Optional:$result + ); + + let autogenSerialization = 0; +} + +// ----- + def SPV_LoopOp : SPV_Op<"loop"> { let summary = "Define a structured loop."; diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 81860e86b7a9..9766d6c88662 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -35,6 +35,7 @@ using namespace mlir; // TODO(antiagainst): generate these strings using ODS. static constexpr const char kAlignmentAttrName[] = "alignment"; static constexpr const char kBranchWeightAttrName[] = "branch_weights"; +static constexpr const char kCallee[] = "callee"; static constexpr const char kDefaultValueAttrName[] = "default_value"; static constexpr const char kFnNameAttrName[] = "fn"; static constexpr const char kIndicesAttrName[] = "indices"; @@ -912,6 +913,108 @@ static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter *printer) { [&](Attribute a) { *printer << a.cast().getInt(); }); } +//===----------------------------------------------------------------------===// +// spv.FuncionCall +//===----------------------------------------------------------------------===// + +static ParseResult parseFunctionCallOp(OpAsmParser *parser, + OperationState *state) { + SymbolRefAttr calleeAttr; + FunctionType type; + SmallVector operands; + auto loc = parser->getNameLoc(); + if (parser->parseAttribute(calleeAttr, kCallee, state->attributes) || + parser->parseOperandList(operands, OpAsmParser::Delimiter::Paren) || + parser->parseColonType(type)) { + return failure(); + } + + auto funcType = type.dyn_cast(); + if (!funcType) { + return parser->emitError(loc, "expected function type, but provided ") + << type; + } + + if (funcType.getNumResults() > 1) { + return parser->emitError(loc, "expected callee function to have 0 or 1 " + "result, but provided ") + << funcType.getNumResults(); + } + + return failure(parser->addTypesToList(funcType.getResults(), state->types) || + parser->resolveOperands(operands, funcType.getInputs(), loc, + state->operands)); +} + +static void print(spirv::FunctionCallOp functionCallOp, OpAsmPrinter *printer) { + SmallVector argTypes(functionCallOp.getOperandTypes()); + SmallVector resultTypes(functionCallOp.getResultTypes()); + Type functionType = + FunctionType::get(argTypes, resultTypes, functionCallOp.getContext()); + + *printer << spirv::FunctionCallOp::getOperationName() << ' ' + << functionCallOp.getAttr(kCallee) << '('; + printer->printOperands(functionCallOp.arguments()); + *printer << ") : " << functionType; +} + +static LogicalResult verify(spirv::FunctionCallOp functionCallOp) { + auto fnName = functionCallOp.callee(); + + auto moduleOp = functionCallOp.getParentOfType(); + if (!moduleOp) { + return functionCallOp.emitOpError( + "must appear in a function inside 'spv.module'"); + } + + auto funcOp = moduleOp.lookupSymbol(fnName); + if (!funcOp) { + return functionCallOp.emitOpError("callee function '") + << fnName << "' not found in 'spv.module'"; + } + + auto functionType = funcOp.getType(); + + if (functionCallOp.getNumResults() > 1) { + return functionCallOp.emitOpError( + "expected callee function to have 0 or 1 result, but provided ") + << functionCallOp.getNumResults(); + } + + if (functionType.getNumInputs() != functionCallOp.getNumOperands()) { + return functionCallOp.emitOpError( + "has incorrect number of operands for callee: expected ") + << functionType.getNumInputs() << ", but provided " + << functionCallOp.getNumOperands(); + } + + for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) { + if (functionCallOp.getOperand(i)->getType() != functionType.getInput(i)) { + return functionCallOp.emitOpError( + "operand type mismatch: expected operand type ") + << functionType.getInput(i) << ", but provided " + << functionCallOp.getOperand(i)->getType() + << " for operand number " << i; + } + } + + if (functionType.getNumResults() != functionCallOp.getNumResults()) { + return functionCallOp.emitOpError( + "has incorrect number of results has for callee: expected ") + << functionType.getNumResults() << ", but provided " + << functionCallOp.getNumResults(); + } + + if (functionCallOp.getNumResults() && + (functionCallOp.getResult(0)->getType() != functionType.getResult(0))) { + return functionCallOp.emitOpError("result type mismatch: expected ") + << functionType.getResult(0) << ", but provided " + << functionCallOp.getResult(0)->getType(); + } + + return success(); +} + //===----------------------------------------------------------------------===// // spv.globalVariable //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index dcc7d19af625..e5f4e06894b3 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -128,6 +128,11 @@ private: /// Gets the constant's attribute and type associated with the given . Optional> getConstant(uint32_t id); + /// Returns a symbol to be used for the function name with the given + /// result . This tries to use the function's OpName if + /// exists; otherwise creates one based on the . + std::string getFunctionSymbol(uint32_t id); + /// Returns a symbol to be used for the specialization constant with the given /// result . This tries to use the specialization constant's OpName if /// exists; otherwise creates one based on the . @@ -637,10 +642,7 @@ LogicalResult Deserializer::processFunction(ArrayRef operands) { << functionType << " and return type " << resultType << " specified"; } - std::string fnName = nameMap.lookup(operands[1]).str(); - if (fnName.empty()) { - fnName = "spirv_fn_" + std::to_string(operands[2]); - } + std::string fnName = getFunctionSymbol(operands[1]); auto funcOp = opBuilder.create(unknownLoc, fnName, functionType, ArrayRef()); curFunction = funcMap[operands[1]] = funcOp; @@ -762,6 +764,14 @@ Optional> Deserializer::getConstant(uint32_t id) { return constIt->getSecond(); } +std::string Deserializer::getFunctionSymbol(uint32_t id) { + auto funcName = nameMap.lookup(id).str(); + if (funcName.empty()) { + funcName = "spirv_fn_" + std::to_string(id); + } + return funcName; +} + std::string Deserializer::getSpecConstantSymbol(uint32_t id) { auto constName = nameMap.lookup(id).str(); if (constName.empty()) { @@ -1779,6 +1789,50 @@ Deserializer::processOp(ArrayRef words) { return success(); } +template <> +LogicalResult +Deserializer::processOp(ArrayRef operands) { + if (operands.size() < 3) { + return emitError(unknownLoc, + "OpFunctionCall must have at least 3 operands"); + } + + Type resultType = getType(operands[0]); + if (!resultType) { + return emitError(unknownLoc, "undefined result type from ") + << operands[0]; + } + + auto resultID = operands[1]; + auto functionID = operands[2]; + + auto functionName = getFunctionSymbol(functionID); + + llvm::SmallVector arguments; + for (auto operand : llvm::drop_begin(operands, 3)) { + auto *value = getValue(operand); + if (!value) { + return emitError(unknownLoc, "unknown ") + << operand << " used by OpFunctionCall"; + } + arguments.push_back(value); + } + + SmallVector resultTypes; + if (!isVoidType(resultType)) { + resultTypes.push_back(resultType); + } + + auto opFunctionCall = opBuilder.create( + unknownLoc, resultTypes, opBuilder.getSymbolRefAttr(functionName), + arguments); + + if (!resultTypes.empty()) { + valueMap[resultID] = opFunctionCall.getResult(0); + } + return success(); +} + // Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and // various Deserializer::processOp<...>() specializations. #define GET_DESERIALIZATION_FNS diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index e05c0d4f8e6d..c31c9f31bb96 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -131,6 +131,10 @@ private: return funcIDMap.lookup(fnName); } + /// Gets the for the function with the given name. Assigns the next + /// available if the function haven't been deserialized. + uint32_t getOrCreateFunctionID(StringRef fnName); + void processCapability(); void processExtension(); @@ -392,6 +396,15 @@ void Serializer::collect(SmallVectorImpl &binary) { // Module structure //===----------------------------------------------------------------------===// +uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) { + auto funcID = funcIDMap.lookup(fnName); + if (!funcID) { + funcID = getNextID(); + funcIDMap[fnName] = funcID; + } + return funcID; +} + void Serializer::processCapability() { auto caps = module.getAttrOfType("capabilities"); if (!caps) @@ -537,8 +550,7 @@ LogicalResult Serializer::processFuncOp(FuncOp op) { return failure(); } operands.push_back(resTypeID); - auto funcID = getNextID(); - funcIDMap[op.getName()] = funcID; + auto funcID = getOrCreateFunctionID(op.getName()); operands.push_back(funcID); // TODO : Support other function control options. operands.push_back(static_cast(spirv::FunctionControl::None)); @@ -1461,6 +1473,37 @@ Serializer::processOp(spirv::ExecutionModeOp op) { operands); } +template <> +LogicalResult +Serializer::processOp(spirv::FunctionCallOp op) { + auto funcName = op.callee(); + uint32_t resTypeID = 0; + + llvm::SmallVector resultTypes(op.getResultTypes()); + if (failed(processType(op.getLoc(), + (resultTypes.empty() ? getVoidType() : resultTypes[0]), + resTypeID))) { + return failure(); + } + + auto funcID = getOrCreateFunctionID(funcName); + auto funcCallID = getNextID(); + SmallVector operands{resTypeID, funcCallID, funcID}; + + for (auto *value : op.arguments()) { + auto valueID = findValueID(value); + assert(valueID && "cannot find a value for spv.FunctionCall"); + operands.push_back(valueID); + } + + if (!resultTypes.empty()) { + valueIDMap[op.getResult(0)] = funcCallID; + } + + return encodeInstructionInto(functions, spirv::Opcode::OpFunctionCall, + operands); +} + // Pull in auto-generated Serializer::dispatchToAutogenSerialization() and // various Serializer::processOp<...>() specializations. #define GET_SERIALIZATION_FNS diff --git a/mlir/test/Dialect/SPIRV/Serialization/function_call.mlir b/mlir/test/Dialect/SPIRV/Serialization/function_call.mlir new file mode 100644 index 000000000000..59023c42f015 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Serialization/function_call.mlir @@ -0,0 +1,53 @@ +// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s + +spv.module "Logical" "GLSL450" { + spv.globalVariable @var1 : !spv.ptr, Input> + func @fmain() -> i32 { + %0 = spv.constant 16 : i32 + %1 = spv._address_of @var1 : !spv.ptr, Input> + // CHECK: {{%.*}} = spv.FunctionCall @f_0({{%.*}}) : (i32) -> i32 + %3 = spv.FunctionCall @f_0(%0) : (i32) -> i32 + // CHECK: spv.FunctionCall @f_1({{%.*}}, {{%.*}}) : (i32, !spv.ptr, Input>) -> () + spv.FunctionCall @f_1(%3, %1) : (i32, !spv.ptr, Input>) -> () + // CHECK: {{%.*}} = spv.FunctionCall @f_2({{%.*}}) : (!spv.ptr, Input>) -> !spv.ptr, Input> + %4 = spv.FunctionCall @f_2(%1) : (!spv.ptr, Input>) -> !spv.ptr, Input> + spv.ReturnValue %3 : i32 + } + func @f_0(%arg0 : i32) -> i32 { + spv.ReturnValue %arg0 : i32 + } + func @f_1(%arg0 : i32, %arg1 : !spv.ptr, Input>) -> () { + spv.Return + } + func @f_2(%arg0 : !spv.ptr, Input>) -> !spv.ptr, Input> { + spv.ReturnValue %arg0 : !spv.ptr, Input> + } + + func @f_loop_with_function_call(%count : i32) -> () { + %zero = spv.constant 0: i32 + %var = spv.Variable init(%zero) : !spv.ptr + spv.loop { + spv.Branch ^header + ^header: + %val0 = spv.Load "Function" %var : i32 + %cmp = spv.SLessThan %val0, %count : i32 + spv.BranchConditional %cmp, ^body, ^merge + ^body: + spv.Branch ^continue + ^continue: + // CHECK: spv.FunctionCall @f_inc({{%.*}}) : (!spv.ptr) -> () + spv.FunctionCall @f_inc(%var) : (!spv.ptr) -> () + spv.Branch ^header + ^merge: + spv._merge + } + spv.Return + } + func @f_inc(%arg0 : !spv.ptr) -> () { + %one = spv.constant 1 : i32 + %0 = spv.Load "Function" %arg0 : i32 + %1 = spv.IAdd %0, %one : i32 + spv.Store "Function" %arg0, %1 : i32 + spv.Return + } +} diff --git a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir index c3638a8d0016..e83f36de4468 100644 --- a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir +++ b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir @@ -144,6 +144,111 @@ func @weights_cannot_both_be_zero() -> () { // ----- +//===----------------------------------------------------------------------===// +// spv.FunctionCall +//===----------------------------------------------------------------------===// + +spv.module "Logical" "GLSL450" { + func @fmain(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>, %arg2 : i32) -> i32 { + // CHECK: {{%.*}} = spv.FunctionCall @f_0({{%.*}}, {{%.*}}) : (vector<4xf32>, vector<4xf32>) -> vector<4xf32> + %0 = spv.FunctionCall @f_0(%arg0, %arg1) : (vector<4xf32>, vector<4xf32>) -> vector<4xf32> + // CHECK: spv.FunctionCall @f_1({{%.*}}, {{%.*}}) : (vector<4xf32>, vector<4xf32>) -> () + spv.FunctionCall @f_1(%0, %arg1) : (vector<4xf32>, vector<4xf32>) -> () + // CHECK: spv.FunctionCall @f_2() : () -> () + spv.FunctionCall @f_2() : () -> () + // CHECK: {{%.*}} = spv.FunctionCall @f_3({{%.*}}) : (i32) -> i32 + %1 = spv.FunctionCall @f_3(%arg2) : (i32) -> i32 + spv.ReturnValue %1 : i32 + } + + func @f_0(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> (vector<4xf32>) { + spv.ReturnValue %arg0 : vector<4xf32> + } + + func @f_1(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> () { + spv.Return + } + + func @f_2() -> () { + spv.Return + } + + func @f_3(%arg0 : i32) -> (i32) { + spv.ReturnValue %arg0 : i32 + } +} + +// ----- + +spv.module "Logical" "GLSL450" { + func @f_invalid_result_type(%arg0 : i32, %arg1 : i32) -> () { + // expected-error @+1 {{expected callee function to have 0 or 1 result, but provided 2}} + %0 = spv.FunctionCall @f_invalid_result_type(%arg0, %arg1) : (i32, i32) -> (i32, i32) + spv.Return + } +} + +// ----- + +spv.module "Logical" "GLSL450" { + func @f_result_type_mismatch(%arg0 : i32, %arg1 : i32) -> () { + // expected-error @+1 {{has incorrect number of results has for callee: expected 0, but provided 1}} + %1 = spv.FunctionCall @f_result_type_mismatch(%arg0, %arg0) : (i32, i32) -> (i32) + spv.Return + } +} + +// ----- + +spv.module "Logical" "GLSL450" { + func @f_type_mismatch(%arg0 : i32, %arg1 : i32) -> () { + // expected-error @+1 {{has incorrect number of operands for callee: expected 2, but provided 1}} + spv.FunctionCall @f_type_mismatch(%arg0) : (i32) -> () + spv.Return + } +} + +// ----- + +spv.module "Logical" "GLSL450" { + func @f_type_mismatch(%arg0 : i32, %arg1 : i32) -> () { + %0 = spv.constant 2.0 : f32 + // expected-error @+1 {{operand type mismatch: expected operand type 'i32', but provided 'f32' for operand number 1}} + spv.FunctionCall @f_type_mismatch(%arg0, %0) : (i32, f32) -> () + spv.Return + } +} + +// ----- + +spv.module "Logical" "GLSL450" { + func @f_type_mismatch(%arg0 : i32, %arg1 : i32) -> i32 { + // expected-error @+1 {{result type mismatch: expected 'i32', but provided 'f32'}} + %0 = spv.FunctionCall @f_type_mismatch(%arg0, %arg0) : (i32, i32) -> f32 + spv.Return + } +} + +// ----- + +spv.module "Logical" "GLSL450" { + func @f_foo(%arg0 : i32, %arg1 : i32) -> i32 { + // expected-error @+1 {{op callee function 'f_undefined' not found in 'spv.module'}} + %0 = spv.FunctionCall @f_undefined(%arg0, %arg0) : (i32, i32) -> i32 + spv.Return + } +} + +// ----- + +func @f_foo(%arg0 : i32, %arg1 : i32) -> i32 { + // expected-error @+1 {{must appear in a function inside 'spv.module'}} + %0 = spv.FunctionCall @f_foo(%arg0, %arg0) : (i32, i32) -> i32 + spv.Return +} + +// ----- + //===----------------------------------------------------------------------===// // spv.loop //===----------------------------------------------------------------------===//