diff --git a/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt b/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt index dd933fe63577..af4520df1304 100644 --- a/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/SPIRV/CMakeLists.txt @@ -9,7 +9,7 @@ mlir_tablegen(SPIRVEnums.cpp.inc -gen-enum-defs) add_public_tablegen_target(MLIRSPIRVEnumsIncGen) set(LLVM_TARGET_DEFINITIONS SPIRVOps.td) -mlir_tablegen(SPIRVSerialization.inc -gen-spirv-serial) +mlir_tablegen(SPIRVSerialization.inc -gen-spirv-serialization) add_public_tablegen_target(MLIRSPIRVSerializationGen) set(LLVM_TARGET_DEFINITIONS SPIRVBase.td) diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index d9281234821b..83f474cd7805 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -1,4 +1,4 @@ -//===-- SPIRVOps.td - MLIR SPIR-V Op Definitions Spec ------*- tablegen -*-===// +//===- SPIRVBase.td - MLIR SPIR-V Op Definitions Base file -*- tablegen -*-===// // // Copyright 2019 The MLIR Authors. // @@ -76,6 +76,8 @@ def SPV_OC_OpMemoryModel : I32EnumAttrCase<"OpMemoryModel", 14>; def SPV_OC_OpEntryPoint : I32EnumAttrCase<"OpEntryPoint", 15>; def SPV_OC_OpExecutionMode : I32EnumAttrCase<"OpExecutionMode", 16>; def SPV_OC_OpTypeVoid : I32EnumAttrCase<"OpTypeVoid", 19>; +def SPV_OC_OpTypeFloat : I32EnumAttrCase<"OpTypeFloat", 22>; +def SPV_OC_OpTypePointer : I32EnumAttrCase<"OpTypePointer", 32>; def SPV_OC_OpTypeFunction : I32EnumAttrCase<"OpTypeFunction", 33>; def SPV_OC_OpFunction : I32EnumAttrCase<"OpFunction", 54>; def SPV_OC_OpFunctionParameter : I32EnumAttrCase<"OpFunctionParameter", 55>; @@ -91,10 +93,10 @@ def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>; def SPV_OpcodeAttr : I32EnumAttr<"Opcode", "valid SPIR-V instructions", [ SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode, - SPV_OC_OpTypeVoid, SPV_OC_OpTypeFunction, SPV_OC_OpFunction, - SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpVariable, - SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpDecorate, SPV_OC_OpCompositeExtract, - SPV_OC_OpFMul, SPV_OC_OpReturn + SPV_OC_OpTypeVoid, SPV_OC_OpTypeFloat, SPV_OC_OpTypePointer, + SPV_OC_OpTypeFunction, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, + SPV_OC_OpFunctionEnd, SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, + SPV_OC_OpDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpFMul, SPV_OC_OpReturn ]> { let returnType = "::mlir::spirv::Opcode"; let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())"; @@ -514,8 +516,6 @@ def ModuleOnly : class SPV_Op traits = []> : Op { - string spirvOpName = "Op" # mnemonic; - // For each SPIR-V op, the following static functions need to be defined // in SPVOps.cpp: // @@ -527,13 +527,34 @@ class SPV_Op traits = []> : let printer = [{ return ::print(*this, p); }]; let verifier = [{ return ::verify(*this); }]; - // By default the opcode to use for (de)serialization is obtained - // automatically from the SPIR-V spec. It assume the SPIR-V op being defined - // is ('Op' # mnemonic). The opcode value can be obtained by calling - // getOpcode(). If invoking this method is invalid or custom - // processing is needed for the op, set hasOpcode = 0 and specialize the - // getOpcode method. - int hasOpcode = 1; + // Specifies if the SPIR-V op has a direct representation of an instruction in + // SPIR-V spec. When set to 1, the ODS generates the dispatch to the + // (de)serialization methods for the op, and also generates the implementation + // of these processing methods (if autogenSerialization is also set to 1). + bit hasOpcode = 1; + + // Name of the corresponding SPIR-V op. Only valid to use when hasOpcode is 1 + string spirvOpName = "Op" # mnemonic; + + // Controls whether the (de)serialization method is generated automatically or + // not. This results in generation of the following methods: + // + // template Serialization::processOp(OpTy op) + // template Deserialization::processOp(ArrayRef) + // + // If the auto generation is disabled (set to 0), then manual implementation + // of a specialization of these methods is required. + // + // Note : + // + // 1) If hasOpcode is 1 and autogenSerialization is 0, the ODS generated + // dispatch function call the above method for the (de)serialization of the + // operation + // + // 2) If hasOpcode is 0, then ODS doesn't generate the (de)serialization + // methods, neither does it handle the dispatch. Those need to be handled + // manually. + bit autogenSerialization = 1; } #endif // SPIRV_BASE diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td index 3c1f2434c7ba..b638bc8afa0a 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -143,6 +143,7 @@ def SPV_EntryPointOp : SPV_Op<"EntryPoint", [ModuleOnly]> { ); let results = (outs SPV_EntryPoint:$id); + let autogenSerialization = 0; } // ----- @@ -344,7 +345,7 @@ def SPV_StoreOp : SPV_Op<"Store", []> { SPV_AnyPtr:$ptr, SPV_Type:$value, OptionalAttr:$memory_access, - OptionalAttr:$alignment + OptionalAttr:$alignment ); let results = (outs); diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h index 6950510f51d8..a1a7d0c820d3 100644 --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -96,6 +96,10 @@ public: // of `TypeAttrBase`). bool isTypeAttr() const; + // Returns true if this attribute is an enum attribute (i.e., a subclass of + // `EnumAttrInfo`) + bool isEnumAttr() const; + // Returns this attribute's TableGen def name. If this is an `OptionalAttr` // or `DefaultValuedAttr` without explicit name, returns the base attribute's // name. diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h index 37e7bc4e6c3e..7cad27a4300b 100644 --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -132,7 +132,7 @@ public: int getNumArgs() const { return arguments.size(); } // Op argument (attribute or operand) accessors. - Argument getArg(int index); + Argument getArg(int index) const; StringRef getArgName(int index) const; // Returns the number of `PredOpTrait` traits. diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index cdd95e768ea1..676ef8ec8712 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -88,9 +88,24 @@ private: LogicalResult processType(spirv::Opcode opcode, ArrayRef operands); LogicalResult processFunctionType(ArrayRef operands); - /// Process SPIR-V instructions that dont have any operands + /// Method to dispatch to the specialized deserialization function for an + /// operation in SPIR-V dialect that is a mirror of an operation in the SPIR-V + /// spec. This is auto-generated from ODS. Dispatch is handled for all + /// operations in SPIR-V dialect that have hasOpcode == 1 + LogicalResult dispatchToAutogenDeserialization(spirv::Opcode opcode, + ArrayRef words); + + /// Method to deserialize an operation in the SPIR-V dialect that is a mirror + /// of an instruction in the SPIR-V spec. This is auto generated if hasOpcode + /// == 1 and autogenSerialization == 1 in ODS. + template LogicalResult processOp(ArrayRef words) { + return processOpImpl(words); + } template - LogicalResult processNullaryInstruction(ArrayRef operands); + LogicalResult processOpImpl(ArrayRef words) { + return emitError(unknownLoc, "unsupported deserialization for op '") + << OpTy::getOperationName() << "')"; + } /// Process function objects in binary LogicalResult processFunction(ArrayRef operands); @@ -232,6 +247,39 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode, } typeMap[operands[0]] = NoneType::get(context); break; + case spirv::Opcode::OpTypeFloat: { + if (operands.size() != 2) { + return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter"); + } + Type floatTy; + switch (operands[1]) { + case 16: + floatTy = opBuilder.getF16Type(); + break; + case 32: + floatTy = opBuilder.getF32Type(); + break; + case 64: + floatTy = opBuilder.getF64Type(); + break; + default: + return emitError(unknownLoc, "unsupported bitwdith ") + << operands[1] << " with OpTypeFloat"; + } + typeMap[operands[0]] = floatTy; + } break; + case spirv::Opcode::OpTypePointer: { + if (operands.size() != 3) { + return emitError(unknownLoc, "OpTypePointer must have two parameters"); + } + auto pointeeType = getType(operands[2]); + if (!pointeeType) { + return emitError(unknownLoc, "unknown OpTypePointer pointee type : ") + << operands[2]; + } + auto storageClass = static_cast(operands[1]); + typeMap[operands[0]] = spirv::PointerType::get(pointeeType, storageClass); + } break; case spirv::Opcode::OpTypeFunction: return processFunctionType(operands); default: @@ -240,18 +288,6 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode, return success(); } -template -LogicalResult -Deserializer::processNullaryInstruction(ArrayRef operands) { - if (!operands.empty()) { - return emitError(unknownLoc) << stringifyOpcode(spirv::getOpcode()) - << " must have no operands, but found " - << operands.size() << " operands"; - } - opBuilder.create(unknownLoc); - return success(); -} - LogicalResult Deserializer::processFunction(ArrayRef operands) { // Get the result type if (operands.size() != 4) { @@ -314,7 +350,7 @@ LogicalResult Deserializer::processFunction(ArrayRef operands) { "expected result type and result for OpFunctionParameter"); } auto argDefinedType = getType(operands[0]); - if (argDefinedType || argDefinedType != argType) { + if (!argDefinedType || argDefinedType != argType) { return emitError(unknownLoc, "mismatch in argument type between function type " "definition ") @@ -352,23 +388,26 @@ LogicalResult Deserializer::processFunction(ArrayRef operands) { return success(); } +#define GET_DESERIALIZATION_FNS +#include "mlir/Dialect/SPIRV/SPIRVSerialization.inc" + LogicalResult Deserializer::processInstruction(spirv::Opcode opcode, ArrayRef operands) { + // First dispatch all the instructions whose opcode does not correspond to + // those that have a direct mirror in the SPIR-V dialect switch (opcode) { case spirv::Opcode::OpMemoryModel: return processMemoryModel(operands); - case spirv::Opcode::OpTypeVoid: + case spirv::Opcode::OpTypeFloat: case spirv::Opcode::OpTypeFunction: + case spirv::Opcode::OpTypePointer: + case spirv::Opcode::OpTypeVoid: return processType(opcode, operands); - case spirv::Opcode::OpReturn: - return processNullaryInstruction(operands); case spirv::Opcode::OpFunction: return processFunction(operands); - default: - break; + default:; } - return emitError(unknownLoc, "NYI: opcode ") - << spirv::stringifyOpcode(opcode); + return dispatchToAutogenDeserialization(opcode, operands); } LogicalResult Deserializer::processMemoryModel(ArrayRef operands) { diff --git a/mlir/lib/Dialect/SPIRV/Serialization/SPIRVBinaryUtils.h b/mlir/lib/Dialect/SPIRV/Serialization/SPIRVBinaryUtils.h index c04c725aaf06..12fb3d00be20 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/SPIRVBinaryUtils.h +++ b/mlir/lib/Dialect/SPIRV/Serialization/SPIRVBinaryUtils.h @@ -35,6 +35,7 @@ constexpr unsigned kHeaderWordCount = 5; /// SPIR-V magic number constexpr uint32_t kMagicNumber = 0x07230203; +#define GET_SPIRV_SERIALIZATION_UTILS #include "mlir/Dialect/SPIRV/SPIRVSerialization.inc" } // end namespace spirv diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index ab931cc82692..b1c5936ae1a5 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -41,7 +41,9 @@ static inline void buildInstruction(spirv::Opcode op, SmallVectorImpl &binary) { uint32_t wordCount = 1 + operands.size(); binary.push_back(getPrefixedOpcode(wordCount, op)); - binary.append(operands.begin(), operands.end()); + if (!operands.empty()) { + binary.append(operands.begin(), operands.end()); + } } namespace { @@ -90,11 +92,24 @@ private: // Main method to dispatch operation serialization LogicalResult processOperation(Operation *op); + /// Method to dispatch to the serialization function for an operation in + /// SPIR-V dialect that is a mirror of an instruction in the SPIR-V spec. This + /// is auto-generated from ODS. Dispatch is handled for all operations in + /// SPIR-V dialect that have hasOpcode == 1 + LogicalResult dispatchToAutogenSerialization(Operation *op); + + /// Method to serialize an operation in the SPIR-V dialect that is a mirror of + /// an instruction in the SPIR-V spec. This is auto generated if hasOpcode == + /// 1 and autogenSerialization == 1 in ODS + template LogicalResult processOp(OpTy op) { + return processOpImpl(op); + } + template LogicalResult processOpImpl(OpTy op) { + return op.emitError("unsupported op serialization"); + } + // Methods to serialize individual operation types LogicalResult processFuncOp(FuncOp op); - // Serialize op that dont produce a value and have no operands, like - // spirv::ReturnOp - template LogicalResult processNullaryOp(OpType op); uint32_t getNextID() { return nextID++; } @@ -103,6 +118,16 @@ private: return (it != typeIDMap.end() ? it->second : Optional(None)); } + Optional findValueID(Value *val) const { + auto it = valueIDMap.find(val); + return (it != valueIDMap.end() ? it->second : Optional(None)); + } + + Optional findFunctionID(Operation *op) const { + auto it = funcIDMap.find(op); + return (it != funcIDMap.end() ? it->second : Optional(None)); + } + Type voidType() { return mlir::NoneType::get(module.getContext()); } bool isVoidType(Type type) const { return type.isa(); } @@ -133,6 +158,9 @@ private: // Map from FuncOps to IDs DenseMap funcIDMap; + + // Map from Value to Ids + DenseMap valueIDMap; }; } // namespace @@ -245,6 +273,20 @@ Serializer::processBasicType(Location loc, Type type, spirv::Opcode &typeEnum, if (isVoidType(type)) { typeEnum = spirv::Opcode::OpTypeVoid; return success(); + } else if (type.isa()) { + typeEnum = spirv::Opcode::OpTypeFloat; + operands.push_back(type.cast().getWidth()); + return success(); + } else if (type.isa()) { + auto ptrType = type.cast(); + uint32_t pointeeTypeID = 0; + if (failed(processType(loc, ptrType.getPointeeType(), pointeeTypeID))) { + return failure(); + } + typeEnum = spirv::Opcode::OpTypePointer; + operands.push_back(static_cast(ptrType.getStorageClass())); + operands.push_back(pointeeTypeID); + return success(); } /// TODO(ravishankarm) : Handle other types return emitError(loc, "unhandled type in serialization : ") << type; @@ -275,15 +317,14 @@ Serializer::processFunctionType(Location loc, FunctionType type, } LogicalResult Serializer::processOperation(Operation *op) { + // First dispatch the methods that do not directly mirror an operation from + // the SPIR-V spec if (isa(op)) { return processFuncOp(cast(op)); - } else if (isa(op)) { - return processNullaryOp(cast(op)); } else if (isa(op)) { return success(); } - /// TODO(ravishankarm) : Handle other ops - return op->emitError("unhandled operation serialization"); + return dispatchToAutogenSerialization(op); } LogicalResult Serializer::processFuncOp(FuncOp op) { @@ -314,13 +355,15 @@ LogicalResult Serializer::processFuncOp(FuncOp op) { buildInstruction(spirv::Opcode::OpFunction, operands, functions); // Declare the parameters - for (auto argType : op.getType().getInputs()) { + for (auto arg : op.getArguments()) { uint32_t argTypeID = 0; - if (failed(processType(op.getLoc(), argType, argTypeID))) { + if (failed(processType(op.getLoc(), arg->getType(), argTypeID))) { return failure(); } + auto argValueID = getNextID(); + valueIDMap[arg] = argValueID; buildInstruction(spirv::Opcode::OpFunctionParameter, - {argTypeID, getNextID()}, functions); + {argTypeID, argValueID}, functions); } // Process the body @@ -345,11 +388,8 @@ LogicalResult Serializer::processFuncOp(FuncOp op) { return success(); } -template -LogicalResult Serializer::processNullaryOp(OpType op) { - buildInstruction(spirv::getOpcode(), ArrayRef(), functions); - return success(); -} +#define GET_SERIALIZATION_FNS +#include "mlir/Dialect/SPIRV/SPIRVSerialization.inc" LogicalResult spirv::serialize(spirv::ModuleOp module, SmallVectorImpl &binary) { diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp index 2daccc8e23fe..abf0ef06f0cb 100644 --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -63,6 +63,10 @@ bool tblgen::Attribute::isTypeAttr() const { return def->isSubClassOf("TypeAttrBase"); } +bool tblgen::Attribute::isEnumAttr() const { + return def->isSubClassOf("EnumAttrInfo"); +} + bool tblgen::Attribute::hasStorageType() const { const auto *init = def->getValueInit("storageType"); return !getValueAsString(init).empty(); diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp index 0a921369d6b3..ba919184efe5 100644 --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -217,7 +217,7 @@ auto tblgen::Operator::getOperands() -> value_range { return {operand_begin(), operand_end()}; } -auto tblgen::Operator::getArg(int index) -> Argument { +auto tblgen::Operator::getArg(int index) const -> Argument { return arguments[index]; } diff --git a/mlir/test/SPIRV/Serialization/load_store.mlir b/mlir/test/SPIRV/Serialization/load_store.mlir new file mode 100644 index 000000000000..a6d66c744dcb --- /dev/null +++ b/mlir/test/SPIRV/Serialization/load_store.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s + +// CHECK: func {{@.*}}([[ARG1:%.*]]: !spv.ptr, [[ARG2:%.*]]: !spv.ptr) { +// CHECK-NEXT: [[VALUE:%.*]] = spv.Load "Input" [[ARG1]] : f32 +// CHECK-NEXT: spv.Store "Output" [[ARG2]], [[VALUE]] : f32 + +func @spirv_loadstore() -> () { + spv.module "Logical" "VulkanKHR" { + func @load_store(%arg0 : !spv.ptr, %arg1 : !spv.ptr) { + %1 = spv.Load "Input" %arg0 : f32 + spv.Store "Output" %arg1, %1 : f32 + spv.Return + } + } + return +} \ No newline at end of file diff --git a/mlir/test/SPIRV/Serialization/variables.mlir b/mlir/test/SPIRV/Serialization/variables.mlir new file mode 100644 index 000000000000..dbb1f7fd380f --- /dev/null +++ b/mlir/test/SPIRV/Serialization/variables.mlir @@ -0,0 +1,11 @@ +// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s + +// CHECK: {{%.*}} = spv.Variable : !spv.ptr +// CHECK-NEXT: {{%.*}} = spv.Variable : !spv.ptr +func @spirv_variables() -> () { + spv.module "Logical" "VulkanKHR" { + %2 = spv.Variable : !spv.ptr + %3 = spv.Variable : !spv.ptr + } + return +} \ No newline at end of file diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp index 1dcf40fda282..869fe1ad15b6 100644 --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -32,39 +32,317 @@ #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" +using llvm::ArrayRef; using llvm::formatv; using llvm::raw_ostream; +using llvm::raw_string_ostream; using llvm::Record; using llvm::RecordKeeper; +using llvm::SMLoc; +using mlir::tblgen::Attribute; using mlir::tblgen::EnumAttr; +using mlir::tblgen::NamedAttribute; +using mlir::tblgen::NamedTypeConstraint; using mlir::tblgen::Operator; // Writes the following function to `os`: // inline uint32_t getOpcode() { return ; } -static void emitGetOpcodeFunction(const llvm::Record &record, - Operator const &op, raw_ostream &os) { - if (record.getValueAsInt("hasOpcode")) { - os << formatv("template <> constexpr inline ::mlir::spirv::Opcode " - "getOpcode<{0}>()", - op.getQualCppClassName()) - << " {\n " - << formatv("return ::mlir::spirv::Opcode::Op{0};\n}\n", - record.getValueAsString("opName")); +static void emitGetOpcodeFunction(const Record *record, Operator const &op, + raw_ostream &os) { + os << formatv("template <> constexpr inline ::mlir::spirv::Opcode " + "getOpcode<{0}>()", + op.getQualCppClassName()) + << " {\n " + << formatv("return ::mlir::spirv::Opcode::{0};\n}\n", + record->getValueAsString("spirvOpName")); +} + +static void declareOpcodeFn(raw_ostream &os) { + os << "template inline constexpr ::mlir::spirv::Opcode " + "getOpcode();\n"; +} + +static void emitAttributeSerialization(const Attribute &attr, + ArrayRef loc, llvm::StringRef op, + llvm::StringRef operandList, + llvm::StringRef attrName, + raw_ostream &os) { + os << " auto attr = " << op << ".getAttr(\"" << attrName << "\");\n"; + os << " if (attr) {\n"; + if (attr.getAttrDefName() == "I32ArrayAttr") { + // Serialize all the elements of the array + os << " for (auto attrElem : attr.cast()) {\n"; + os << " " << operandList + << ".push_back(static_cast(attrElem.cast()." + "getValue().getZExtValue()));\n"; + os << " }\n"; + } else if (attr.isEnumAttr() || attr.getAttrDefName() == "I32Attr") { + os << " " << operandList + << ".push_back(static_cast(attr.cast().getValue()" + ".getZExtValue()));\n"; + } else { + PrintFatalError( + loc, + llvm::Twine( + "unhandled attribute type in SPIR-V serialization generation : '") + + attr.getAttrDefName() + llvm::Twine("'")); + } + os << " }\n"; +} + +static void emitSerializationFunction(const Record *record, const Operator &op, + raw_ostream &os) { + // If the record has 'autogenSerialization' set to 0, nothing to do + if (!record->getValueAsBit("autogenSerialization")) { + return; + } + os << formatv("template <> LogicalResult\nSerializer::processOpImpl<{0}>(\n" + " {0} op)", + op.getQualCppClassName()) + << " {\n"; + os << " SmallVector operands;\n"; + + // Serialize result information + if (op.getNumResults() == 1) { + os << " {\n"; + os << " uint32_t typeID = 0;\n"; + os << " if (failed(processType(op.getLoc(), " + "op.getResult()->getType(), typeID))) {\n"; + os << " return failure();\n"; + os << " }\n"; + os << " operands.push_back(typeID);\n"; + /// Create an SSA result for the op + os << " auto resultID = getNextID();\n"; + os << " valueIDMap[op.getResult()] = resultID;\n"; + os << " operands.push_back(resultID);\n"; + os << " }\n"; + } else if (op.getNumResults() != 0) { + PrintFatalError(record->getLoc(), "SPIR-V ops can only zero or one result"); + } + + // Process arguments + auto operandNum = 0; + for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) { + auto argument = op.getArg(i); + os << " {\n"; + if (argument.is()) { + os << " if (" << operandNum + << " < op.getOperation()->getNumOperands()) {\n"; + os << " auto arg = findValueID(op.getOperation()->getOperand(" + << operandNum << "));\n"; + os << " if (!arg) {\n"; + os << " emitError(op.getLoc(), \"operand " << operandNum + << " has a use before def\");\n"; + os << " }\n"; + os << " operands.push_back(arg.getValue());\n"; + os << " }\n"; + operandNum++; + } else { + auto attr = argument.get(); + emitAttributeSerialization( + (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr), + record->getLoc(), "op", "operands", attr->name, os); + } + os << " }\n"; + } + + os << formatv( + " buildInstruction(spirv::getOpcode<{0}>(), operands, functions);\n", + op.getQualCppClassName()); + os << " return success();\n"; + os << "}\n\n"; +} + +static void initDispatchSerializationFn(raw_ostream &os) { + os << "LogicalResult Serializer::dispatchToAutogenSerialization(Operation " + "*op) {\n "; +} + +static void emitSerializationDispatch(const Operator &op, raw_ostream &os) { + os << formatv(" if (isa<{0}>(op)) ", op.getQualCppClassName()) << "{\n"; + os << " "; + os << formatv("return processOp<{0}>(cast<{0}>(op));\n", + op.getQualCppClassName()); + os << " } else"; +} + +static void finalizeDispatchSerializationFn(raw_ostream &os) { + os << " {\n"; + os << " return op->emitError(\"unhandled operation serialization\");\n"; + os << " }\n"; + os << " return success();\n"; + os << "}\n\n"; +} + +static void emitAttributeDeserialization( + const Attribute &attr, ArrayRef loc, llvm::StringRef attrList, + llvm::StringRef attrName, llvm::StringRef operandsList, + llvm::StringRef wordIndex, llvm::StringRef wordCount, raw_ostream &os) { + if (attr.getAttrDefName() == "I32ArrayAttr") { + os << " SmallVector attrListElems;\n"; + os << " while (" << wordIndex << " < " << wordCount << ") {\n"; + os << " attrListElems.push_back(opBuilder.getI32IntegerAttr(" + << operandsList << "[" << wordIndex << "++]));\n"; + os << " }\n"; + os << " " << attrList << ".push_back(opBuilder.getNamedAttr(\"" + << attrName << "\", opBuilder.getArrayAttr(attrListElems)));\n"; + } else if (attr.isEnumAttr() || attr.getAttrDefName() == "I32Attr") { + os << " " << attrList << ".push_back(opBuilder.getNamedAttr(\"" + << attrName << "\", opBuilder.getI32IntegerAttr(" << operandsList << "[" + << wordIndex << "++])));\n"; + } else { + PrintFatalError( + loc, llvm::Twine( + "unhandled attribute type in deserialization generation : '") + + attr.getAttrDefName() + llvm::Twine("'")); } } -static bool emitSerializationUtils(const RecordKeeper &recordKeeper, - raw_ostream &os) { - llvm::emitSourceFileHeader("SPIR-V Serialization Utilities", os); +static void emitDeserializationFunction(const Record *record, + const Operator &op, raw_ostream &os) { + // If the record has 'autogenSerialization' set to 0, nothing to do + if (!record->getValueAsBit("autogenSerialization")) { + return; + } + os << formatv("template <> " + "LogicalResult\nDeserializer::processOpImpl<{0}>(ArrayRef<" + "uint32_t> words)", + op.getQualCppClassName()); + os << " {\n"; + os << " SmallVector resultTypes;\n"; + os << " size_t wordIndex = 0;\n"; - /// Define the function to get the opcode - os << "template inline constexpr ::mlir::spirv::Opcode " - "getOpcode();\n"; + // Deserialize result information if it exists + bool hasResult = false; + if (op.getNumResults() == 1) { + os << " {\n"; + os << " if (wordIndex >= words.size()) {\n"; + os << " " + << formatv("return emitError(unknownLoc, \"expected result type " + "while deserializing {0}\");\n", + op.getQualCppClassName()); + os << " }\n"; + os << " auto ty = getType(words[wordIndex]);\n"; + os << " if (!ty) {\n"; + os << " return emitError(unknownLoc, \"unknown type result : " + "\") << words[wordIndex];\n"; + os << " }\n"; + os << " resultTypes.push_back(ty);\n"; + os << " wordIndex++;\n"; + os << " }\n"; + os << " if (wordIndex >= words.size()) {\n"; + os << " " + << formatv("return emitError(unknownLoc, \"expected result while " + "deserializing {0}\");\n", + op.getQualCppClassName()); + os << " }\n"; + os << " uint32_t valueID = words[wordIndex++];\n"; + hasResult = true; + } else if (op.getNumResults() != 0) { + PrintFatalError(record->getLoc(), + "SPIR-V ops can have only zero or one result"); + } + + // Process arguments/attributes + os << " SmallVector operands;\n"; + os << " SmallVector attributes;\n"; + unsigned operandNum = 0; + for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) { + auto argument = op.getArg(i); + os << " if (wordIndex < words.size()) {\n"; + if (argument.is()) { + os << " auto arg = getValue(words[wordIndex]);\n"; + os << " if (!arg) {\n"; + os << " return emitError(unknownLoc, \"unknown result : \") << " + "words[wordIndex];\n"; + os << " }\n"; + os << " operands.push_back(arg);\n"; + os << " wordIndex++;\n"; + operandNum++; + } else { + auto attr = argument.get(); + emitAttributeDeserialization( + (attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr), + record->getLoc(), "attributes", attr->name, "words", "wordIndex", + "words.size()", os); + } + os << " }\n"; + } + + os << formatv(" auto op = opBuilder.create<{0}>(unknownLoc, resultTypes, " + "operands, attributes);\n", + op.getQualCppClassName()); + if (hasResult) { + os << " valueMap[valueID] = op.getResult();\n"; + } + os << " return success();\n"; + os << "}\n\n"; +} + +static void initDispatchDeserializationFn(raw_ostream &os) { + os << "LogicalResult " + "Deserializer::dispatchToAutogenDeserialization(spirv::Opcode " + "opcode, ArrayRef words) {\n"; + os << " switch (opcode) {\n"; +} + +static void emitDeserializationDispatch(const Operator &op, const Record *def, + raw_ostream &os) { + os << formatv(" case spirv::Opcode::{0}:\n", + def->getValueAsString("spirvOpName")); + os << formatv(" return processOp<{0}>(words);\n", + op.getQualCppClassName()); +} + +static void finalizeDispatchDeserializationFn(raw_ostream &os) { + os << " default:\n"; + os << " ;\n"; + os << " }\n"; + os << " return emitError(unknownLoc, \"unhandled deserialization of \") << " + "spirv::stringifyOpcode(opcode);\n"; + os << "}\n"; +} + +static bool emitSerializationFns(const RecordKeeper &recordKeeper, + raw_ostream &os) { + llvm::emitSourceFileHeader("SPIR-V Serialization Utilities/Functions", os); + + std::string dSerFnString, dDesFnString, serFnString, deserFnString, + utilsString; + raw_string_ostream dSerFn(dSerFnString), dDesFn(dDesFnString), + serFn(serFnString), deserFn(deserFnString), utils(utilsString); + + declareOpcodeFn(utils); + initDispatchSerializationFn(dSerFn); + initDispatchDeserializationFn(dDesFn); auto defs = recordKeeper.getAllDerivedDefinitions("SPV_Op"); for (const auto *def : defs) { + if (!def->getValueAsBit("hasOpcode")) { + continue; + } Operator op(def); - emitGetOpcodeFunction(*def, op, os); + emitGetOpcodeFunction(def, op, utils); + emitSerializationFunction(def, op, serFn); + emitSerializationDispatch(op, dSerFn); + emitDeserializationFunction(def, op, deserFn); + emitDeserializationDispatch(op, def, dDesFn); } + finalizeDispatchSerializationFn(dSerFn); + finalizeDispatchDeserializationFn(dDesFn); + + os << "#ifdef GET_SPIRV_SERIALIZATION_UTILS\n"; + os << utils.str(); + os << "#endif // GET_SPIRV_SERIALIZATION_UTILS\n\n"; + + os << "#ifdef GET_SERIALIZATION_FNS\n\n"; + os << serFn.str(); + os << dSerFn.str(); + os << "#endif // GET_SERIALIZATION_FNS\n\n"; + + os << "#ifdef GET_DESERIALIZATION_FNS\n\n"; + os << deserFn.str(); + os << dDesFn.str(); + os << "#endif // GET_DESERIALIZATION_FNS\n\n"; return false; } @@ -134,12 +412,12 @@ static bool emitOpUtils(const RecordKeeper &recordKeeper, raw_ostream &os) { } // Registers the enum utility generator to mlir-tblgen. -static mlir::GenRegistration - genSerializationDefs("gen-spirv-serial", - "Generate SPIR-V serialization utility definitions", - [](const RecordKeeper &records, raw_ostream &os) { - return emitSerializationUtils(records, os); - }); +static mlir::GenRegistration genSerialization( + "gen-spirv-serialization", + "Generate SPIR-V (de)serialization utilities and functions", + [](const RecordKeeper &records, raw_ostream &os) { + return emitSerializationFns(records, os); + }); static mlir::GenRegistration genOpUtils("gen-spirv-op-utils",