[spirv] Support (de)serialization of spv.struct

Support (de)serialization of spv.struct with offset decorations.

Closes tensorflow/mlir#94

PiperOrigin-RevId: 264421427
This commit is contained in:
Denis Khalikov 2019-08-20 11:02:57 -07:00 committed by A. Unique TensorFlower
parent 9e6cf0d025
commit 82cf6051ee
4 changed files with 152 additions and 13 deletions

View File

@ -83,6 +83,7 @@ def SPV_OC_OpTypeInt : I32EnumAttrCase<"OpTypeInt", 21>;
def SPV_OC_OpTypeFloat : I32EnumAttrCase<"OpTypeFloat", 22>;
def SPV_OC_OpTypeVector : I32EnumAttrCase<"OpTypeVector", 23>;
def SPV_OC_OpTypeArray : I32EnumAttrCase<"OpTypeArray", 28>;
def SPV_OC_OpTypeStruct : I32EnumAttrCase<"OpTypeStruct", 30>;
def SPV_OC_OpTypePointer : I32EnumAttrCase<"OpTypePointer", 32>;
def SPV_OC_OpTypeFunction : I32EnumAttrCase<"OpTypeFunction", 33>;
def SPV_OC_OpConstantTrue : I32EnumAttrCase<"OpConstantTrue", 41>;
@ -102,6 +103,7 @@ def SPV_OC_OpLoad : I32EnumAttrCase<"OpLoad", 61>;
def SPV_OC_OpStore : I32EnumAttrCase<"OpStore", 62>;
def SPV_OC_OpAccessChain : I32EnumAttrCase<"OpAccessChain", 65>;
def SPV_OC_OpDecorate : I32EnumAttrCase<"OpDecorate", 71>;
def SPV_OC_OpMemberDecorate : I32EnumAttrCase<"OpMemberDecorate", 72>;
def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>;
def SPV_OC_OpIAdd : I32EnumAttrCase<"OpIAdd", 128>;
def SPV_OC_OpFAdd : I32EnumAttrCase<"OpFAdd", 129>;
@ -135,19 +137,20 @@ def SPV_OpcodeAttr :
SPV_OC_OpNop, SPV_OC_OpName, SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint,
SPV_OC_OpExecutionMode, SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt,
SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, SPV_OC_OpTypeArray,
SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue,
SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite,
SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse,
SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction,
SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpVariable,
SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
SPV_OC_OpCompositeExtract, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub,
SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv,
SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem,
SPV_OC_OpFMod, SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
SPV_OC_OpSLessThanEqual, SPV_OC_OpReturn, SPV_OC_OpReturnValue
SPV_OC_OpTypeStruct, SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction,
SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse, SPV_OC_OpConstant,
SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue,
SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite,
SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd,
SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain,
SPV_OC_OpDecorate,SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract,
SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul,
SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod,
SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpIEqual,
SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
SPV_OC_OpReturn, SPV_OC_OpReturnValue
]> {
let returnType = "::mlir::spirv::Opcode";
let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())";

View File

@ -28,6 +28,7 @@
#include "mlir/IR/Location.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/StringExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/bit.h"
@ -84,6 +85,9 @@ private:
/// Method to process an OpDecorate instruction.
LogicalResult processDecoration(ArrayRef<uint32_t> words);
// Method to process an OpMemberDecorate instruction.
LogicalResult processMemberDecoration(ArrayRef<uint32_t> words);
/// Processes the SPIR-V function at the current `offset` into `binary`.
/// The operands to the OpFunction instruction is passed in as ``operands`.
/// This method processes each instruction inside the function and dispatches
@ -122,6 +126,8 @@ private:
LogicalResult processFunctionType(ArrayRef<uint32_t> operands);
LogicalResult processStructType(ArrayRef<uint32_t> operands);
//===--------------------------------------------------------------------===//
// Constant
//===--------------------------------------------------------------------===//
@ -232,6 +238,9 @@ private:
// Result <id> to type decorations.
DenseMap<uint32_t, uint32_t> typeDecorations;
// Result <id> to member decorations.
DenseMap<uint32_t, DenseMap<uint32_t, uint32_t>> memberDecorationMap;
// List of instructions that are processed in a defered fashion (after an
// initial processing of the entire binary). Some operations like
// OpEntryPoint, and OpExecutionMode use forward references to function
@ -368,6 +377,23 @@ LogicalResult Deserializer::processDecoration(ArrayRef<uint32_t> words) {
return success();
}
LogicalResult Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
// The binary layout of OpMemberDecorate is different comparing to OpDecorate
if (words.size() != 4) {
return emitError(unknownLoc, "OpMemberDecorate must have 4 operands");
}
switch (static_cast<spirv::Decoration>(words[2])) {
case spirv::Decoration::Offset:
memberDecorationMap[words[0]][words[1]] = words[3];
break;
default:
return emitError(unknownLoc, "unhandled OpMemberDecoration case: ")
<< words[2];
}
return success();
}
LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
// Get the result type
if (operands.size() != 4) {
@ -653,6 +679,8 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode,
return processArrayType(operands);
case spirv::Opcode::OpTypeFunction:
return processFunctionType(operands);
case spirv::Opcode::OpTypeStruct:
return processStructType(operands);
default:
return emitError(unknownLoc, "unhandled type instruction");
}
@ -722,6 +750,46 @@ LogicalResult Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
return success();
}
LogicalResult Deserializer::processStructType(ArrayRef<uint32_t> operands) {
// TODO(ravishankarm) : Regarding to the spec spv.struct must support zero
// amount of members.
if (operands.size() < 2) {
return emitError(unknownLoc, "OpTypeStruct must have at least 2 operand");
}
SmallVector<Type, 0> memberTypes;
for (auto op : llvm::drop_begin(operands, 1)) {
Type memberType = getType(op);
if (!memberType) {
return emitError(unknownLoc, "OpTypeStruct references undefined <id> ")
<< op;
}
memberTypes.push_back(memberType);
}
SmallVector<spirv::StructType::LayoutInfo, 0> layoutInfo;
// Check for layoutinfo
auto memberDecorationIt = memberDecorationMap.find(operands[0]);
if (memberDecorationIt != memberDecorationMap.end()) {
// Each member must have an offset
const auto &offsetDecorationMap = memberDecorationIt->second;
auto offsetDecorationMapEnd = offsetDecorationMap.end();
for (auto memberIndex : llvm::seq<uint32_t>(0, memberTypes.size())) {
// Check that specific member has an offset
auto offsetIt = offsetDecorationMap.find(memberIndex);
if (offsetIt == offsetDecorationMapEnd) {
return emitError(unknownLoc, "OpTypeStruct with <id> ")
<< operands[0] << " must have an offset for " << memberIndex
<< "-th member";
}
layoutInfo.push_back(
static_cast<spirv::StructType::LayoutInfo>(offsetIt->second));
}
}
typeMap[operands[0]] = spirv::StructType::get(memberTypes, layoutInfo);
return success();
}
//===----------------------------------------------------------------------===//
// Constant
//===----------------------------------------------------------------------===//
@ -993,6 +1061,7 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
case spirv::Opcode::OpTypeVector:
case spirv::Opcode::OpTypeArray:
case spirv::Opcode::OpTypeFunction:
case spirv::Opcode::OpTypeStruct:
case spirv::Opcode::OpTypePointer:
return processType(opcode, operands);
case spirv::Opcode::OpConstant:
@ -1015,6 +1084,8 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
return processConstantNull(operands);
case spirv::Opcode::OpDecorate:
return processDecoration(operands);
case spirv::Opcode::OpMemberDecorate:
return processMemberDecoration(operands);
case spirv::Opcode::OpFunction:
return processFunction(operands);
default:

View File

@ -28,6 +28,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/StringExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/bit.h"
#include "llvm/Support/raw_ostream.h"
@ -148,6 +149,11 @@ private:
return emitError(loc, "unhandled decoraion for type:") << type;
}
/// Process member decoration
LogicalResult processMemberDecoration(uint32_t structID, uint32_t memberNum,
spirv::Decoration decorationType,
uint32_t value);
//===--------------------------------------------------------------------===//
// Types
//===--------------------------------------------------------------------===//
@ -411,6 +417,16 @@ LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
}
return success();
}
LogicalResult
Serializer::processMemberDecoration(uint32_t structID, uint32_t memberIndex,
spirv::Decoration decorationType,
uint32_t value) {
SmallVector<uint32_t, 4> args(
{structID, memberIndex, static_cast<uint32_t>(decorationType), value});
return encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate,
args);
}
} // namespace
LogicalResult Serializer::processFuncOp(FuncOp op) {
@ -618,6 +634,31 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
return success();
}
if (auto structType = type.dyn_cast<spirv::StructType>()) {
bool hasLayout = structType.hasLayout();
for (auto elementIndex :
llvm::seq<uint32_t>(0, structType.getNumElements())) {
uint32_t elementTypeID = 0;
if (failed(processType(loc, structType.getElementType(elementIndex),
elementTypeID))) {
return failure();
}
operands.push_back(elementTypeID);
if (hasLayout) {
// Decorate each struct member with an offset
if (failed(processMemberDecoration(
resultID, elementIndex, spirv::Decoration::Offset,
static_cast<uint32_t>(structType.getOffset(elementIndex))))) {
return emitError(loc, "cannot decorate ")
<< elementIndex << "-th member of : " << structType
<< "with its offset";
}
}
}
typeEnum = spirv::Opcode::OpTypeStruct;
return success();
}
// TODO(ravishankarm) : Handle other types.
return emitError(loc, "unhandled type in serialization: ") << type;
}

View File

@ -0,0 +1,24 @@
// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s
func @spirvmodule() -> () {
spv.module "Logical" "VulkanKHR" {
// CHECK: !spv.ptr<!spv.struct<!spv.array<128 x f32 [4]> [0]>, Input>
spv.globalVariable @var0 bind(0, 1) : !spv.ptr<!spv.struct<!spv.array<128 x f32 [4]> [0]>, Input>
// CHECK: !spv.ptr<!spv.struct<f32 [0], !spv.struct<f32 [0], !spv.array<16 x f32 [4]> [4]> [4]>, Input>
spv.globalVariable @var1 bind(0, 2) : !spv.ptr<!spv.struct<f32 [0], !spv.struct<f32 [0], !spv.array<16 x f32 [4]> [4]> [4]>, Input>
// CHECK: !spv.ptr<!spv.struct<f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38]>, StorageBuffer>
spv.globalVariable @var2 : !spv.ptr<!spv.struct<f32 [0], i32 [4], f64 [8], i64 [16], f32 [24], i32 [30], f32 [34], i32 [38]>, StorageBuffer>
// CHECK: !spv.ptr<!spv.struct<!spv.array<128 x !spv.struct<!spv.array<128 x f32 [4]> [0]> [4]> [0]>, StorageBuffer>
spv.globalVariable @var3 : !spv.ptr<!spv.struct<!spv.array<128 x !spv.struct<!spv.array<128 x f32 [4]> [0]> [4]> [0]>, StorageBuffer>
// CHECK: !spv.ptr<!spv.struct<!spv.array<128 x f32 [4]> [0]>, Input>,
// CHECK-SAME: !spv.ptr<!spv.struct<!spv.array<128 x f32 [4]> [0]>, Output>
func @kernel_1(%arg0: !spv.ptr<!spv.struct<!spv.array<128 x f32 [4]> [0]>, Input>, %arg1: !spv.ptr<!spv.struct<!spv.array<128 x f32 [4]> [0]>, Output>) -> () {
spv.Return
}
}
return
}