forked from OSchip/llvm-project
Add serialization and deserialization of FuncOps. To support this the
following SPIRV Instructions serializaiton/deserialization are added as well OpFunction OpFunctionParameter OpFunctionEnd OpReturn PiperOrigin-RevId: 257869806
This commit is contained in:
parent
63bc37c9c0
commit
9af156757d
|
@ -72,23 +72,29 @@ class SPV_OpCode<string name, int val> {
|
|||
|
||||
// Begin opcode section. Generated from SPIR-V spec; DO NOT MODIFY!
|
||||
|
||||
def SPV_OC_OpMemoryModel : I32EnumAttrCase<"OpMemoryModel", 14>;
|
||||
def SPV_OC_OpEntryPoint : I32EnumAttrCase<"OpEntryPoint", 15>;
|
||||
def SPV_OC_OpExecutionMode : I32EnumAttrCase<"OpExecutionMode", 16>;
|
||||
def SPV_OC_OpTypeVoid : I32EnumAttrCase<"OpTypeVoid", 19>;
|
||||
def SPV_OC_OpTypeFunction : I32EnumAttrCase<"OpTypeFunction", 33>;
|
||||
def SPV_OC_OpVariable : I32EnumAttrCase<"OpVariable", 59>;
|
||||
def SPV_OC_OpLoad : I32EnumAttrCase<"OpLoad", 61>;
|
||||
def SPV_OC_OpStore : I32EnumAttrCase<"OpStore", 62>;
|
||||
def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>;
|
||||
def SPV_OC_OpFMul : I32EnumAttrCase<"OpFMul", 133>;
|
||||
def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>;
|
||||
def SPV_OC_OpMemoryModel : I32EnumAttrCase<"OpMemoryModel", 14>;
|
||||
def SPV_OC_OpEntryPoint : I32EnumAttrCase<"OpEntryPoint", 15>;
|
||||
def SPV_OC_OpExecutionMode : I32EnumAttrCase<"OpExecutionMode", 16>;
|
||||
def SPV_OC_OpTypeVoid : I32EnumAttrCase<"OpTypeVoid", 19>;
|
||||
def SPV_OC_OpTypeFunction : I32EnumAttrCase<"OpTypeFunction", 33>;
|
||||
def SPV_OC_OpFunction : I32EnumAttrCase<"OpFunction", 54>;
|
||||
def SPV_OC_OpFunctionParameter : I32EnumAttrCase<"OpFunctionParameter", 55>;
|
||||
def SPV_OC_OpFunctionEnd : I32EnumAttrCase<"OpFunctionEnd", 56>;
|
||||
def SPV_OC_OpVariable : I32EnumAttrCase<"OpVariable", 59>;
|
||||
def SPV_OC_OpLoad : I32EnumAttrCase<"OpLoad", 61>;
|
||||
def SPV_OC_OpStore : I32EnumAttrCase<"OpStore", 62>;
|
||||
def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>;
|
||||
def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>;
|
||||
def SPV_OC_OpFMul : I32EnumAttrCase<"OpFMul", 133>;
|
||||
def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>;
|
||||
|
||||
def SPV_OpcodeAttr :
|
||||
I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
|
||||
SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode,
|
||||
SPV_OC_OpTypeVoid, SPV_OC_OpTypeFunction, SPV_OC_OpVariable, SPV_OC_OpLoad,
|
||||
SPV_OC_OpStore, SPV_OC_OpCompositeExtract, SPV_OC_OpFMul, SPV_OC_OpReturn
|
||||
SPV_OC_OpTypeVoid, SPV_OC_OpTypeFunction, SPV_OC_OpFunction,
|
||||
SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpVariable,
|
||||
SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpDecorate, SPV_OC_OpCompositeExtract,
|
||||
SPV_OC_OpFMul, SPV_OC_OpReturn
|
||||
]> {
|
||||
let returnType = "::mlir::spirv::Opcode";
|
||||
let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())";
|
||||
|
@ -294,6 +300,21 @@ def SPV_ExecutionModelAttr :
|
|||
let cppNamespace = "::mlir::spirv";
|
||||
}
|
||||
|
||||
def SPV_FC_None : I32EnumAttrCase<"None", 0x0000>;
|
||||
def SPV_FC_Inline : I32EnumAttrCase<"Inline", 0x0001>;
|
||||
def SPV_FC_DontInline : I32EnumAttrCase<"DontInline", 0x0002>;
|
||||
def SPV_FC_Pure : I32EnumAttrCase<"Pure", 0x0004>;
|
||||
def SPV_FC_Const : I32EnumAttrCase<"Const", 0x0008>;
|
||||
|
||||
def SPV_FunctionControlAttr :
|
||||
I32EnumAttr<"FunctionControl", "valid SPIR-V FunctionControl", [
|
||||
SPV_FC_None, SPV_FC_Inline, SPV_FC_DontInline, SPV_FC_Pure, SPV_FC_Const
|
||||
]> {
|
||||
let returnType = "::mlir::spirv::FunctionControl";
|
||||
let convertFromStorage = "static_cast<::mlir::spirv::FunctionControl>($_self.getInt())";
|
||||
let cppNamespace = "::mlir::spirv";
|
||||
}
|
||||
|
||||
def SPV_IF_Unknown : I32EnumAttrCase<"Unknown", 0>;
|
||||
def SPV_IF_Rgba32f : I32EnumAttrCase<"Rgba32f", 1>;
|
||||
def SPV_IF_Rgba16f : I32EnumAttrCase<"Rgba16f", 2>;
|
||||
|
@ -352,6 +373,18 @@ def SPV_ImageFormatAttr :
|
|||
let cppNamespace = "::mlir::spirv";
|
||||
}
|
||||
|
||||
def SPV_LT_Export : I32EnumAttrCase<"Export", 0>;
|
||||
def SPV_LT_Import : I32EnumAttrCase<"Import", 1>;
|
||||
|
||||
def SPV_LinkageTypeAttr :
|
||||
I32EnumAttr<"LinkageType", "valid SPIR-V LinkageType", [
|
||||
SPV_LT_Export, SPV_LT_Import
|
||||
]> {
|
||||
let returnType = "::mlir::spirv::LinkageType";
|
||||
let convertFromStorage = "static_cast<::mlir::spirv::LinkageType>($_self.getInt())";
|
||||
let cppNamespace = "::mlir::spirv";
|
||||
}
|
||||
|
||||
def SPV_MA_None : I32EnumAttrCase<"None", 0x0000>;
|
||||
def SPV_MA_Volatile : I32EnumAttrCase<"Volatile", 0x0001>;
|
||||
def SPV_MA_Aligned : I32EnumAttrCase<"Aligned", 0x0002>;
|
||||
|
|
|
@ -66,9 +66,20 @@ private:
|
|||
/// Get type for a given result <id>
|
||||
Type getType(uint32_t id) { return typeMap.lookup(id); }
|
||||
|
||||
/// Get Value associated with a result <id>
|
||||
Value *getValue(uint32_t id) { return valueMap.lookup(id); }
|
||||
|
||||
// Check if a type is void
|
||||
bool isVoidType(Type type) const { return type.isa<NoneType>(); }
|
||||
|
||||
/// Processes SPIR-V module header.
|
||||
LogicalResult processHeader();
|
||||
|
||||
/// Deserialize a single instruction. The |opcode| and |operands| are returned
|
||||
/// after deserialization to the caller.
|
||||
LogicalResult deserializeInstruction(spirv::Opcode &opcode,
|
||||
ArrayRef<uint32_t> &operands);
|
||||
|
||||
/// Processes a SPIR-V instruction with the given `opcode` and `operands`.
|
||||
LogicalResult processInstruction(spirv::Opcode opcode,
|
||||
ArrayRef<uint32_t> operands);
|
||||
|
@ -77,6 +88,13 @@ private:
|
|||
LogicalResult processType(spirv::Opcode opcode, ArrayRef<uint32_t> operands);
|
||||
LogicalResult processFunctionType(ArrayRef<uint32_t> operands);
|
||||
|
||||
/// Process SPIR-V instructions that dont have any operands
|
||||
template <typename OpTy>
|
||||
LogicalResult processNullaryInstruction(ArrayRef<uint32_t> operands);
|
||||
|
||||
/// Process function objects in binary
|
||||
LogicalResult processFunction(ArrayRef<uint32_t> operands);
|
||||
|
||||
LogicalResult processMemoryModel(ArrayRef<uint32_t> operands);
|
||||
|
||||
/// Initializes the `module` ModuleOp in this deserializer instance.
|
||||
|
@ -102,6 +120,12 @@ private:
|
|||
|
||||
// result <id> to type mapping
|
||||
DenseMap<uint32_t, Type> typeMap;
|
||||
|
||||
// result <id> to function mapping
|
||||
DenseMap<uint32_t, Operation *> funcMap;
|
||||
|
||||
// result <id> to value mapping
|
||||
DenseMap<uint32_t, Value *> valueMap;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
@ -114,30 +138,11 @@ LogicalResult Deserializer::deserialize() {
|
|||
if (failed(processHeader()))
|
||||
return failure();
|
||||
|
||||
auto binarySize = binary.size();
|
||||
curOffset = spirv::kHeaderWordCount;
|
||||
|
||||
while (curOffset < binarySize) {
|
||||
// For each instruction, get its word count from the first word to slice it
|
||||
// from the stream properly, and then dispatch to the instruction handler.
|
||||
|
||||
uint32_t wordCount = binary[curOffset] >> 16;
|
||||
uint32_t opcode = binary[curOffset] & 0xffff;
|
||||
|
||||
if (wordCount == 0)
|
||||
return emitError(unknownLoc, "word count cannot be zero");
|
||||
|
||||
uint32_t nextOffset = curOffset + wordCount;
|
||||
if (nextOffset > binarySize)
|
||||
return emitError(unknownLoc,
|
||||
"insufficient words for the last instruction");
|
||||
|
||||
auto operands = binary.slice(curOffset + 1, wordCount - 1);
|
||||
if (failed(
|
||||
processInstruction(static_cast<spirv::Opcode>(opcode), operands)))
|
||||
spirv::Opcode opcode;
|
||||
ArrayRef<uint32_t> operands;
|
||||
while (succeeded(deserializeInstruction(opcode, operands))) {
|
||||
if (failed(processInstruction(opcode, operands)))
|
||||
return failure();
|
||||
|
||||
curOffset = nextOffset;
|
||||
}
|
||||
|
||||
return success();
|
||||
|
@ -154,6 +159,32 @@ LogicalResult Deserializer::processHeader() {
|
|||
return emitError(unknownLoc, "incorrect magic number");
|
||||
|
||||
// TODO(antiagainst): generator number, bound, schema
|
||||
curOffset = spirv::kHeaderWordCount;
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
Deserializer::deserializeInstruction(spirv::Opcode &opcode,
|
||||
ArrayRef<uint32_t> &operands) {
|
||||
auto binarySize = binary.size();
|
||||
if (curOffset >= binarySize) {
|
||||
return failure();
|
||||
}
|
||||
// For each instruction, get its word count from the first word to slice it
|
||||
// from the stream properly, and then dispatch to the instruction handler.
|
||||
|
||||
uint32_t wordCount = binary[curOffset] >> 16;
|
||||
opcode = static_cast<spirv::Opcode>(binary[curOffset] & 0xffff);
|
||||
|
||||
if (wordCount == 0)
|
||||
return emitError(unknownLoc, "word count cannot be zero");
|
||||
|
||||
uint32_t nextOffset = curOffset + wordCount;
|
||||
if (nextOffset > binarySize)
|
||||
return emitError(unknownLoc, "insufficient words for the last instruction");
|
||||
|
||||
operands = binary.slice(curOffset + 1, wordCount - 1);
|
||||
curOffset = nextOffset;
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -174,7 +205,11 @@ LogicalResult Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
|
|||
}
|
||||
argTypes.push_back(ty);
|
||||
}
|
||||
typeMap[operands[0]] = FunctionType::get(argTypes, {returnType}, context);
|
||||
ArrayRef<Type> returnTypes;
|
||||
if (!isVoidType(returnType)) {
|
||||
returnTypes = llvm::makeArrayRef(returnType);
|
||||
}
|
||||
typeMap[operands[0]] = FunctionType::get(argTypes, returnTypes, context);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -205,6 +240,118 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode,
|
|||
return success();
|
||||
}
|
||||
|
||||
template <typename OpTy>
|
||||
LogicalResult
|
||||
Deserializer::processNullaryInstruction(ArrayRef<uint32_t> operands) {
|
||||
if (!operands.empty()) {
|
||||
return emitError(unknownLoc) << stringifyOpcode(spirv::getOpcode<OpTy>())
|
||||
<< " must have no operands, but found "
|
||||
<< operands.size() << " operands";
|
||||
}
|
||||
opBuilder.create<OpTy>(unknownLoc);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
|
||||
// Get the result type
|
||||
if (operands.size() != 4) {
|
||||
return emitError(unknownLoc, "OpFunction must have 4 parameters");
|
||||
}
|
||||
Type resultType = getType(operands[0]);
|
||||
if (!resultType) {
|
||||
return emitError(unknownLoc, "unknown result type from <id> ")
|
||||
<< operands[0];
|
||||
}
|
||||
if (funcMap.count(operands[1])) {
|
||||
return emitError(unknownLoc, "duplicate function definition/declaration");
|
||||
}
|
||||
auto functionControl = spirv::symbolizeFunctionControl(operands[2]);
|
||||
if (!functionControl) {
|
||||
return emitError(unknownLoc, "unknown Function Control : ") << operands[2];
|
||||
}
|
||||
if (functionControl.getValue() != spirv::FunctionControl::None) {
|
||||
/// TODO : Handle different function controls
|
||||
return emitError(unknownLoc, "unhandled Function Control : '")
|
||||
<< spirv::stringifyFunctionControl(functionControl.getValue())
|
||||
<< "'";
|
||||
}
|
||||
Type fnType = getType(operands[3]);
|
||||
if (!fnType || !fnType.isa<FunctionType>()) {
|
||||
return emitError(unknownLoc, "unknown function type from <id> ")
|
||||
<< operands[3];
|
||||
}
|
||||
auto functionType = fnType.cast<FunctionType>();
|
||||
if ((isVoidType(resultType) && functionType.getNumResults() != 0) ||
|
||||
(functionType.getNumResults() == 1 &&
|
||||
functionType.getResult(0) != resultType)) {
|
||||
return emitError(unknownLoc, "mismatch in function type ")
|
||||
<< functionType << " and return type " << resultType << " specified";
|
||||
}
|
||||
/// TODO : The function name must be obtained from OpName eventually
|
||||
std::string fnName = "spirv_fn_" + std::to_string(operands[2]);
|
||||
auto funcOp = opBuilder.create<FuncOp>(unknownLoc, fnName, functionType,
|
||||
ArrayRef<NamedAttribute>());
|
||||
funcOp.addEntryBlock();
|
||||
|
||||
// Parse the op argument instructions
|
||||
if (functionType.getNumInputs()) {
|
||||
for (size_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
|
||||
auto argType = functionType.getInput(i);
|
||||
spirv::Opcode opcode;
|
||||
ArrayRef<uint32_t> operands;
|
||||
if (failed(deserializeInstruction(opcode, operands))) {
|
||||
return failure();
|
||||
}
|
||||
if (opcode != spirv::Opcode::OpFunctionParameter) {
|
||||
return emitError(
|
||||
unknownLoc,
|
||||
"missing OpFunctionParameter instruction for argument ")
|
||||
<< i;
|
||||
}
|
||||
if (operands.size() != 2) {
|
||||
return emitError(
|
||||
unknownLoc,
|
||||
"expected result type and result <id> for OpFunctionParameter");
|
||||
}
|
||||
auto argDefinedType = getType(operands[0]);
|
||||
if (argDefinedType || argDefinedType != argType) {
|
||||
return emitError(unknownLoc,
|
||||
"mismatch in argument type between function type "
|
||||
"definition ")
|
||||
<< functionType << " and argument type definition "
|
||||
<< argDefinedType << " at argument " << i;
|
||||
}
|
||||
if (getValue(operands[1])) {
|
||||
return emitError(unknownLoc, "duplicate definition of result <id> '")
|
||||
<< operands[1];
|
||||
}
|
||||
auto argValue = funcOp.getArgument(i);
|
||||
valueMap[operands[1]] = argValue;
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new builder for building the body
|
||||
OpBuilder funcBody(funcOp.getBody());
|
||||
std::swap(funcBody, opBuilder);
|
||||
|
||||
spirv::Opcode opcode;
|
||||
ArrayRef<uint32_t> instOperands;
|
||||
while (succeeded(deserializeInstruction(opcode, instOperands)) &&
|
||||
opcode != spirv::Opcode::OpFunctionEnd) {
|
||||
if (failed(processInstruction(opcode, instOperands))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
std::swap(funcBody, opBuilder);
|
||||
if (opcode != spirv::Opcode::OpFunctionEnd) {
|
||||
return failure();
|
||||
}
|
||||
if (!instOperands.empty()) {
|
||||
return emitError(unknownLoc, "unexpected operands for OpFunctionEnd");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
|
||||
ArrayRef<uint32_t> operands) {
|
||||
switch (opcode) {
|
||||
|
@ -213,6 +360,10 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
|
|||
case spirv::Opcode::OpTypeVoid:
|
||||
case spirv::Opcode::OpTypeFunction:
|
||||
return processType(opcode, operands);
|
||||
case spirv::Opcode::OpReturn:
|
||||
return processNullaryInstruction<spirv::ReturnOp>(operands);
|
||||
case spirv::Opcode::OpFunction:
|
||||
return processFunction(operands);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "mlir/SPIRV/Serialization.h"
|
||||
|
||||
#include "SPIRVBinaryUtils.h"
|
||||
#include "mlir/SPIRV/SPIRVDialect.h"
|
||||
#include "mlir/SPIRV/SPIRVOps.h"
|
||||
#include "mlir/SPIRV/SPIRVTypes.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
|
@ -87,13 +88,24 @@ private:
|
|||
SmallVectorImpl<uint32_t> &operands);
|
||||
|
||||
// Main method to dispatch operation serialization
|
||||
LogicalResult processOperation(Operation *op, uint32_t &opID);
|
||||
LogicalResult processOperation(Operation *op);
|
||||
|
||||
// Methods to serialize individual operation types
|
||||
LogicalResult processFuncOp(FuncOp op, uint32_t &funcID);
|
||||
LogicalResult processFuncOp(FuncOp op);
|
||||
// Serialize op that dont produce a value and have no operands, like
|
||||
// spirv::ReturnOp
|
||||
template <typename OpType> LogicalResult processNullaryOp(OpType op);
|
||||
|
||||
uint32_t getNextID() { return nextID++; }
|
||||
|
||||
Optional<uint32_t> findTypeID(Type type) const {
|
||||
auto it = typeIDMap.find(type);
|
||||
return (it != typeIDMap.end() ? it->second : Optional<uint32_t>(None));
|
||||
}
|
||||
|
||||
Type voidType() { return mlir::NoneType::get(module.getContext()); }
|
||||
bool isVoidType(Type type) const { return type.isa<NoneType>(); }
|
||||
|
||||
private:
|
||||
/// The SPIR-V module to be serialized.
|
||||
spirv::ModuleOp module;
|
||||
|
@ -114,11 +126,13 @@ private:
|
|||
// TODO(antiagainst): debug instructions
|
||||
SmallVector<uint32_t, 0> decorations;
|
||||
SmallVector<uint32_t, 0> typesGlobalValues;
|
||||
SmallVector<uint32_t, 0> functionDecls;
|
||||
SmallVector<uint32_t, 0> functionDefns;
|
||||
SmallVector<uint32_t, 0> functions;
|
||||
|
||||
// Map from type used in SPIR-V module to their IDs
|
||||
DenseMap<Type, uint32_t> typeIDMap;
|
||||
|
||||
// Map from FuncOps to IDs
|
||||
DenseMap<Operation *, uint32_t> funcIDMap;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
@ -132,8 +146,7 @@ LogicalResult Serializer::serialize() {
|
|||
// Iterate over the module body to serialze it. Assumptions are that there is
|
||||
// only one basic block in the moduleOp
|
||||
for (auto &op : module.getBlock()) {
|
||||
uint32_t opID = 0;
|
||||
if (failed(processOperation(&op, opID))) {
|
||||
if (failed(processOperation(&op))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
@ -147,7 +160,7 @@ void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
|
|||
extendedSets.size() + memoryModel.size() +
|
||||
entryPoints.size() + executionModes.size() +
|
||||
decorations.size() + typesGlobalValues.size() +
|
||||
functionDecls.size() + functionDefns.size();
|
||||
functions.size();
|
||||
|
||||
binary.clear();
|
||||
binary.reserve(moduleSize);
|
||||
|
@ -162,8 +175,7 @@ void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
|
|||
binary.append(executionModes.begin(), executionModes.end());
|
||||
binary.append(decorations.begin(), decorations.end());
|
||||
binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
|
||||
binary.append(functionDecls.begin(), functionDecls.end());
|
||||
binary.append(functionDefns.begin(), functionDefns.end());
|
||||
binary.append(functions.begin(), functions.end());
|
||||
}
|
||||
|
||||
LogicalResult Serializer::processHeader() {
|
||||
|
@ -207,9 +219,9 @@ LogicalResult Serializer::processMemoryModel() {
|
|||
|
||||
LogicalResult Serializer::processType(Location loc, Type type,
|
||||
uint32_t &typeID) {
|
||||
auto it = typeIDMap.find(type);
|
||||
if (it != typeIDMap.end()) {
|
||||
typeID = it->second;
|
||||
auto id = findTypeID(type);
|
||||
if (id) {
|
||||
typeID = id.getValue();
|
||||
return success();
|
||||
}
|
||||
typeID = getNextID();
|
||||
|
@ -230,7 +242,7 @@ LogicalResult Serializer::processType(Location loc, Type type,
|
|||
LogicalResult
|
||||
Serializer::processBasicType(Location loc, Type type, spirv::Opcode &typeEnum,
|
||||
SmallVectorImpl<uint32_t> &operands) {
|
||||
if (type.isa<NoneType>()) {
|
||||
if (isVoidType(type)) {
|
||||
typeEnum = spirv::Opcode::OpTypeVoid;
|
||||
return success();
|
||||
}
|
||||
|
@ -246,11 +258,9 @@ Serializer::processFunctionType(Location loc, FunctionType type,
|
|||
assert(type.getNumResults() <= 1 &&
|
||||
"Serialization supports only a single return value");
|
||||
uint32_t resultID = 0;
|
||||
if (failed(processType(loc,
|
||||
type.getNumResults() == 1
|
||||
? type.getResult(0)
|
||||
: mlir::NoneType::get(module.getContext()),
|
||||
resultID))) {
|
||||
if (failed(processType(
|
||||
loc, type.getNumResults() == 1 ? type.getResult(0) : voidType(),
|
||||
resultID))) {
|
||||
return failure();
|
||||
}
|
||||
operands.push_back(resultID);
|
||||
|
@ -264,21 +274,80 @@ Serializer::processFunctionType(Location loc, FunctionType type,
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult Serializer::processOperation(Operation *op, uint32_t &opID) {
|
||||
opID = getNextID();
|
||||
if ((isa<FuncOp>(op) && succeeded(processFuncOp(cast<FuncOp>(op), opID))) ||
|
||||
isa<spirv::ModuleEndOp>(op)) {
|
||||
LogicalResult Serializer::processOperation(Operation *op) {
|
||||
if (isa<FuncOp>(op)) {
|
||||
return processFuncOp(cast<FuncOp>(op));
|
||||
} else if (isa<spirv::ReturnOp>(op)) {
|
||||
return processNullaryOp(cast<spirv::ReturnOp>(op));
|
||||
} else if (isa<spirv::ModuleEndOp>(op)) {
|
||||
return success();
|
||||
}
|
||||
/// TODO(ravishankarm) : Handle other ops
|
||||
return op->emitError("unhandled operation serialization");
|
||||
}
|
||||
|
||||
LogicalResult Serializer::processFuncOp(FuncOp op, uint32_t &funcID) {
|
||||
uint32_t typeID = 0;
|
||||
LogicalResult Serializer::processFuncOp(FuncOp op) {
|
||||
uint32_t fnTypeID = 0;
|
||||
// Generate type of the function
|
||||
processType(op.getLoc(), op.getType(), typeID);
|
||||
// TODO(ravishankarm) : Process Function body
|
||||
processType(op.getLoc(), op.getType(), fnTypeID);
|
||||
|
||||
/// Add the function definition
|
||||
SmallVector<uint32_t, 4> operands;
|
||||
uint32_t resTypeID = 0;
|
||||
auto resultTypes = op.getType().getResults();
|
||||
if (resultTypes.size() > 1) {
|
||||
return emitError(op.getLoc(),
|
||||
"cannot serialize function with multiple return types");
|
||||
}
|
||||
if (failed(processType(op.getLoc(),
|
||||
(resultTypes.empty() ? voidType() : resultTypes[0]),
|
||||
resTypeID))) {
|
||||
return failure();
|
||||
}
|
||||
operands.push_back(resTypeID);
|
||||
auto funcID = getNextID();
|
||||
funcIDMap[op.getOperation()] = funcID;
|
||||
operands.push_back(funcID);
|
||||
/// TODO : Support other function control options
|
||||
operands.push_back(static_cast<uint32_t>(spirv::FunctionControl::None));
|
||||
operands.push_back(fnTypeID);
|
||||
buildInstruction(spirv::Opcode::OpFunction, operands, functions);
|
||||
|
||||
// Declare the parameters
|
||||
for (auto argType : op.getType().getInputs()) {
|
||||
uint32_t argTypeID = 0;
|
||||
if (failed(processType(op.getLoc(), argType, argTypeID))) {
|
||||
return failure();
|
||||
}
|
||||
buildInstruction(spirv::Opcode::OpFunctionParameter,
|
||||
{argTypeID, getNextID()}, functions);
|
||||
}
|
||||
|
||||
// Process the body
|
||||
if (!op.empty()) {
|
||||
for (auto &b : op) {
|
||||
for (auto &op : b) {
|
||||
if (failed(processOperation(&op))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Insert Function End
|
||||
buildInstruction(spirv::Opcode::OpFunctionEnd, {}, functions);
|
||||
|
||||
// If the function body is empty return an error
|
||||
// TODO : Handle external functions
|
||||
if (op.empty()) {
|
||||
return emitError(op.getLoc(), "external function is unhandled");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename OpType>
|
||||
LogicalResult Serializer::processNullaryOp(OpType op) {
|
||||
buildInstruction(spirv::getOpcode<OpType>(), ArrayRef<uint32_t>(), functions);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -2,13 +2,11 @@
|
|||
|
||||
// CHECK-LABEL: func @spirv_module
|
||||
// CHECK: spv.module "Logical" "VulkanKHR" {
|
||||
// CHECK-NEXT: func @spirv_fn_0() {
|
||||
// CHECK-NEXT: spv.Return
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: } attributes {major_version = 1 : i32, minor_version = 0 : i32}
|
||||
|
||||
// TODO(ravishankarm) : The output produced is not correct, since it
|
||||
// doesnt get the function body. The serialization doesnt handle
|
||||
// functions yet. Change the CHECK once it does, to make sure the
|
||||
// function is reproduced
|
||||
|
||||
func @spirv_module() -> () {
|
||||
spv.module "Logical" "VulkanKHR" {
|
||||
func @foo() -> () {
|
||||
|
|
|
@ -44,10 +44,12 @@ using mlir::tblgen::Operator;
|
|||
static void emitGetOpcodeFunction(const llvm::Record &record,
|
||||
Operator const &op, raw_ostream &os) {
|
||||
if (record.getValueAsInt("hasOpcode")) {
|
||||
os << formatv("template <> constexpr inline uint32_t getOpcode<{0}>()",
|
||||
os << formatv("template <> constexpr inline ::mlir::spirv::Opcode "
|
||||
"getOpcode<{0}>()",
|
||||
op.getQualCppClassName())
|
||||
<< " {\n return static_cast<uint32_t>("
|
||||
<< formatv("Opcode::Op{0});\n}\n", record.getValueAsString("opName"));
|
||||
<< " {\n "
|
||||
<< formatv("return ::mlir::spirv::Opcode::Op{0};\n}\n",
|
||||
record.getValueAsString("opName"));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -56,7 +58,8 @@ static bool emitSerializationUtils(const RecordKeeper &recordKeeper,
|
|||
llvm::emitSourceFileHeader("SPIR-V Serialization Utilities", os);
|
||||
|
||||
/// Define the function to get the opcode
|
||||
os << "template <typename OpClass> inline constexpr uint32_t getOpcode();\n";
|
||||
os << "template <typename OpClass> inline constexpr ::mlir::spirv::Opcode "
|
||||
"getOpcode();\n";
|
||||
auto defs = recordKeeper.getAllDerivedDefinitions("SPV_Op");
|
||||
for (const auto *def : defs) {
|
||||
Operator op(def);
|
||||
|
|
|
@ -29,7 +29,6 @@
|
|||
# in SPIR-V
|
||||
|
||||
set -e
|
||||
set -x
|
||||
|
||||
current_file="$(readlink -f "$0")"
|
||||
current_dir="$(dirname "$current_file")"
|
||||
|
|
Loading…
Reference in New Issue