Automatically generate (de)serialization methods for SPIR-V ops

For ops in SPIR-V dialect that are a direct mirror of SPIR-V
operations, the serialization/deserialization methods can be
automatically generated from the Op specification. To enable this an
'autogenSerialization' field is added to SPV_Ops. When set to
non-zero, this will enable the automatic (de)serialization function
generation

Also adding tests that verify the spv.Load, spv.Store and spv.Variable
ops are serialized and deserialized correctly. To fully support these
tests also add serialization and deserialization of float types and
spv.ptr types

PiperOrigin-RevId: 258684764
This commit is contained in:
Mahesh Ravishankar 2019-07-17 18:41:28 -07:00 committed by Mehdi Amini
parent ec66bc57a8
commit c6cfebf1af
13 changed files with 493 additions and 78 deletions

View File

@ -9,7 +9,7 @@ mlir_tablegen(SPIRVEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRSPIRVEnumsIncGen)
set(LLVM_TARGET_DEFINITIONS SPIRVOps.td)
mlir_tablegen(SPIRVSerialization.inc -gen-spirv-serial)
mlir_tablegen(SPIRVSerialization.inc -gen-spirv-serialization)
add_public_tablegen_target(MLIRSPIRVSerializationGen)
set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)

View File

@ -1,4 +1,4 @@
//===-- SPIRVOps.td - MLIR SPIR-V Op Definitions Spec ------*- tablegen -*-===//
//===- SPIRVBase.td - MLIR SPIR-V Op Definitions Base file -*- tablegen -*-===//
//
// Copyright 2019 The MLIR Authors.
//
@ -76,6 +76,8 @@ def SPV_OC_OpMemoryModel : I32EnumAttrCase<"OpMemoryModel", 14>;
def SPV_OC_OpEntryPoint : I32EnumAttrCase<"OpEntryPoint", 15>;
def SPV_OC_OpExecutionMode : I32EnumAttrCase<"OpExecutionMode", 16>;
def SPV_OC_OpTypeVoid : I32EnumAttrCase<"OpTypeVoid", 19>;
def SPV_OC_OpTypeFloat : I32EnumAttrCase<"OpTypeFloat", 22>;
def SPV_OC_OpTypePointer : I32EnumAttrCase<"OpTypePointer", 32>;
def SPV_OC_OpTypeFunction : I32EnumAttrCase<"OpTypeFunction", 33>;
def SPV_OC_OpFunction : I32EnumAttrCase<"OpFunction", 54>;
def SPV_OC_OpFunctionParameter : I32EnumAttrCase<"OpFunctionParameter", 55>;
@ -91,10 +93,10 @@ def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>;
def SPV_OpcodeAttr :
I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode,
SPV_OC_OpTypeVoid, SPV_OC_OpTypeFunction, SPV_OC_OpFunction,
SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpVariable,
SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpDecorate, SPV_OC_OpCompositeExtract,
SPV_OC_OpFMul, SPV_OC_OpReturn
SPV_OC_OpTypeVoid, SPV_OC_OpTypeFloat, SPV_OC_OpTypePointer,
SPV_OC_OpTypeFunction, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
SPV_OC_OpFunctionEnd, SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore,
SPV_OC_OpDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpFMul, SPV_OC_OpReturn
]> {
let returnType = "::mlir::spirv::Opcode";
let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())";
@ -514,8 +516,6 @@ def ModuleOnly :
class SPV_Op<string mnemonic, list<OpTrait> traits = []> :
Op<SPV_Dialect, mnemonic, traits> {
string spirvOpName = "Op" # mnemonic;
// For each SPIR-V op, the following static functions need to be defined
// in SPVOps.cpp:
//
@ -527,13 +527,34 @@ class SPV_Op<string mnemonic, list<OpTrait> traits = []> :
let printer = [{ return ::print(*this, p); }];
let verifier = [{ return ::verify(*this); }];
// By default the opcode to use for (de)serialization is obtained
// automatically from the SPIR-V spec. It assume the SPIR-V op being defined
// is ('Op' # mnemonic). The opcode value can be obtained by calling
// getOpcode<OpClass>(). If invoking this method is invalid or custom
// processing is needed for the op, set hasOpcode = 0 and specialize the
// getOpcode method.
int hasOpcode = 1;
// Specifies if the SPIR-V op has a direct representation of an instruction in
// SPIR-V spec. When set to 1, the ODS generates the dispatch to the
// (de)serialization methods for the op, and also generates the implementation
// of these processing methods (if autogenSerialization is also set to 1).
bit hasOpcode = 1;
// Name of the corresponding SPIR-V op. Only valid to use when hasOpcode is 1
string spirvOpName = "Op" # mnemonic;
// Controls whether the (de)serialization method is generated automatically or
// not. This results in generation of the following methods:
//
// template<typename OpTy> Serialization::processOp(OpTy op)
// template<typename OpTy> Deserialization::processOp(ArrayRef<uint32_t>)
//
// If the auto generation is disabled (set to 0), then manual implementation
// of a specialization of these methods is required.
//
// Note :
//
// 1) If hasOpcode is 1 and autogenSerialization is 0, the ODS generated
// dispatch function call the above method for the (de)serialization of the
// operation
//
// 2) If hasOpcode is 0, then ODS doesn't generate the (de)serialization
// methods, neither does it handle the dispatch. Those need to be handled
// manually.
bit autogenSerialization = 1;
}
#endif // SPIRV_BASE

View File

@ -143,6 +143,7 @@ def SPV_EntryPointOp : SPV_Op<"EntryPoint", [ModuleOnly]> {
);
let results = (outs SPV_EntryPoint:$id);
let autogenSerialization = 0;
}
// -----
@ -344,7 +345,7 @@ def SPV_StoreOp : SPV_Op<"Store", []> {
SPV_AnyPtr:$ptr,
SPV_Type:$value,
OptionalAttr<SPV_MemoryAccessAttr>:$memory_access,
OptionalAttr<APIntAttr>:$alignment
OptionalAttr<I32Attr>:$alignment
);
let results = (outs);

View File

@ -96,6 +96,10 @@ public:
// of `TypeAttrBase`).
bool isTypeAttr() const;
// Returns true if this attribute is an enum attribute (i.e., a subclass of
// `EnumAttrInfo`)
bool isEnumAttr() const;
// Returns this attribute's TableGen def name. If this is an `OptionalAttr`
// or `DefaultValuedAttr` without explicit name, returns the base attribute's
// name.

View File

@ -132,7 +132,7 @@ public:
int getNumArgs() const { return arguments.size(); }
// Op argument (attribute or operand) accessors.
Argument getArg(int index);
Argument getArg(int index) const;
StringRef getArgName(int index) const;
// Returns the number of `PredOpTrait` traits.

View File

@ -88,9 +88,24 @@ private:
LogicalResult processType(spirv::Opcode opcode, ArrayRef<uint32_t> operands);
LogicalResult processFunctionType(ArrayRef<uint32_t> operands);
/// Process SPIR-V instructions that dont have any operands
/// Method to dispatch to the specialized deserialization function for an
/// operation in SPIR-V dialect that is a mirror of an operation in the SPIR-V
/// spec. This is auto-generated from ODS. Dispatch is handled for all
/// operations in SPIR-V dialect that have hasOpcode == 1
LogicalResult dispatchToAutogenDeserialization(spirv::Opcode opcode,
ArrayRef<uint32_t> words);
/// Method to deserialize an operation in the SPIR-V dialect that is a mirror
/// of an instruction in the SPIR-V spec. This is auto generated if hasOpcode
/// == 1 and autogenSerialization == 1 in ODS.
template <typename OpTy> LogicalResult processOp(ArrayRef<uint32_t> words) {
return processOpImpl<OpTy>(words);
}
template <typename OpTy>
LogicalResult processNullaryInstruction(ArrayRef<uint32_t> operands);
LogicalResult processOpImpl(ArrayRef<uint32_t> words) {
return emitError(unknownLoc, "unsupported deserialization for op '")
<< OpTy::getOperationName() << "')";
}
/// Process function objects in binary
LogicalResult processFunction(ArrayRef<uint32_t> operands);
@ -232,6 +247,39 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode,
}
typeMap[operands[0]] = NoneType::get(context);
break;
case spirv::Opcode::OpTypeFloat: {
if (operands.size() != 2) {
return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter");
}
Type floatTy;
switch (operands[1]) {
case 16:
floatTy = opBuilder.getF16Type();
break;
case 32:
floatTy = opBuilder.getF32Type();
break;
case 64:
floatTy = opBuilder.getF64Type();
break;
default:
return emitError(unknownLoc, "unsupported bitwdith ")
<< operands[1] << " with OpTypeFloat";
}
typeMap[operands[0]] = floatTy;
} break;
case spirv::Opcode::OpTypePointer: {
if (operands.size() != 3) {
return emitError(unknownLoc, "OpTypePointer must have two parameters");
}
auto pointeeType = getType(operands[2]);
if (!pointeeType) {
return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> : ")
<< operands[2];
}
auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
typeMap[operands[0]] = spirv::PointerType::get(pointeeType, storageClass);
} break;
case spirv::Opcode::OpTypeFunction:
return processFunctionType(operands);
default:
@ -240,18 +288,6 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode,
return success();
}
template <typename OpTy>
LogicalResult
Deserializer::processNullaryInstruction(ArrayRef<uint32_t> operands) {
if (!operands.empty()) {
return emitError(unknownLoc) << stringifyOpcode(spirv::getOpcode<OpTy>())
<< " must have no operands, but found "
<< operands.size() << " operands";
}
opBuilder.create<OpTy>(unknownLoc);
return success();
}
LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
// Get the result type
if (operands.size() != 4) {
@ -314,7 +350,7 @@ LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
"expected result type and result <id> for OpFunctionParameter");
}
auto argDefinedType = getType(operands[0]);
if (argDefinedType || argDefinedType != argType) {
if (!argDefinedType || argDefinedType != argType) {
return emitError(unknownLoc,
"mismatch in argument type between function type "
"definition ")
@ -352,23 +388,26 @@ LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
return success();
}
#define GET_DESERIALIZATION_FNS
#include "mlir/Dialect/SPIRV/SPIRVSerialization.inc"
LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
ArrayRef<uint32_t> operands) {
// First dispatch all the instructions whose opcode does not correspond to
// those that have a direct mirror in the SPIR-V dialect
switch (opcode) {
case spirv::Opcode::OpMemoryModel:
return processMemoryModel(operands);
case spirv::Opcode::OpTypeVoid:
case spirv::Opcode::OpTypeFloat:
case spirv::Opcode::OpTypeFunction:
case spirv::Opcode::OpTypePointer:
case spirv::Opcode::OpTypeVoid:
return processType(opcode, operands);
case spirv::Opcode::OpReturn:
return processNullaryInstruction<spirv::ReturnOp>(operands);
case spirv::Opcode::OpFunction:
return processFunction(operands);
default:
break;
default:;
}
return emitError(unknownLoc, "NYI: opcode ")
<< spirv::stringifyOpcode(opcode);
return dispatchToAutogenDeserialization(opcode, operands);
}
LogicalResult Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {

View File

@ -35,6 +35,7 @@ constexpr unsigned kHeaderWordCount = 5;
/// SPIR-V magic number
constexpr uint32_t kMagicNumber = 0x07230203;
#define GET_SPIRV_SERIALIZATION_UTILS
#include "mlir/Dialect/SPIRV/SPIRVSerialization.inc"
} // end namespace spirv

View File

@ -41,7 +41,9 @@ static inline void buildInstruction(spirv::Opcode op,
SmallVectorImpl<uint32_t> &binary) {
uint32_t wordCount = 1 + operands.size();
binary.push_back(getPrefixedOpcode(wordCount, op));
binary.append(operands.begin(), operands.end());
if (!operands.empty()) {
binary.append(operands.begin(), operands.end());
}
}
namespace {
@ -90,11 +92,24 @@ private:
// Main method to dispatch operation serialization
LogicalResult processOperation(Operation *op);
/// Method to dispatch to the serialization function for an operation in
/// SPIR-V dialect that is a mirror of an instruction in the SPIR-V spec. This
/// is auto-generated from ODS. Dispatch is handled for all operations in
/// SPIR-V dialect that have hasOpcode == 1
LogicalResult dispatchToAutogenSerialization(Operation *op);
/// Method to serialize an operation in the SPIR-V dialect that is a mirror of
/// an instruction in the SPIR-V spec. This is auto generated if hasOpcode ==
/// 1 and autogenSerialization == 1 in ODS
template <typename OpTy> LogicalResult processOp(OpTy op) {
return processOpImpl(op);
}
template <typename OpTy> LogicalResult processOpImpl(OpTy op) {
return op.emitError("unsupported op serialization");
}
// Methods to serialize individual operation types
LogicalResult processFuncOp(FuncOp op);
// Serialize op that dont produce a value and have no operands, like
// spirv::ReturnOp
template <typename OpType> LogicalResult processNullaryOp(OpType op);
uint32_t getNextID() { return nextID++; }
@ -103,6 +118,16 @@ private:
return (it != typeIDMap.end() ? it->second : Optional<uint32_t>(None));
}
Optional<uint32_t> findValueID(Value *val) const {
auto it = valueIDMap.find(val);
return (it != valueIDMap.end() ? it->second : Optional<uint32_t>(None));
}
Optional<uint32_t> findFunctionID(Operation *op) const {
auto it = funcIDMap.find(op);
return (it != funcIDMap.end() ? it->second : Optional<uint32_t>(None));
}
Type voidType() { return mlir::NoneType::get(module.getContext()); }
bool isVoidType(Type type) const { return type.isa<NoneType>(); }
@ -133,6 +158,9 @@ private:
// Map from FuncOps to IDs
DenseMap<Operation *, uint32_t> funcIDMap;
// Map from Value to Ids
DenseMap<Value *, uint32_t> valueIDMap;
};
} // namespace
@ -245,6 +273,20 @@ Serializer::processBasicType(Location loc, Type type, spirv::Opcode &typeEnum,
if (isVoidType(type)) {
typeEnum = spirv::Opcode::OpTypeVoid;
return success();
} else if (type.isa<FloatType>()) {
typeEnum = spirv::Opcode::OpTypeFloat;
operands.push_back(type.cast<FloatType>().getWidth());
return success();
} else if (type.isa<spirv::PointerType>()) {
auto ptrType = type.cast<spirv::PointerType>();
uint32_t pointeeTypeID = 0;
if (failed(processType(loc, ptrType.getPointeeType(), pointeeTypeID))) {
return failure();
}
typeEnum = spirv::Opcode::OpTypePointer;
operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
operands.push_back(pointeeTypeID);
return success();
}
/// TODO(ravishankarm) : Handle other types
return emitError(loc, "unhandled type in serialization : ") << type;
@ -275,15 +317,14 @@ Serializer::processFunctionType(Location loc, FunctionType type,
}
LogicalResult Serializer::processOperation(Operation *op) {
// First dispatch the methods that do not directly mirror an operation from
// the SPIR-V spec
if (isa<FuncOp>(op)) {
return processFuncOp(cast<FuncOp>(op));
} else if (isa<spirv::ReturnOp>(op)) {
return processNullaryOp(cast<spirv::ReturnOp>(op));
} else if (isa<spirv::ModuleEndOp>(op)) {
return success();
}
/// TODO(ravishankarm) : Handle other ops
return op->emitError("unhandled operation serialization");
return dispatchToAutogenSerialization(op);
}
LogicalResult Serializer::processFuncOp(FuncOp op) {
@ -314,13 +355,15 @@ LogicalResult Serializer::processFuncOp(FuncOp op) {
buildInstruction(spirv::Opcode::OpFunction, operands, functions);
// Declare the parameters
for (auto argType : op.getType().getInputs()) {
for (auto arg : op.getArguments()) {
uint32_t argTypeID = 0;
if (failed(processType(op.getLoc(), argType, argTypeID))) {
if (failed(processType(op.getLoc(), arg->getType(), argTypeID))) {
return failure();
}
auto argValueID = getNextID();
valueIDMap[arg] = argValueID;
buildInstruction(spirv::Opcode::OpFunctionParameter,
{argTypeID, getNextID()}, functions);
{argTypeID, argValueID}, functions);
}
// Process the body
@ -345,11 +388,8 @@ LogicalResult Serializer::processFuncOp(FuncOp op) {
return success();
}
template <typename OpType>
LogicalResult Serializer::processNullaryOp(OpType op) {
buildInstruction(spirv::getOpcode<OpType>(), ArrayRef<uint32_t>(), functions);
return success();
}
#define GET_SERIALIZATION_FNS
#include "mlir/Dialect/SPIRV/SPIRVSerialization.inc"
LogicalResult spirv::serialize(spirv::ModuleOp module,
SmallVectorImpl<uint32_t> &binary) {

View File

@ -63,6 +63,10 @@ bool tblgen::Attribute::isTypeAttr() const {
return def->isSubClassOf("TypeAttrBase");
}
bool tblgen::Attribute::isEnumAttr() const {
return def->isSubClassOf("EnumAttrInfo");
}
bool tblgen::Attribute::hasStorageType() const {
const auto *init = def->getValueInit("storageType");
return !getValueAsString(init).empty();

View File

@ -217,7 +217,7 @@ auto tblgen::Operator::getOperands() -> value_range {
return {operand_begin(), operand_end()};
}
auto tblgen::Operator::getArg(int index) -> Argument {
auto tblgen::Operator::getArg(int index) const -> Argument {
return arguments[index];
}

View File

@ -0,0 +1,16 @@
// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s
// CHECK: func {{@.*}}([[ARG1:%.*]]: !spv.ptr<f32, Input>, [[ARG2:%.*]]: !spv.ptr<f32, Output>) {
// CHECK-NEXT: [[VALUE:%.*]] = spv.Load "Input" [[ARG1]] : f32
// CHECK-NEXT: spv.Store "Output" [[ARG2]], [[VALUE]] : f32
func @spirv_loadstore() -> () {
spv.module "Logical" "VulkanKHR" {
func @load_store(%arg0 : !spv.ptr<f32, Input>, %arg1 : !spv.ptr<f32, Output>) {
%1 = spv.Load "Input" %arg0 : f32
spv.Store "Output" %arg1, %1 : f32
spv.Return
}
}
return
}

View File

@ -0,0 +1,11 @@
// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s
// CHECK: {{%.*}} = spv.Variable : !spv.ptr<f32, Input>
// CHECK-NEXT: {{%.*}} = spv.Variable : !spv.ptr<f32, Output>
func @spirv_variables() -> () {
spv.module "Logical" "VulkanKHR" {
%2 = spv.Variable : !spv.ptr<f32, Input>
%3 = spv.Variable : !spv.ptr<f32, Output>
}
return
}

View File

@ -32,39 +32,317 @@
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
using llvm::ArrayRef;
using llvm::formatv;
using llvm::raw_ostream;
using llvm::raw_string_ostream;
using llvm::Record;
using llvm::RecordKeeper;
using llvm::SMLoc;
using mlir::tblgen::Attribute;
using mlir::tblgen::EnumAttr;
using mlir::tblgen::NamedAttribute;
using mlir::tblgen::NamedTypeConstraint;
using mlir::tblgen::Operator;
// Writes the following function to `os`:
// inline uint32_t getOpcode(<op-class-name>) { return <opcode>; }
static void emitGetOpcodeFunction(const llvm::Record &record,
Operator const &op, raw_ostream &os) {
if (record.getValueAsInt("hasOpcode")) {
os << formatv("template <> constexpr inline ::mlir::spirv::Opcode "
"getOpcode<{0}>()",
op.getQualCppClassName())
<< " {\n "
<< formatv("return ::mlir::spirv::Opcode::Op{0};\n}\n",
record.getValueAsString("opName"));
static void emitGetOpcodeFunction(const Record *record, Operator const &op,
raw_ostream &os) {
os << formatv("template <> constexpr inline ::mlir::spirv::Opcode "
"getOpcode<{0}>()",
op.getQualCppClassName())
<< " {\n "
<< formatv("return ::mlir::spirv::Opcode::{0};\n}\n",
record->getValueAsString("spirvOpName"));
}
static void declareOpcodeFn(raw_ostream &os) {
os << "template <typename OpClass> inline constexpr ::mlir::spirv::Opcode "
"getOpcode();\n";
}
static void emitAttributeSerialization(const Attribute &attr,
ArrayRef<SMLoc> loc, llvm::StringRef op,
llvm::StringRef operandList,
llvm::StringRef attrName,
raw_ostream &os) {
os << " auto attr = " << op << ".getAttr(\"" << attrName << "\");\n";
os << " if (attr) {\n";
if (attr.getAttrDefName() == "I32ArrayAttr") {
// Serialize all the elements of the array
os << " for (auto attrElem : attr.cast<ArrayAttr>()) {\n";
os << " " << operandList
<< ".push_back(static_cast<uint32_t>(attrElem.cast<IntegerAttr>()."
"getValue().getZExtValue()));\n";
os << " }\n";
} else if (attr.isEnumAttr() || attr.getAttrDefName() == "I32Attr") {
os << " " << operandList
<< ".push_back(static_cast<uint32_t>(attr.cast<IntegerAttr>().getValue()"
".getZExtValue()));\n";
} else {
PrintFatalError(
loc,
llvm::Twine(
"unhandled attribute type in SPIR-V serialization generation : '") +
attr.getAttrDefName() + llvm::Twine("'"));
}
os << " }\n";
}
static void emitSerializationFunction(const Record *record, const Operator &op,
raw_ostream &os) {
// If the record has 'autogenSerialization' set to 0, nothing to do
if (!record->getValueAsBit("autogenSerialization")) {
return;
}
os << formatv("template <> LogicalResult\nSerializer::processOpImpl<{0}>(\n"
" {0} op)",
op.getQualCppClassName())
<< " {\n";
os << " SmallVector<uint32_t, 4> operands;\n";
// Serialize result information
if (op.getNumResults() == 1) {
os << " {\n";
os << " uint32_t typeID = 0;\n";
os << " if (failed(processType(op.getLoc(), "
"op.getResult()->getType(), typeID))) {\n";
os << " return failure();\n";
os << " }\n";
os << " operands.push_back(typeID);\n";
/// Create an SSA result <id> for the op
os << " auto resultID = getNextID();\n";
os << " valueIDMap[op.getResult()] = resultID;\n";
os << " operands.push_back(resultID);\n";
os << " }\n";
} else if (op.getNumResults() != 0) {
PrintFatalError(record->getLoc(), "SPIR-V ops can only zero or one result");
}
// Process arguments
auto operandNum = 0;
for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
auto argument = op.getArg(i);
os << " {\n";
if (argument.is<NamedTypeConstraint *>()) {
os << " if (" << operandNum
<< " < op.getOperation()->getNumOperands()) {\n";
os << " auto arg = findValueID(op.getOperation()->getOperand("
<< operandNum << "));\n";
os << " if (!arg) {\n";
os << " emitError(op.getLoc(), \"operand " << operandNum
<< " has a use before def\");\n";
os << " }\n";
os << " operands.push_back(arg.getValue());\n";
os << " }\n";
operandNum++;
} else {
auto attr = argument.get<NamedAttribute *>();
emitAttributeSerialization(
(attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
record->getLoc(), "op", "operands", attr->name, os);
}
os << " }\n";
}
os << formatv(
" buildInstruction(spirv::getOpcode<{0}>(), operands, functions);\n",
op.getQualCppClassName());
os << " return success();\n";
os << "}\n\n";
}
static void initDispatchSerializationFn(raw_ostream &os) {
os << "LogicalResult Serializer::dispatchToAutogenSerialization(Operation "
"*op) {\n ";
}
static void emitSerializationDispatch(const Operator &op, raw_ostream &os) {
os << formatv(" if (isa<{0}>(op)) ", op.getQualCppClassName()) << "{\n";
os << " ";
os << formatv("return processOp<{0}>(cast<{0}>(op));\n",
op.getQualCppClassName());
os << " } else";
}
static void finalizeDispatchSerializationFn(raw_ostream &os) {
os << " {\n";
os << " return op->emitError(\"unhandled operation serialization\");\n";
os << " }\n";
os << " return success();\n";
os << "}\n\n";
}
static void emitAttributeDeserialization(
const Attribute &attr, ArrayRef<SMLoc> loc, llvm::StringRef attrList,
llvm::StringRef attrName, llvm::StringRef operandsList,
llvm::StringRef wordIndex, llvm::StringRef wordCount, raw_ostream &os) {
if (attr.getAttrDefName() == "I32ArrayAttr") {
os << " SmallVector<Attribute, 4> attrListElems;\n";
os << " while (" << wordIndex << " < " << wordCount << ") {\n";
os << " attrListElems.push_back(opBuilder.getI32IntegerAttr("
<< operandsList << "[" << wordIndex << "++]));\n";
os << " }\n";
os << " " << attrList << ".push_back(opBuilder.getNamedAttr(\""
<< attrName << "\", opBuilder.getArrayAttr(attrListElems)));\n";
} else if (attr.isEnumAttr() || attr.getAttrDefName() == "I32Attr") {
os << " " << attrList << ".push_back(opBuilder.getNamedAttr(\""
<< attrName << "\", opBuilder.getI32IntegerAttr(" << operandsList << "["
<< wordIndex << "++])));\n";
} else {
PrintFatalError(
loc, llvm::Twine(
"unhandled attribute type in deserialization generation : '") +
attr.getAttrDefName() + llvm::Twine("'"));
}
}
static bool emitSerializationUtils(const RecordKeeper &recordKeeper,
raw_ostream &os) {
llvm::emitSourceFileHeader("SPIR-V Serialization Utilities", os);
static void emitDeserializationFunction(const Record *record,
const Operator &op, raw_ostream &os) {
// If the record has 'autogenSerialization' set to 0, nothing to do
if (!record->getValueAsBit("autogenSerialization")) {
return;
}
os << formatv("template <> "
"LogicalResult\nDeserializer::processOpImpl<{0}>(ArrayRef<"
"uint32_t> words)",
op.getQualCppClassName());
os << " {\n";
os << " SmallVector<Type, 1> resultTypes;\n";
os << " size_t wordIndex = 0;\n";
/// Define the function to get the opcode
os << "template <typename OpClass> inline constexpr ::mlir::spirv::Opcode "
"getOpcode();\n";
// Deserialize result information if it exists
bool hasResult = false;
if (op.getNumResults() == 1) {
os << " {\n";
os << " if (wordIndex >= words.size()) {\n";
os << " "
<< formatv("return emitError(unknownLoc, \"expected result type <id> "
"while deserializing {0}\");\n",
op.getQualCppClassName());
os << " }\n";
os << " auto ty = getType(words[wordIndex]);\n";
os << " if (!ty) {\n";
os << " return emitError(unknownLoc, \"unknown type result <id> : "
"\") << words[wordIndex];\n";
os << " }\n";
os << " resultTypes.push_back(ty);\n";
os << " wordIndex++;\n";
os << " }\n";
os << " if (wordIndex >= words.size()) {\n";
os << " "
<< formatv("return emitError(unknownLoc, \"expected result <id> while "
"deserializing {0}\");\n",
op.getQualCppClassName());
os << " }\n";
os << " uint32_t valueID = words[wordIndex++];\n";
hasResult = true;
} else if (op.getNumResults() != 0) {
PrintFatalError(record->getLoc(),
"SPIR-V ops can have only zero or one result");
}
// Process arguments/attributes
os << " SmallVector<Value *, 4> operands;\n";
os << " SmallVector<NamedAttribute, 4> attributes;\n";
unsigned operandNum = 0;
for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
auto argument = op.getArg(i);
os << " if (wordIndex < words.size()) {\n";
if (argument.is<NamedTypeConstraint *>()) {
os << " auto arg = getValue(words[wordIndex]);\n";
os << " if (!arg) {\n";
os << " return emitError(unknownLoc, \"unknown result <id> : \") << "
"words[wordIndex];\n";
os << " }\n";
os << " operands.push_back(arg);\n";
os << " wordIndex++;\n";
operandNum++;
} else {
auto attr = argument.get<NamedAttribute *>();
emitAttributeDeserialization(
(attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
record->getLoc(), "attributes", attr->name, "words", "wordIndex",
"words.size()", os);
}
os << " }\n";
}
os << formatv(" auto op = opBuilder.create<{0}>(unknownLoc, resultTypes, "
"operands, attributes);\n",
op.getQualCppClassName());
if (hasResult) {
os << " valueMap[valueID] = op.getResult();\n";
}
os << " return success();\n";
os << "}\n\n";
}
static void initDispatchDeserializationFn(raw_ostream &os) {
os << "LogicalResult "
"Deserializer::dispatchToAutogenDeserialization(spirv::Opcode "
"opcode, ArrayRef<uint32_t> words) {\n";
os << " switch (opcode) {\n";
}
static void emitDeserializationDispatch(const Operator &op, const Record *def,
raw_ostream &os) {
os << formatv(" case spirv::Opcode::{0}:\n",
def->getValueAsString("spirvOpName"));
os << formatv(" return processOp<{0}>(words);\n",
op.getQualCppClassName());
}
static void finalizeDispatchDeserializationFn(raw_ostream &os) {
os << " default:\n";
os << " ;\n";
os << " }\n";
os << " return emitError(unknownLoc, \"unhandled deserialization of \") << "
"spirv::stringifyOpcode(opcode);\n";
os << "}\n";
}
static bool emitSerializationFns(const RecordKeeper &recordKeeper,
raw_ostream &os) {
llvm::emitSourceFileHeader("SPIR-V Serialization Utilities/Functions", os);
std::string dSerFnString, dDesFnString, serFnString, deserFnString,
utilsString;
raw_string_ostream dSerFn(dSerFnString), dDesFn(dDesFnString),
serFn(serFnString), deserFn(deserFnString), utils(utilsString);
declareOpcodeFn(utils);
initDispatchSerializationFn(dSerFn);
initDispatchDeserializationFn(dDesFn);
auto defs = recordKeeper.getAllDerivedDefinitions("SPV_Op");
for (const auto *def : defs) {
if (!def->getValueAsBit("hasOpcode")) {
continue;
}
Operator op(def);
emitGetOpcodeFunction(*def, op, os);
emitGetOpcodeFunction(def, op, utils);
emitSerializationFunction(def, op, serFn);
emitSerializationDispatch(op, dSerFn);
emitDeserializationFunction(def, op, deserFn);
emitDeserializationDispatch(op, def, dDesFn);
}
finalizeDispatchSerializationFn(dSerFn);
finalizeDispatchDeserializationFn(dDesFn);
os << "#ifdef GET_SPIRV_SERIALIZATION_UTILS\n";
os << utils.str();
os << "#endif // GET_SPIRV_SERIALIZATION_UTILS\n\n";
os << "#ifdef GET_SERIALIZATION_FNS\n\n";
os << serFn.str();
os << dSerFn.str();
os << "#endif // GET_SERIALIZATION_FNS\n\n";
os << "#ifdef GET_DESERIALIZATION_FNS\n\n";
os << deserFn.str();
os << dDesFn.str();
os << "#endif // GET_DESERIALIZATION_FNS\n\n";
return false;
}
@ -134,12 +412,12 @@ static bool emitOpUtils(const RecordKeeper &recordKeeper, raw_ostream &os) {
}
// Registers the enum utility generator to mlir-tblgen.
static mlir::GenRegistration
genSerializationDefs("gen-spirv-serial",
"Generate SPIR-V serialization utility definitions",
[](const RecordKeeper &records, raw_ostream &os) {
return emitSerializationUtils(records, os);
});
static mlir::GenRegistration genSerialization(
"gen-spirv-serialization",
"Generate SPIR-V (de)serialization utilities and functions",
[](const RecordKeeper &records, raw_ostream &os) {
return emitSerializationFns(records, os);
});
static mlir::GenRegistration
genOpUtils("gen-spirv-op-utils",