Add (de)serialization support for OpRuntimeArray.

Update the SPIR-V (de)serialization to handle RuntimeArrayType.

PiperOrigin-RevId: 269667196
This commit is contained in:
Mahesh Ravishankar 2019-09-17 15:21:20 -07:00 committed by A. Unique TensorFlower
parent af45ca844f
commit 9330c1b9a1
4 changed files with 57 additions and 17 deletions

View File

@ -87,6 +87,7 @@ def SPV_OC_OpTypeInt : I32EnumAttrCase<"OpTypeInt", 21>;
def SPV_OC_OpTypeFloat : I32EnumAttrCase<"OpTypeFloat", 22>;
def SPV_OC_OpTypeVector : I32EnumAttrCase<"OpTypeVector", 23>;
def SPV_OC_OpTypeArray : I32EnumAttrCase<"OpTypeArray", 28>;
def SPV_OC_OpTypeRuntimeArray : I32EnumAttrCase<"OpTypeRuntimeArray", 29>;
def SPV_OC_OpTypeStruct : I32EnumAttrCase<"OpTypeStruct", 30>;
def SPV_OC_OpTypePointer : I32EnumAttrCase<"OpTypePointer", 32>;
def SPV_OC_OpTypeFunction : I32EnumAttrCase<"OpTypeFunction", 33>;
@ -160,23 +161,23 @@ def SPV_OpcodeAttr :
SPV_OC_OpExtInst, SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint,
SPV_OC_OpExecutionMode, SPV_OC_OpCapability, SPV_OC_OpTypeVoid,
SPV_OC_OpTypeBool, SPV_OC_OpTypeInt, SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector,
SPV_OC_OpTypeArray, SPV_OC_OpTypeStruct, SPV_OC_OpTypePointer,
SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse,
SPV_OC_OpConstant, SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull,
SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant,
SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad,
SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpIAdd,
SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul,
SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem,
SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect, SPV_OC_OpIEqual,
SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual,
SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan,
SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
SPV_OC_OpTypeArray, SPV_OC_OpTypeRuntimeArray, SPV_OC_OpTypeStruct,
SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue,
SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite,
SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse,
SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction,
SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall,
SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain,
SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract,
SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul,
SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod,
SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect,
SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan,
SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
SPV_OC_OpLoopMerge, SPV_OC_OpLabel, SPV_OC_OpBranch,

View File

@ -176,6 +176,8 @@ private:
LogicalResult processFunctionType(ArrayRef<uint32_t> operands);
LogicalResult processRuntimeArrayType(ArrayRef<uint32_t> operands);
LogicalResult processStructType(ArrayRef<uint32_t> operands);
//===--------------------------------------------------------------------===//
@ -996,6 +998,8 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode,
return processArrayType(operands);
case spirv::Opcode::OpTypeFunction:
return processFunctionType(operands);
case spirv::Opcode::OpTypeRuntimeArray:
return processRuntimeArrayType(operands);
case spirv::Opcode::OpTypeStruct:
return processStructType(operands);
default:
@ -1061,6 +1065,21 @@ LogicalResult Deserializer::processFunctionType(ArrayRef<uint32_t> operands) {
return success();
}
LogicalResult
Deserializer::processRuntimeArrayType(ArrayRef<uint32_t> operands) {
if (operands.size() != 2) {
return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands");
}
Type memberType = getType(operands[1]);
if (!memberType) {
return emitError(unknownLoc,
"OpTypeRuntimeArray references undefined <id> ")
<< operands[1];
}
typeMap[operands[0]] = spirv::RuntimeArrayType::get(memberType);
return success();
}
LogicalResult Deserializer::processStructType(ArrayRef<uint32_t> operands) {
// TODO(ravishankarm) : Regarding to the spec spv.struct must support zero
// amount of members.
@ -1716,6 +1735,7 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
case spirv::Opcode::OpTypeVector:
case spirv::Opcode::OpTypeArray:
case spirv::Opcode::OpTypeFunction:
case spirv::Opcode::OpTypeRuntimeArray:
case spirv::Opcode::OpTypeStruct:
case spirv::Opcode::OpTypePointer:
return processType(opcode, operands);

View File

@ -766,6 +766,17 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
return success();
}
if (auto runtimeArrayType = type.dyn_cast<spirv::RuntimeArrayType>()) {
uint32_t elementTypeID = 0;
if (failed(processType(loc, runtimeArrayType.getElementType(),
elementTypeID))) {
return failure();
}
operands.push_back(elementTypeID);
typeEnum = spirv::Opcode::OpTypeRuntimeArray;
return success();
}
if (auto structType = type.dyn_cast<spirv::StructType>()) {
bool hasLayout = structType.hasLayout();
for (auto elementIndex :

View File

@ -0,0 +1,8 @@
// RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s
spv.module "Logical" "GLSL450" {
// CHECK: spv.globalVariable {{@.*}} : !spv.ptr<!spv.rtarray<f32>, StorageBuffer>
spv.globalVariable @var0 : !spv.ptr<!spv.rtarray<f32>, StorageBuffer>
// CHECK: spv.globalVariable {{@.*}} : !spv.ptr<!spv.rtarray<vector<4xf16>>, Input>
spv.globalVariable @var1 : !spv.ptr<!spv.rtarray<vector<4xf16>>, Input>
}