forked from OSchip/llvm-project
Automatically generate (de)serialization methods for SPIR-V ops
For ops in SPIR-V dialect that are a direct mirror of SPIR-V operations, the serialization/deserialization methods can be automatically generated from the Op specification. To enable this an 'autogenSerialization' field is added to SPV_Ops. When set to non-zero, this will enable the automatic (de)serialization function generation Also adding tests that verify the spv.Load, spv.Store and spv.Variable ops are serialized and deserialized correctly. To fully support these tests also add serialization and deserialization of float types and spv.ptr types PiperOrigin-RevId: 258684764
This commit is contained in:
parent
ec66bc57a8
commit
c6cfebf1af
|
@ -9,7 +9,7 @@ mlir_tablegen(SPIRVEnums.cpp.inc -gen-enum-defs)
|
|||
add_public_tablegen_target(MLIRSPIRVEnumsIncGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS SPIRVOps.td)
|
||||
mlir_tablegen(SPIRVSerialization.inc -gen-spirv-serial)
|
||||
mlir_tablegen(SPIRVSerialization.inc -gen-spirv-serialization)
|
||||
add_public_tablegen_target(MLIRSPIRVSerializationGen)
|
||||
|
||||
set(LLVM_TARGET_DEFINITIONS SPIRVBase.td)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//===-- SPIRVOps.td - MLIR SPIR-V Op Definitions Spec ------*- tablegen -*-===//
|
||||
//===- SPIRVBase.td - MLIR SPIR-V Op Definitions Base file -*- tablegen -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
|
@ -76,6 +76,8 @@ def SPV_OC_OpMemoryModel : I32EnumAttrCase<"OpMemoryModel", 14>;
|
|||
def SPV_OC_OpEntryPoint : I32EnumAttrCase<"OpEntryPoint", 15>;
|
||||
def SPV_OC_OpExecutionMode : I32EnumAttrCase<"OpExecutionMode", 16>;
|
||||
def SPV_OC_OpTypeVoid : I32EnumAttrCase<"OpTypeVoid", 19>;
|
||||
def SPV_OC_OpTypeFloat : I32EnumAttrCase<"OpTypeFloat", 22>;
|
||||
def SPV_OC_OpTypePointer : I32EnumAttrCase<"OpTypePointer", 32>;
|
||||
def SPV_OC_OpTypeFunction : I32EnumAttrCase<"OpTypeFunction", 33>;
|
||||
def SPV_OC_OpFunction : I32EnumAttrCase<"OpFunction", 54>;
|
||||
def SPV_OC_OpFunctionParameter : I32EnumAttrCase<"OpFunctionParameter", 55>;
|
||||
|
@ -91,10 +93,10 @@ def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>;
|
|||
def SPV_OpcodeAttr :
|
||||
I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
|
||||
SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode,
|
||||
SPV_OC_OpTypeVoid, SPV_OC_OpTypeFunction, SPV_OC_OpFunction,
|
||||
SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpVariable,
|
||||
SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpDecorate, SPV_OC_OpCompositeExtract,
|
||||
SPV_OC_OpFMul, SPV_OC_OpReturn
|
||||
SPV_OC_OpTypeVoid, SPV_OC_OpTypeFloat, SPV_OC_OpTypePointer,
|
||||
SPV_OC_OpTypeFunction, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
|
||||
SPV_OC_OpFunctionEnd, SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore,
|
||||
SPV_OC_OpDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpFMul, SPV_OC_OpReturn
|
||||
]> {
|
||||
let returnType = "::mlir::spirv::Opcode";
|
||||
let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())";
|
||||
|
@ -514,8 +516,6 @@ def ModuleOnly :
|
|||
class SPV_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<SPV_Dialect, mnemonic, traits> {
|
||||
|
||||
string spirvOpName = "Op" # mnemonic;
|
||||
|
||||
// For each SPIR-V op, the following static functions need to be defined
|
||||
// in SPVOps.cpp:
|
||||
//
|
||||
|
@ -527,13 +527,34 @@ class SPV_Op<string mnemonic, list<OpTrait> traits = []> :
|
|||
let printer = [{ return ::print(*this, p); }];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
|
||||
// By default the opcode to use for (de)serialization is obtained
|
||||
// automatically from the SPIR-V spec. It assume the SPIR-V op being defined
|
||||
// is ('Op' # mnemonic). The opcode value can be obtained by calling
|
||||
// getOpcode<OpClass>(). If invoking this method is invalid or custom
|
||||
// processing is needed for the op, set hasOpcode = 0 and specialize the
|
||||
// getOpcode method.
|
||||
int hasOpcode = 1;
|
||||
// Specifies if the SPIR-V op has a direct representation of an instruction in
|
||||
// SPIR-V spec. When set to 1, the ODS generates the dispatch to the
|
||||
// (de)serialization methods for the op, and also generates the implementation
|
||||
// of these processing methods (if autogenSerialization is also set to 1).
|
||||
bit hasOpcode = 1;
|
||||
|
||||
// Name of the corresponding SPIR-V op. Only valid to use when hasOpcode is 1
|
||||
string spirvOpName = "Op" # mnemonic;
|
||||
|
||||
// Controls whether the (de)serialization method is generated automatically or
|
||||
// not. This results in generation of the following methods:
|
||||
//
|
||||
// template<typename OpTy> Serialization::processOp(OpTy op)
|
||||
// template<typename OpTy> Deserialization::processOp(ArrayRef<uint32_t>)
|
||||
//
|
||||
// If the auto generation is disabled (set to 0), then manual implementation
|
||||
// of a specialization of these methods is required.
|
||||
//
|
||||
// Note :
|
||||
//
|
||||
// 1) If hasOpcode is 1 and autogenSerialization is 0, the ODS generated
|
||||
// dispatch function call the above method for the (de)serialization of the
|
||||
// operation
|
||||
//
|
||||
// 2) If hasOpcode is 0, then ODS doesn't generate the (de)serialization
|
||||
// methods, neither does it handle the dispatch. Those need to be handled
|
||||
// manually.
|
||||
bit autogenSerialization = 1;
|
||||
}
|
||||
|
||||
#endif // SPIRV_BASE
|
||||
|
|
|
@ -143,6 +143,7 @@ def SPV_EntryPointOp : SPV_Op<"EntryPoint", [ModuleOnly]> {
|
|||
);
|
||||
|
||||
let results = (outs SPV_EntryPoint:$id);
|
||||
let autogenSerialization = 0;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
@ -344,7 +345,7 @@ def SPV_StoreOp : SPV_Op<"Store", []> {
|
|||
SPV_AnyPtr:$ptr,
|
||||
SPV_Type:$value,
|
||||
OptionalAttr<SPV_MemoryAccessAttr>:$memory_access,
|
||||
OptionalAttr<APIntAttr>:$alignment
|
||||
OptionalAttr<I32Attr>:$alignment
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
|
|
|
@ -96,6 +96,10 @@ public:
|
|||
// of `TypeAttrBase`).
|
||||
bool isTypeAttr() const;
|
||||
|
||||
// Returns true if this attribute is an enum attribute (i.e., a subclass of
|
||||
// `EnumAttrInfo`)
|
||||
bool isEnumAttr() const;
|
||||
|
||||
// Returns this attribute's TableGen def name. If this is an `OptionalAttr`
|
||||
// or `DefaultValuedAttr` without explicit name, returns the base attribute's
|
||||
// name.
|
||||
|
|
|
@ -132,7 +132,7 @@ public:
|
|||
int getNumArgs() const { return arguments.size(); }
|
||||
|
||||
// Op argument (attribute or operand) accessors.
|
||||
Argument getArg(int index);
|
||||
Argument getArg(int index) const;
|
||||
StringRef getArgName(int index) const;
|
||||
|
||||
// Returns the number of `PredOpTrait` traits.
|
||||
|
|
|
@ -88,9 +88,24 @@ private:
|
|||
LogicalResult processType(spirv::Opcode opcode, ArrayRef<uint32_t> operands);
|
||||
LogicalResult processFunctionType(ArrayRef<uint32_t> operands);
|
||||
|
||||
/// Process SPIR-V instructions that dont have any operands
|
||||
/// Method to dispatch to the specialized deserialization function for an
|
||||
/// operation in SPIR-V dialect that is a mirror of an operation in the SPIR-V
|
||||
/// spec. This is auto-generated from ODS. Dispatch is handled for all
|
||||
/// operations in SPIR-V dialect that have hasOpcode == 1
|
||||
LogicalResult dispatchToAutogenDeserialization(spirv::Opcode opcode,
|
||||
ArrayRef<uint32_t> words);
|
||||
|
||||
/// Method to deserialize an operation in the SPIR-V dialect that is a mirror
|
||||
/// of an instruction in the SPIR-V spec. This is auto generated if hasOpcode
|
||||
/// == 1 and autogenSerialization == 1 in ODS.
|
||||
template <typename OpTy> LogicalResult processOp(ArrayRef<uint32_t> words) {
|
||||
return processOpImpl<OpTy>(words);
|
||||
}
|
||||
template <typename OpTy>
|
||||
LogicalResult processNullaryInstruction(ArrayRef<uint32_t> operands);
|
||||
LogicalResult processOpImpl(ArrayRef<uint32_t> words) {
|
||||
return emitError(unknownLoc, "unsupported deserialization for op '")
|
||||
<< OpTy::getOperationName() << "')";
|
||||
}
|
||||
|
||||
/// Process function objects in binary
|
||||
LogicalResult processFunction(ArrayRef<uint32_t> operands);
|
||||
|
@ -232,6 +247,39 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode,
|
|||
}
|
||||
typeMap[operands[0]] = NoneType::get(context);
|
||||
break;
|
||||
case spirv::Opcode::OpTypeFloat: {
|
||||
if (operands.size() != 2) {
|
||||
return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter");
|
||||
}
|
||||
Type floatTy;
|
||||
switch (operands[1]) {
|
||||
case 16:
|
||||
floatTy = opBuilder.getF16Type();
|
||||
break;
|
||||
case 32:
|
||||
floatTy = opBuilder.getF32Type();
|
||||
break;
|
||||
case 64:
|
||||
floatTy = opBuilder.getF64Type();
|
||||
break;
|
||||
default:
|
||||
return emitError(unknownLoc, "unsupported bitwdith ")
|
||||
<< operands[1] << " with OpTypeFloat";
|
||||
}
|
||||
typeMap[operands[0]] = floatTy;
|
||||
} break;
|
||||
case spirv::Opcode::OpTypePointer: {
|
||||
if (operands.size() != 3) {
|
||||
return emitError(unknownLoc, "OpTypePointer must have two parameters");
|
||||
}
|
||||
auto pointeeType = getType(operands[2]);
|
||||
if (!pointeeType) {
|
||||
return emitError(unknownLoc, "unknown OpTypePointer pointee type <id> : ")
|
||||
<< operands[2];
|
||||
}
|
||||
auto storageClass = static_cast<spirv::StorageClass>(operands[1]);
|
||||
typeMap[operands[0]] = spirv::PointerType::get(pointeeType, storageClass);
|
||||
} break;
|
||||
case spirv::Opcode::OpTypeFunction:
|
||||
return processFunctionType(operands);
|
||||
default:
|
||||
|
@ -240,18 +288,6 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode,
|
|||
return success();
|
||||
}
|
||||
|
||||
template <typename OpTy>
|
||||
LogicalResult
|
||||
Deserializer::processNullaryInstruction(ArrayRef<uint32_t> operands) {
|
||||
if (!operands.empty()) {
|
||||
return emitError(unknownLoc) << stringifyOpcode(spirv::getOpcode<OpTy>())
|
||||
<< " must have no operands, but found "
|
||||
<< operands.size() << " operands";
|
||||
}
|
||||
opBuilder.create<OpTy>(unknownLoc);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
|
||||
// Get the result type
|
||||
if (operands.size() != 4) {
|
||||
|
@ -314,7 +350,7 @@ LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
|
|||
"expected result type and result <id> for OpFunctionParameter");
|
||||
}
|
||||
auto argDefinedType = getType(operands[0]);
|
||||
if (argDefinedType || argDefinedType != argType) {
|
||||
if (!argDefinedType || argDefinedType != argType) {
|
||||
return emitError(unknownLoc,
|
||||
"mismatch in argument type between function type "
|
||||
"definition ")
|
||||
|
@ -352,23 +388,26 @@ LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
|
|||
return success();
|
||||
}
|
||||
|
||||
#define GET_DESERIALIZATION_FNS
|
||||
#include "mlir/Dialect/SPIRV/SPIRVSerialization.inc"
|
||||
|
||||
LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
|
||||
ArrayRef<uint32_t> operands) {
|
||||
// First dispatch all the instructions whose opcode does not correspond to
|
||||
// those that have a direct mirror in the SPIR-V dialect
|
||||
switch (opcode) {
|
||||
case spirv::Opcode::OpMemoryModel:
|
||||
return processMemoryModel(operands);
|
||||
case spirv::Opcode::OpTypeVoid:
|
||||
case spirv::Opcode::OpTypeFloat:
|
||||
case spirv::Opcode::OpTypeFunction:
|
||||
case spirv::Opcode::OpTypePointer:
|
||||
case spirv::Opcode::OpTypeVoid:
|
||||
return processType(opcode, operands);
|
||||
case spirv::Opcode::OpReturn:
|
||||
return processNullaryInstruction<spirv::ReturnOp>(operands);
|
||||
case spirv::Opcode::OpFunction:
|
||||
return processFunction(operands);
|
||||
default:
|
||||
break;
|
||||
default:;
|
||||
}
|
||||
return emitError(unknownLoc, "NYI: opcode ")
|
||||
<< spirv::stringifyOpcode(opcode);
|
||||
return dispatchToAutogenDeserialization(opcode, operands);
|
||||
}
|
||||
|
||||
LogicalResult Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
|
||||
|
|
|
@ -35,6 +35,7 @@ constexpr unsigned kHeaderWordCount = 5;
|
|||
/// SPIR-V magic number
|
||||
constexpr uint32_t kMagicNumber = 0x07230203;
|
||||
|
||||
#define GET_SPIRV_SERIALIZATION_UTILS
|
||||
#include "mlir/Dialect/SPIRV/SPIRVSerialization.inc"
|
||||
|
||||
} // end namespace spirv
|
||||
|
|
|
@ -41,7 +41,9 @@ static inline void buildInstruction(spirv::Opcode op,
|
|||
SmallVectorImpl<uint32_t> &binary) {
|
||||
uint32_t wordCount = 1 + operands.size();
|
||||
binary.push_back(getPrefixedOpcode(wordCount, op));
|
||||
binary.append(operands.begin(), operands.end());
|
||||
if (!operands.empty()) {
|
||||
binary.append(operands.begin(), operands.end());
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -90,11 +92,24 @@ private:
|
|||
// Main method to dispatch operation serialization
|
||||
LogicalResult processOperation(Operation *op);
|
||||
|
||||
/// Method to dispatch to the serialization function for an operation in
|
||||
/// SPIR-V dialect that is a mirror of an instruction in the SPIR-V spec. This
|
||||
/// is auto-generated from ODS. Dispatch is handled for all operations in
|
||||
/// SPIR-V dialect that have hasOpcode == 1
|
||||
LogicalResult dispatchToAutogenSerialization(Operation *op);
|
||||
|
||||
/// Method to serialize an operation in the SPIR-V dialect that is a mirror of
|
||||
/// an instruction in the SPIR-V spec. This is auto generated if hasOpcode ==
|
||||
/// 1 and autogenSerialization == 1 in ODS
|
||||
template <typename OpTy> LogicalResult processOp(OpTy op) {
|
||||
return processOpImpl(op);
|
||||
}
|
||||
template <typename OpTy> LogicalResult processOpImpl(OpTy op) {
|
||||
return op.emitError("unsupported op serialization");
|
||||
}
|
||||
|
||||
// Methods to serialize individual operation types
|
||||
LogicalResult processFuncOp(FuncOp op);
|
||||
// Serialize op that dont produce a value and have no operands, like
|
||||
// spirv::ReturnOp
|
||||
template <typename OpType> LogicalResult processNullaryOp(OpType op);
|
||||
|
||||
uint32_t getNextID() { return nextID++; }
|
||||
|
||||
|
@ -103,6 +118,16 @@ private:
|
|||
return (it != typeIDMap.end() ? it->second : Optional<uint32_t>(None));
|
||||
}
|
||||
|
||||
Optional<uint32_t> findValueID(Value *val) const {
|
||||
auto it = valueIDMap.find(val);
|
||||
return (it != valueIDMap.end() ? it->second : Optional<uint32_t>(None));
|
||||
}
|
||||
|
||||
Optional<uint32_t> findFunctionID(Operation *op) const {
|
||||
auto it = funcIDMap.find(op);
|
||||
return (it != funcIDMap.end() ? it->second : Optional<uint32_t>(None));
|
||||
}
|
||||
|
||||
Type voidType() { return mlir::NoneType::get(module.getContext()); }
|
||||
bool isVoidType(Type type) const { return type.isa<NoneType>(); }
|
||||
|
||||
|
@ -133,6 +158,9 @@ private:
|
|||
|
||||
// Map from FuncOps to IDs
|
||||
DenseMap<Operation *, uint32_t> funcIDMap;
|
||||
|
||||
// Map from Value to Ids
|
||||
DenseMap<Value *, uint32_t> valueIDMap;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
@ -245,6 +273,20 @@ Serializer::processBasicType(Location loc, Type type, spirv::Opcode &typeEnum,
|
|||
if (isVoidType(type)) {
|
||||
typeEnum = spirv::Opcode::OpTypeVoid;
|
||||
return success();
|
||||
} else if (type.isa<FloatType>()) {
|
||||
typeEnum = spirv::Opcode::OpTypeFloat;
|
||||
operands.push_back(type.cast<FloatType>().getWidth());
|
||||
return success();
|
||||
} else if (type.isa<spirv::PointerType>()) {
|
||||
auto ptrType = type.cast<spirv::PointerType>();
|
||||
uint32_t pointeeTypeID = 0;
|
||||
if (failed(processType(loc, ptrType.getPointeeType(), pointeeTypeID))) {
|
||||
return failure();
|
||||
}
|
||||
typeEnum = spirv::Opcode::OpTypePointer;
|
||||
operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
|
||||
operands.push_back(pointeeTypeID);
|
||||
return success();
|
||||
}
|
||||
/// TODO(ravishankarm) : Handle other types
|
||||
return emitError(loc, "unhandled type in serialization : ") << type;
|
||||
|
@ -275,15 +317,14 @@ Serializer::processFunctionType(Location loc, FunctionType type,
|
|||
}
|
||||
|
||||
LogicalResult Serializer::processOperation(Operation *op) {
|
||||
// First dispatch the methods that do not directly mirror an operation from
|
||||
// the SPIR-V spec
|
||||
if (isa<FuncOp>(op)) {
|
||||
return processFuncOp(cast<FuncOp>(op));
|
||||
} else if (isa<spirv::ReturnOp>(op)) {
|
||||
return processNullaryOp(cast<spirv::ReturnOp>(op));
|
||||
} else if (isa<spirv::ModuleEndOp>(op)) {
|
||||
return success();
|
||||
}
|
||||
/// TODO(ravishankarm) : Handle other ops
|
||||
return op->emitError("unhandled operation serialization");
|
||||
return dispatchToAutogenSerialization(op);
|
||||
}
|
||||
|
||||
LogicalResult Serializer::processFuncOp(FuncOp op) {
|
||||
|
@ -314,13 +355,15 @@ LogicalResult Serializer::processFuncOp(FuncOp op) {
|
|||
buildInstruction(spirv::Opcode::OpFunction, operands, functions);
|
||||
|
||||
// Declare the parameters
|
||||
for (auto argType : op.getType().getInputs()) {
|
||||
for (auto arg : op.getArguments()) {
|
||||
uint32_t argTypeID = 0;
|
||||
if (failed(processType(op.getLoc(), argType, argTypeID))) {
|
||||
if (failed(processType(op.getLoc(), arg->getType(), argTypeID))) {
|
||||
return failure();
|
||||
}
|
||||
auto argValueID = getNextID();
|
||||
valueIDMap[arg] = argValueID;
|
||||
buildInstruction(spirv::Opcode::OpFunctionParameter,
|
||||
{argTypeID, getNextID()}, functions);
|
||||
{argTypeID, argValueID}, functions);
|
||||
}
|
||||
|
||||
// Process the body
|
||||
|
@ -345,11 +388,8 @@ LogicalResult Serializer::processFuncOp(FuncOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
template <typename OpType>
|
||||
LogicalResult Serializer::processNullaryOp(OpType op) {
|
||||
buildInstruction(spirv::getOpcode<OpType>(), ArrayRef<uint32_t>(), functions);
|
||||
return success();
|
||||
}
|
||||
#define GET_SERIALIZATION_FNS
|
||||
#include "mlir/Dialect/SPIRV/SPIRVSerialization.inc"
|
||||
|
||||
LogicalResult spirv::serialize(spirv::ModuleOp module,
|
||||
SmallVectorImpl<uint32_t> &binary) {
|
||||
|
|
|
@ -63,6 +63,10 @@ bool tblgen::Attribute::isTypeAttr() const {
|
|||
return def->isSubClassOf("TypeAttrBase");
|
||||
}
|
||||
|
||||
bool tblgen::Attribute::isEnumAttr() const {
|
||||
return def->isSubClassOf("EnumAttrInfo");
|
||||
}
|
||||
|
||||
bool tblgen::Attribute::hasStorageType() const {
|
||||
const auto *init = def->getValueInit("storageType");
|
||||
return !getValueAsString(init).empty();
|
||||
|
|
|
@ -217,7 +217,7 @@ auto tblgen::Operator::getOperands() -> value_range {
|
|||
return {operand_begin(), operand_end()};
|
||||
}
|
||||
|
||||
auto tblgen::Operator::getArg(int index) -> Argument {
|
||||
auto tblgen::Operator::getArg(int index) const -> Argument {
|
||||
return arguments[index];
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s
|
||||
|
||||
// CHECK: func {{@.*}}([[ARG1:%.*]]: !spv.ptr<f32, Input>, [[ARG2:%.*]]: !spv.ptr<f32, Output>) {
|
||||
// CHECK-NEXT: [[VALUE:%.*]] = spv.Load "Input" [[ARG1]] : f32
|
||||
// CHECK-NEXT: spv.Store "Output" [[ARG2]], [[VALUE]] : f32
|
||||
|
||||
func @spirv_loadstore() -> () {
|
||||
spv.module "Logical" "VulkanKHR" {
|
||||
func @load_store(%arg0 : !spv.ptr<f32, Input>, %arg1 : !spv.ptr<f32, Output>) {
|
||||
%1 = spv.Load "Input" %arg0 : f32
|
||||
spv.Store "Output" %arg1, %1 : f32
|
||||
spv.Return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s
|
||||
|
||||
// CHECK: {{%.*}} = spv.Variable : !spv.ptr<f32, Input>
|
||||
// CHECK-NEXT: {{%.*}} = spv.Variable : !spv.ptr<f32, Output>
|
||||
func @spirv_variables() -> () {
|
||||
spv.module "Logical" "VulkanKHR" {
|
||||
%2 = spv.Variable : !spv.ptr<f32, Input>
|
||||
%3 = spv.Variable : !spv.ptr<f32, Output>
|
||||
}
|
||||
return
|
||||
}
|
|
@ -32,39 +32,317 @@
|
|||
#include "llvm/TableGen/Record.h"
|
||||
#include "llvm/TableGen/TableGenBackend.h"
|
||||
|
||||
using llvm::ArrayRef;
|
||||
using llvm::formatv;
|
||||
using llvm::raw_ostream;
|
||||
using llvm::raw_string_ostream;
|
||||
using llvm::Record;
|
||||
using llvm::RecordKeeper;
|
||||
using llvm::SMLoc;
|
||||
using mlir::tblgen::Attribute;
|
||||
using mlir::tblgen::EnumAttr;
|
||||
using mlir::tblgen::NamedAttribute;
|
||||
using mlir::tblgen::NamedTypeConstraint;
|
||||
using mlir::tblgen::Operator;
|
||||
|
||||
// Writes the following function to `os`:
|
||||
// inline uint32_t getOpcode(<op-class-name>) { return <opcode>; }
|
||||
static void emitGetOpcodeFunction(const llvm::Record &record,
|
||||
Operator const &op, raw_ostream &os) {
|
||||
if (record.getValueAsInt("hasOpcode")) {
|
||||
os << formatv("template <> constexpr inline ::mlir::spirv::Opcode "
|
||||
"getOpcode<{0}>()",
|
||||
op.getQualCppClassName())
|
||||
<< " {\n "
|
||||
<< formatv("return ::mlir::spirv::Opcode::Op{0};\n}\n",
|
||||
record.getValueAsString("opName"));
|
||||
static void emitGetOpcodeFunction(const Record *record, Operator const &op,
|
||||
raw_ostream &os) {
|
||||
os << formatv("template <> constexpr inline ::mlir::spirv::Opcode "
|
||||
"getOpcode<{0}>()",
|
||||
op.getQualCppClassName())
|
||||
<< " {\n "
|
||||
<< formatv("return ::mlir::spirv::Opcode::{0};\n}\n",
|
||||
record->getValueAsString("spirvOpName"));
|
||||
}
|
||||
|
||||
static void declareOpcodeFn(raw_ostream &os) {
|
||||
os << "template <typename OpClass> inline constexpr ::mlir::spirv::Opcode "
|
||||
"getOpcode();\n";
|
||||
}
|
||||
|
||||
static void emitAttributeSerialization(const Attribute &attr,
|
||||
ArrayRef<SMLoc> loc, llvm::StringRef op,
|
||||
llvm::StringRef operandList,
|
||||
llvm::StringRef attrName,
|
||||
raw_ostream &os) {
|
||||
os << " auto attr = " << op << ".getAttr(\"" << attrName << "\");\n";
|
||||
os << " if (attr) {\n";
|
||||
if (attr.getAttrDefName() == "I32ArrayAttr") {
|
||||
// Serialize all the elements of the array
|
||||
os << " for (auto attrElem : attr.cast<ArrayAttr>()) {\n";
|
||||
os << " " << operandList
|
||||
<< ".push_back(static_cast<uint32_t>(attrElem.cast<IntegerAttr>()."
|
||||
"getValue().getZExtValue()));\n";
|
||||
os << " }\n";
|
||||
} else if (attr.isEnumAttr() || attr.getAttrDefName() == "I32Attr") {
|
||||
os << " " << operandList
|
||||
<< ".push_back(static_cast<uint32_t>(attr.cast<IntegerAttr>().getValue()"
|
||||
".getZExtValue()));\n";
|
||||
} else {
|
||||
PrintFatalError(
|
||||
loc,
|
||||
llvm::Twine(
|
||||
"unhandled attribute type in SPIR-V serialization generation : '") +
|
||||
attr.getAttrDefName() + llvm::Twine("'"));
|
||||
}
|
||||
os << " }\n";
|
||||
}
|
||||
|
||||
static void emitSerializationFunction(const Record *record, const Operator &op,
|
||||
raw_ostream &os) {
|
||||
// If the record has 'autogenSerialization' set to 0, nothing to do
|
||||
if (!record->getValueAsBit("autogenSerialization")) {
|
||||
return;
|
||||
}
|
||||
os << formatv("template <> LogicalResult\nSerializer::processOpImpl<{0}>(\n"
|
||||
" {0} op)",
|
||||
op.getQualCppClassName())
|
||||
<< " {\n";
|
||||
os << " SmallVector<uint32_t, 4> operands;\n";
|
||||
|
||||
// Serialize result information
|
||||
if (op.getNumResults() == 1) {
|
||||
os << " {\n";
|
||||
os << " uint32_t typeID = 0;\n";
|
||||
os << " if (failed(processType(op.getLoc(), "
|
||||
"op.getResult()->getType(), typeID))) {\n";
|
||||
os << " return failure();\n";
|
||||
os << " }\n";
|
||||
os << " operands.push_back(typeID);\n";
|
||||
/// Create an SSA result <id> for the op
|
||||
os << " auto resultID = getNextID();\n";
|
||||
os << " valueIDMap[op.getResult()] = resultID;\n";
|
||||
os << " operands.push_back(resultID);\n";
|
||||
os << " }\n";
|
||||
} else if (op.getNumResults() != 0) {
|
||||
PrintFatalError(record->getLoc(), "SPIR-V ops can only zero or one result");
|
||||
}
|
||||
|
||||
// Process arguments
|
||||
auto operandNum = 0;
|
||||
for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
|
||||
auto argument = op.getArg(i);
|
||||
os << " {\n";
|
||||
if (argument.is<NamedTypeConstraint *>()) {
|
||||
os << " if (" << operandNum
|
||||
<< " < op.getOperation()->getNumOperands()) {\n";
|
||||
os << " auto arg = findValueID(op.getOperation()->getOperand("
|
||||
<< operandNum << "));\n";
|
||||
os << " if (!arg) {\n";
|
||||
os << " emitError(op.getLoc(), \"operand " << operandNum
|
||||
<< " has a use before def\");\n";
|
||||
os << " }\n";
|
||||
os << " operands.push_back(arg.getValue());\n";
|
||||
os << " }\n";
|
||||
operandNum++;
|
||||
} else {
|
||||
auto attr = argument.get<NamedAttribute *>();
|
||||
emitAttributeSerialization(
|
||||
(attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
|
||||
record->getLoc(), "op", "operands", attr->name, os);
|
||||
}
|
||||
os << " }\n";
|
||||
}
|
||||
|
||||
os << formatv(
|
||||
" buildInstruction(spirv::getOpcode<{0}>(), operands, functions);\n",
|
||||
op.getQualCppClassName());
|
||||
os << " return success();\n";
|
||||
os << "}\n\n";
|
||||
}
|
||||
|
||||
static void initDispatchSerializationFn(raw_ostream &os) {
|
||||
os << "LogicalResult Serializer::dispatchToAutogenSerialization(Operation "
|
||||
"*op) {\n ";
|
||||
}
|
||||
|
||||
static void emitSerializationDispatch(const Operator &op, raw_ostream &os) {
|
||||
os << formatv(" if (isa<{0}>(op)) ", op.getQualCppClassName()) << "{\n";
|
||||
os << " ";
|
||||
os << formatv("return processOp<{0}>(cast<{0}>(op));\n",
|
||||
op.getQualCppClassName());
|
||||
os << " } else";
|
||||
}
|
||||
|
||||
static void finalizeDispatchSerializationFn(raw_ostream &os) {
|
||||
os << " {\n";
|
||||
os << " return op->emitError(\"unhandled operation serialization\");\n";
|
||||
os << " }\n";
|
||||
os << " return success();\n";
|
||||
os << "}\n\n";
|
||||
}
|
||||
|
||||
static void emitAttributeDeserialization(
|
||||
const Attribute &attr, ArrayRef<SMLoc> loc, llvm::StringRef attrList,
|
||||
llvm::StringRef attrName, llvm::StringRef operandsList,
|
||||
llvm::StringRef wordIndex, llvm::StringRef wordCount, raw_ostream &os) {
|
||||
if (attr.getAttrDefName() == "I32ArrayAttr") {
|
||||
os << " SmallVector<Attribute, 4> attrListElems;\n";
|
||||
os << " while (" << wordIndex << " < " << wordCount << ") {\n";
|
||||
os << " attrListElems.push_back(opBuilder.getI32IntegerAttr("
|
||||
<< operandsList << "[" << wordIndex << "++]));\n";
|
||||
os << " }\n";
|
||||
os << " " << attrList << ".push_back(opBuilder.getNamedAttr(\""
|
||||
<< attrName << "\", opBuilder.getArrayAttr(attrListElems)));\n";
|
||||
} else if (attr.isEnumAttr() || attr.getAttrDefName() == "I32Attr") {
|
||||
os << " " << attrList << ".push_back(opBuilder.getNamedAttr(\""
|
||||
<< attrName << "\", opBuilder.getI32IntegerAttr(" << operandsList << "["
|
||||
<< wordIndex << "++])));\n";
|
||||
} else {
|
||||
PrintFatalError(
|
||||
loc, llvm::Twine(
|
||||
"unhandled attribute type in deserialization generation : '") +
|
||||
attr.getAttrDefName() + llvm::Twine("'"));
|
||||
}
|
||||
}
|
||||
|
||||
static bool emitSerializationUtils(const RecordKeeper &recordKeeper,
|
||||
raw_ostream &os) {
|
||||
llvm::emitSourceFileHeader("SPIR-V Serialization Utilities", os);
|
||||
static void emitDeserializationFunction(const Record *record,
|
||||
const Operator &op, raw_ostream &os) {
|
||||
// If the record has 'autogenSerialization' set to 0, nothing to do
|
||||
if (!record->getValueAsBit("autogenSerialization")) {
|
||||
return;
|
||||
}
|
||||
os << formatv("template <> "
|
||||
"LogicalResult\nDeserializer::processOpImpl<{0}>(ArrayRef<"
|
||||
"uint32_t> words)",
|
||||
op.getQualCppClassName());
|
||||
os << " {\n";
|
||||
os << " SmallVector<Type, 1> resultTypes;\n";
|
||||
os << " size_t wordIndex = 0;\n";
|
||||
|
||||
/// Define the function to get the opcode
|
||||
os << "template <typename OpClass> inline constexpr ::mlir::spirv::Opcode "
|
||||
"getOpcode();\n";
|
||||
// Deserialize result information if it exists
|
||||
bool hasResult = false;
|
||||
if (op.getNumResults() == 1) {
|
||||
os << " {\n";
|
||||
os << " if (wordIndex >= words.size()) {\n";
|
||||
os << " "
|
||||
<< formatv("return emitError(unknownLoc, \"expected result type <id> "
|
||||
"while deserializing {0}\");\n",
|
||||
op.getQualCppClassName());
|
||||
os << " }\n";
|
||||
os << " auto ty = getType(words[wordIndex]);\n";
|
||||
os << " if (!ty) {\n";
|
||||
os << " return emitError(unknownLoc, \"unknown type result <id> : "
|
||||
"\") << words[wordIndex];\n";
|
||||
os << " }\n";
|
||||
os << " resultTypes.push_back(ty);\n";
|
||||
os << " wordIndex++;\n";
|
||||
os << " }\n";
|
||||
os << " if (wordIndex >= words.size()) {\n";
|
||||
os << " "
|
||||
<< formatv("return emitError(unknownLoc, \"expected result <id> while "
|
||||
"deserializing {0}\");\n",
|
||||
op.getQualCppClassName());
|
||||
os << " }\n";
|
||||
os << " uint32_t valueID = words[wordIndex++];\n";
|
||||
hasResult = true;
|
||||
} else if (op.getNumResults() != 0) {
|
||||
PrintFatalError(record->getLoc(),
|
||||
"SPIR-V ops can have only zero or one result");
|
||||
}
|
||||
|
||||
// Process arguments/attributes
|
||||
os << " SmallVector<Value *, 4> operands;\n";
|
||||
os << " SmallVector<NamedAttribute, 4> attributes;\n";
|
||||
unsigned operandNum = 0;
|
||||
for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
|
||||
auto argument = op.getArg(i);
|
||||
os << " if (wordIndex < words.size()) {\n";
|
||||
if (argument.is<NamedTypeConstraint *>()) {
|
||||
os << " auto arg = getValue(words[wordIndex]);\n";
|
||||
os << " if (!arg) {\n";
|
||||
os << " return emitError(unknownLoc, \"unknown result <id> : \") << "
|
||||
"words[wordIndex];\n";
|
||||
os << " }\n";
|
||||
os << " operands.push_back(arg);\n";
|
||||
os << " wordIndex++;\n";
|
||||
operandNum++;
|
||||
} else {
|
||||
auto attr = argument.get<NamedAttribute *>();
|
||||
emitAttributeDeserialization(
|
||||
(attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
|
||||
record->getLoc(), "attributes", attr->name, "words", "wordIndex",
|
||||
"words.size()", os);
|
||||
}
|
||||
os << " }\n";
|
||||
}
|
||||
|
||||
os << formatv(" auto op = opBuilder.create<{0}>(unknownLoc, resultTypes, "
|
||||
"operands, attributes);\n",
|
||||
op.getQualCppClassName());
|
||||
if (hasResult) {
|
||||
os << " valueMap[valueID] = op.getResult();\n";
|
||||
}
|
||||
os << " return success();\n";
|
||||
os << "}\n\n";
|
||||
}
|
||||
|
||||
static void initDispatchDeserializationFn(raw_ostream &os) {
|
||||
os << "LogicalResult "
|
||||
"Deserializer::dispatchToAutogenDeserialization(spirv::Opcode "
|
||||
"opcode, ArrayRef<uint32_t> words) {\n";
|
||||
os << " switch (opcode) {\n";
|
||||
}
|
||||
|
||||
static void emitDeserializationDispatch(const Operator &op, const Record *def,
|
||||
raw_ostream &os) {
|
||||
os << formatv(" case spirv::Opcode::{0}:\n",
|
||||
def->getValueAsString("spirvOpName"));
|
||||
os << formatv(" return processOp<{0}>(words);\n",
|
||||
op.getQualCppClassName());
|
||||
}
|
||||
|
||||
static void finalizeDispatchDeserializationFn(raw_ostream &os) {
|
||||
os << " default:\n";
|
||||
os << " ;\n";
|
||||
os << " }\n";
|
||||
os << " return emitError(unknownLoc, \"unhandled deserialization of \") << "
|
||||
"spirv::stringifyOpcode(opcode);\n";
|
||||
os << "}\n";
|
||||
}
|
||||
|
||||
static bool emitSerializationFns(const RecordKeeper &recordKeeper,
|
||||
raw_ostream &os) {
|
||||
llvm::emitSourceFileHeader("SPIR-V Serialization Utilities/Functions", os);
|
||||
|
||||
std::string dSerFnString, dDesFnString, serFnString, deserFnString,
|
||||
utilsString;
|
||||
raw_string_ostream dSerFn(dSerFnString), dDesFn(dDesFnString),
|
||||
serFn(serFnString), deserFn(deserFnString), utils(utilsString);
|
||||
|
||||
declareOpcodeFn(utils);
|
||||
initDispatchSerializationFn(dSerFn);
|
||||
initDispatchDeserializationFn(dDesFn);
|
||||
auto defs = recordKeeper.getAllDerivedDefinitions("SPV_Op");
|
||||
for (const auto *def : defs) {
|
||||
if (!def->getValueAsBit("hasOpcode")) {
|
||||
continue;
|
||||
}
|
||||
Operator op(def);
|
||||
emitGetOpcodeFunction(*def, op, os);
|
||||
emitGetOpcodeFunction(def, op, utils);
|
||||
emitSerializationFunction(def, op, serFn);
|
||||
emitSerializationDispatch(op, dSerFn);
|
||||
emitDeserializationFunction(def, op, deserFn);
|
||||
emitDeserializationDispatch(op, def, dDesFn);
|
||||
}
|
||||
finalizeDispatchSerializationFn(dSerFn);
|
||||
finalizeDispatchDeserializationFn(dDesFn);
|
||||
|
||||
os << "#ifdef GET_SPIRV_SERIALIZATION_UTILS\n";
|
||||
os << utils.str();
|
||||
os << "#endif // GET_SPIRV_SERIALIZATION_UTILS\n\n";
|
||||
|
||||
os << "#ifdef GET_SERIALIZATION_FNS\n\n";
|
||||
os << serFn.str();
|
||||
os << dSerFn.str();
|
||||
os << "#endif // GET_SERIALIZATION_FNS\n\n";
|
||||
|
||||
os << "#ifdef GET_DESERIALIZATION_FNS\n\n";
|
||||
os << deserFn.str();
|
||||
os << dDesFn.str();
|
||||
os << "#endif // GET_DESERIALIZATION_FNS\n\n";
|
||||
|
||||
return false;
|
||||
}
|
||||
|
@ -134,12 +412,12 @@ static bool emitOpUtils(const RecordKeeper &recordKeeper, raw_ostream &os) {
|
|||
}
|
||||
|
||||
// Registers the enum utility generator to mlir-tblgen.
|
||||
static mlir::GenRegistration
|
||||
genSerializationDefs("gen-spirv-serial",
|
||||
"Generate SPIR-V serialization utility definitions",
|
||||
[](const RecordKeeper &records, raw_ostream &os) {
|
||||
return emitSerializationUtils(records, os);
|
||||
});
|
||||
static mlir::GenRegistration genSerialization(
|
||||
"gen-spirv-serialization",
|
||||
"Generate SPIR-V (de)serialization utilities and functions",
|
||||
[](const RecordKeeper &records, raw_ostream &os) {
|
||||
return emitSerializationFns(records, os);
|
||||
});
|
||||
|
||||
static mlir::GenRegistration
|
||||
genOpUtils("gen-spirv-op-utils",
|
||||
|
|
Loading…
Reference in New Issue