diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index d1abf9f8af0c..b8a5f3cc98af 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -87,6 +87,7 @@ def SPV_OC_OpTypeInt : I32EnumAttrCase<"OpTypeInt", 21>; def SPV_OC_OpTypeFloat : I32EnumAttrCase<"OpTypeFloat", 22>; def SPV_OC_OpTypeVector : I32EnumAttrCase<"OpTypeVector", 23>; def SPV_OC_OpTypeArray : I32EnumAttrCase<"OpTypeArray", 28>; +def SPV_OC_OpTypeRuntimeArray : I32EnumAttrCase<"OpTypeRuntimeArray", 29>; def SPV_OC_OpTypeStruct : I32EnumAttrCase<"OpTypeStruct", 30>; def SPV_OC_OpTypePointer : I32EnumAttrCase<"OpTypePointer", 32>; def SPV_OC_OpTypeFunction : I32EnumAttrCase<"OpTypeFunction", 33>; @@ -160,23 +161,23 @@ def SPV_OpcodeAttr : SPV_OC_OpExtInst, SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode, SPV_OC_OpCapability, SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt, SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, - SPV_OC_OpTypeArray, SPV_OC_OpTypeStruct, SPV_OC_OpTypePointer, - SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse, - SPV_OC_OpConstant, SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull, - SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant, - SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, - SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad, - SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate, - SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpIAdd, - SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, - SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, - SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect, SPV_OC_OpIEqual, - SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, - SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, - SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, - SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, - SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, - SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, + SPV_OC_OpTypeArray, SPV_OC_OpTypeRuntimeArray, SPV_OC_OpTypeStruct, + SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue, + SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite, + SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse, + SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, + SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, + SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain, + SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, + SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, + SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, + SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect, + SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, + SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, + SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, + SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, + SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, + SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpLoopMerge, SPV_OC_OpLabel, SPV_OC_OpBranch, diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index 23cd60e38c67..7c62dca0665f 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -176,6 +176,8 @@ private: LogicalResult processFunctionType(ArrayRef operands); + LogicalResult processRuntimeArrayType(ArrayRef operands); + LogicalResult processStructType(ArrayRef operands); //===--------------------------------------------------------------------===// @@ -996,6 +998,8 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode, return processArrayType(operands); case spirv::Opcode::OpTypeFunction: return processFunctionType(operands); + case spirv::Opcode::OpTypeRuntimeArray: + return processRuntimeArrayType(operands); case spirv::Opcode::OpTypeStruct: return processStructType(operands); default: @@ -1061,6 +1065,21 @@ LogicalResult Deserializer::processFunctionType(ArrayRef operands) { return success(); } +LogicalResult +Deserializer::processRuntimeArrayType(ArrayRef operands) { + if (operands.size() != 2) { + return emitError(unknownLoc, "OpTypeRuntimeArray must have two operands"); + } + Type memberType = getType(operands[1]); + if (!memberType) { + return emitError(unknownLoc, + "OpTypeRuntimeArray references undefined ") + << operands[1]; + } + typeMap[operands[0]] = spirv::RuntimeArrayType::get(memberType); + return success(); +} + LogicalResult Deserializer::processStructType(ArrayRef operands) { // TODO(ravishankarm) : Regarding to the spec spv.struct must support zero // amount of members. @@ -1716,6 +1735,7 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode, case spirv::Opcode::OpTypeVector: case spirv::Opcode::OpTypeArray: case spirv::Opcode::OpTypeFunction: + case spirv::Opcode::OpTypeRuntimeArray: case spirv::Opcode::OpTypeStruct: case spirv::Opcode::OpTypePointer: return processType(opcode, operands); diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index ea506492e45e..28e68eac15d3 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -766,6 +766,17 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID, return success(); } + if (auto runtimeArrayType = type.dyn_cast()) { + uint32_t elementTypeID = 0; + if (failed(processType(loc, runtimeArrayType.getElementType(), + elementTypeID))) { + return failure(); + } + operands.push_back(elementTypeID); + typeEnum = spirv::Opcode::OpTypeRuntimeArray; + return success(); + } + if (auto structType = type.dyn_cast()) { bool hasLayout = structType.hasLayout(); for (auto elementIndex : diff --git a/mlir/test/Dialect/SPIRV/Serialization/rtarray.mlir b/mlir/test/Dialect/SPIRV/Serialization/rtarray.mlir new file mode 100644 index 000000000000..91bc802ed4c0 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Serialization/rtarray.mlir @@ -0,0 +1,8 @@ +// RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s + +spv.module "Logical" "GLSL450" { + // CHECK: spv.globalVariable {{@.*}} : !spv.ptr, StorageBuffer> + spv.globalVariable @var0 : !spv.ptr, StorageBuffer> + // CHECK: spv.globalVariable {{@.*}} : !spv.ptr>, Input> + spv.globalVariable @var1 : !spv.ptr>, Input> +} \ No newline at end of file