diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index 1a722f8a433c..8f07fecb9f06 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -221,6 +221,142 @@ def SPV_AddressingModelAttr : let cppNamespace = "::mlir::spirv"; } +def SPV_BI_Position : I32EnumAttrCase<"Position", 0>; +def SPV_BI_PointSize : I32EnumAttrCase<"PointSize", 1>; +def SPV_BI_ClipDistance : I32EnumAttrCase<"ClipDistance", 3>; +def SPV_BI_CullDistance : I32EnumAttrCase<"CullDistance", 4>; +def SPV_BI_VertexId : I32EnumAttrCase<"VertexId", 5>; +def SPV_BI_InstanceId : I32EnumAttrCase<"InstanceId", 6>; +def SPV_BI_PrimitiveId : I32EnumAttrCase<"PrimitiveId", 7>; +def SPV_BI_InvocationId : I32EnumAttrCase<"InvocationId", 8>; +def SPV_BI_Layer : I32EnumAttrCase<"Layer", 9>; +def SPV_BI_ViewportIndex : I32EnumAttrCase<"ViewportIndex", 10>; +def SPV_BI_TessLevelOuter : I32EnumAttrCase<"TessLevelOuter", 11>; +def SPV_BI_TessLevelInner : I32EnumAttrCase<"TessLevelInner", 12>; +def SPV_BI_TessCoord : I32EnumAttrCase<"TessCoord", 13>; +def SPV_BI_PatchVertices : I32EnumAttrCase<"PatchVertices", 14>; +def SPV_BI_FragCoord : I32EnumAttrCase<"FragCoord", 15>; +def SPV_BI_PointCoord : I32EnumAttrCase<"PointCoord", 16>; +def SPV_BI_FrontFacing : I32EnumAttrCase<"FrontFacing", 17>; +def SPV_BI_SampleId : I32EnumAttrCase<"SampleId", 18>; +def SPV_BI_SamplePosition : I32EnumAttrCase<"SamplePosition", 19>; +def SPV_BI_SampleMask : I32EnumAttrCase<"SampleMask", 20>; +def SPV_BI_FragDepth : I32EnumAttrCase<"FragDepth", 22>; +def SPV_BI_HelperInvocation : I32EnumAttrCase<"HelperInvocation", 23>; +def SPV_BI_NumWorkgroups : I32EnumAttrCase<"NumWorkgroups", 24>; +def SPV_BI_WorkgroupSize : I32EnumAttrCase<"WorkgroupSize", 25>; +def SPV_BI_WorkgroupId : I32EnumAttrCase<"WorkgroupId", 26>; +def SPV_BI_LocalInvocationId : I32EnumAttrCase<"LocalInvocationId", 27>; +def SPV_BI_GlobalInvocationId : I32EnumAttrCase<"GlobalInvocationId", 28>; +def SPV_BI_LocalInvocationIndex : I32EnumAttrCase<"LocalInvocationIndex", 29>; +def SPV_BI_WorkDim : I32EnumAttrCase<"WorkDim", 30>; +def SPV_BI_GlobalSize : I32EnumAttrCase<"GlobalSize", 31>; +def SPV_BI_EnqueuedWorkgroupSize : I32EnumAttrCase<"EnqueuedWorkgroupSize", 32>; +def SPV_BI_GlobalOffset : I32EnumAttrCase<"GlobalOffset", 33>; +def SPV_BI_GlobalLinearId : I32EnumAttrCase<"GlobalLinearId", 34>; +def SPV_BI_SubgroupSize : I32EnumAttrCase<"SubgroupSize", 36>; +def SPV_BI_SubgroupMaxSize : I32EnumAttrCase<"SubgroupMaxSize", 37>; +def SPV_BI_NumSubgroups : I32EnumAttrCase<"NumSubgroups", 38>; +def SPV_BI_NumEnqueuedSubgroups : I32EnumAttrCase<"NumEnqueuedSubgroups", 39>; +def SPV_BI_SubgroupId : I32EnumAttrCase<"SubgroupId", 40>; +def SPV_BI_SubgroupLocalInvocationId : I32EnumAttrCase<"SubgroupLocalInvocationId", 41>; +def SPV_BI_VertexIndex : I32EnumAttrCase<"VertexIndex", 42>; +def SPV_BI_InstanceIndex : I32EnumAttrCase<"InstanceIndex", 43>; +def SPV_BI_SubgroupEqMask : I32EnumAttrCase<"SubgroupEqMask", 4416>; +def SPV_BI_SubgroupGeMask : I32EnumAttrCase<"SubgroupGeMask", 4417>; +def SPV_BI_SubgroupGtMask : I32EnumAttrCase<"SubgroupGtMask", 4418>; +def SPV_BI_SubgroupLeMask : I32EnumAttrCase<"SubgroupLeMask", 4419>; +def SPV_BI_SubgroupLtMask : I32EnumAttrCase<"SubgroupLtMask", 4420>; +def SPV_BI_BaseVertex : I32EnumAttrCase<"BaseVertex", 4424>; +def SPV_BI_BaseInstance : I32EnumAttrCase<"BaseInstance", 4425>; +def SPV_BI_DrawIndex : I32EnumAttrCase<"DrawIndex", 4426>; +def SPV_BI_DeviceIndex : I32EnumAttrCase<"DeviceIndex", 4438>; +def SPV_BI_ViewIndex : I32EnumAttrCase<"ViewIndex", 4440>; +def SPV_BI_BaryCoordNoPerspAMD : I32EnumAttrCase<"BaryCoordNoPerspAMD", 4992>; +def SPV_BI_BaryCoordNoPerspCentroidAMD : I32EnumAttrCase<"BaryCoordNoPerspCentroidAMD", 4993>; +def SPV_BI_BaryCoordNoPerspSampleAMD : I32EnumAttrCase<"BaryCoordNoPerspSampleAMD", 4994>; +def SPV_BI_BaryCoordSmoothAMD : I32EnumAttrCase<"BaryCoordSmoothAMD", 4995>; +def SPV_BI_BaryCoordSmoothCentroidAMD : I32EnumAttrCase<"BaryCoordSmoothCentroidAMD", 4996>; +def SPV_BI_BaryCoordSmoothSampleAMD : I32EnumAttrCase<"BaryCoordSmoothSampleAMD", 4997>; +def SPV_BI_BaryCoordPullModelAMD : I32EnumAttrCase<"BaryCoordPullModelAMD", 4998>; +def SPV_BI_FragStencilRefEXT : I32EnumAttrCase<"FragStencilRefEXT", 5014>; +def SPV_BI_ViewportMaskNV : I32EnumAttrCase<"ViewportMaskNV", 5253>; +def SPV_BI_SecondaryPositionNV : I32EnumAttrCase<"SecondaryPositionNV", 5257>; +def SPV_BI_SecondaryViewportMaskNV : I32EnumAttrCase<"SecondaryViewportMaskNV", 5258>; +def SPV_BI_PositionPerViewNV : I32EnumAttrCase<"PositionPerViewNV", 5261>; +def SPV_BI_ViewportMaskPerViewNV : I32EnumAttrCase<"ViewportMaskPerViewNV", 5262>; +def SPV_BI_FullyCoveredEXT : I32EnumAttrCase<"FullyCoveredEXT", 5264>; +def SPV_BI_TaskCountNV : I32EnumAttrCase<"TaskCountNV", 5274>; +def SPV_BI_PrimitiveCountNV : I32EnumAttrCase<"PrimitiveCountNV", 5275>; +def SPV_BI_PrimitiveIndicesNV : I32EnumAttrCase<"PrimitiveIndicesNV", 5276>; +def SPV_BI_ClipDistancePerViewNV : I32EnumAttrCase<"ClipDistancePerViewNV", 5277>; +def SPV_BI_CullDistancePerViewNV : I32EnumAttrCase<"CullDistancePerViewNV", 5278>; +def SPV_BI_LayerPerViewNV : I32EnumAttrCase<"LayerPerViewNV", 5279>; +def SPV_BI_MeshViewCountNV : I32EnumAttrCase<"MeshViewCountNV", 5280>; +def SPV_BI_MeshViewIndicesNV : I32EnumAttrCase<"MeshViewIndicesNV", 5281>; +def SPV_BI_BaryCoordNV : I32EnumAttrCase<"BaryCoordNV", 5286>; +def SPV_BI_BaryCoordNoPerspNV : I32EnumAttrCase<"BaryCoordNoPerspNV", 5287>; +def SPV_BI_FragSizeEXT : I32EnumAttrCase<"FragSizeEXT", 5292>; +def SPV_BI_FragInvocationCountEXT : I32EnumAttrCase<"FragInvocationCountEXT", 5293>; +def SPV_BI_LaunchIdNV : I32EnumAttrCase<"LaunchIdNV", 5319>; +def SPV_BI_LaunchSizeNV : I32EnumAttrCase<"LaunchSizeNV", 5320>; +def SPV_BI_WorldRayOriginNV : I32EnumAttrCase<"WorldRayOriginNV", 5321>; +def SPV_BI_WorldRayDirectionNV : I32EnumAttrCase<"WorldRayDirectionNV", 5322>; +def SPV_BI_ObjectRayOriginNV : I32EnumAttrCase<"ObjectRayOriginNV", 5323>; +def SPV_BI_ObjectRayDirectionNV : I32EnumAttrCase<"ObjectRayDirectionNV", 5324>; +def SPV_BI_RayTminNV : I32EnumAttrCase<"RayTminNV", 5325>; +def SPV_BI_RayTmaxNV : I32EnumAttrCase<"RayTmaxNV", 5326>; +def SPV_BI_InstanceCustomIndexNV : I32EnumAttrCase<"InstanceCustomIndexNV", 5327>; +def SPV_BI_ObjectToWorldNV : I32EnumAttrCase<"ObjectToWorldNV", 5330>; +def SPV_BI_WorldToObjectNV : I32EnumAttrCase<"WorldToObjectNV", 5331>; +def SPV_BI_HitTNV : I32EnumAttrCase<"HitTNV", 5332>; +def SPV_BI_HitKindNV : I32EnumAttrCase<"HitKindNV", 5333>; +def SPV_BI_IncomingRayFlagsNV : I32EnumAttrCase<"IncomingRayFlagsNV", 5351>; +def SPV_BI_WarpsPerSMNV : I32EnumAttrCase<"WarpsPerSMNV", 5374>; +def SPV_BI_SMCountNV : I32EnumAttrCase<"SMCountNV", 5375>; +def SPV_BI_WarpIDNV : I32EnumAttrCase<"WarpIDNV", 5376>; +def SPV_BI_SMIDNV : I32EnumAttrCase<"SMIDNV", 5377>; + +def SPV_BuiltInAttr : + I32EnumAttr<"BuiltIn", "valid SPIR-V BuiltIn", [ + SPV_BI_Position, SPV_BI_PointSize, SPV_BI_ClipDistance, SPV_BI_CullDistance, + SPV_BI_VertexId, SPV_BI_InstanceId, SPV_BI_PrimitiveId, SPV_BI_InvocationId, + SPV_BI_Layer, SPV_BI_ViewportIndex, SPV_BI_TessLevelOuter, + SPV_BI_TessLevelInner, SPV_BI_TessCoord, SPV_BI_PatchVertices, + SPV_BI_FragCoord, SPV_BI_PointCoord, SPV_BI_FrontFacing, SPV_BI_SampleId, + SPV_BI_SamplePosition, SPV_BI_SampleMask, SPV_BI_FragDepth, + SPV_BI_HelperInvocation, SPV_BI_NumWorkgroups, SPV_BI_WorkgroupSize, + SPV_BI_WorkgroupId, SPV_BI_LocalInvocationId, SPV_BI_GlobalInvocationId, + SPV_BI_LocalInvocationIndex, SPV_BI_WorkDim, SPV_BI_GlobalSize, + SPV_BI_EnqueuedWorkgroupSize, SPV_BI_GlobalOffset, SPV_BI_GlobalLinearId, + SPV_BI_SubgroupSize, SPV_BI_SubgroupMaxSize, SPV_BI_NumSubgroups, + SPV_BI_NumEnqueuedSubgroups, SPV_BI_SubgroupId, + SPV_BI_SubgroupLocalInvocationId, SPV_BI_VertexIndex, SPV_BI_InstanceIndex, + SPV_BI_SubgroupEqMask, SPV_BI_SubgroupGeMask, SPV_BI_SubgroupGtMask, + SPV_BI_SubgroupLeMask, SPV_BI_SubgroupLtMask, SPV_BI_BaseVertex, + SPV_BI_BaseInstance, SPV_BI_DrawIndex, SPV_BI_DeviceIndex, SPV_BI_ViewIndex, + SPV_BI_BaryCoordNoPerspAMD, SPV_BI_BaryCoordNoPerspCentroidAMD, + SPV_BI_BaryCoordNoPerspSampleAMD, SPV_BI_BaryCoordSmoothAMD, + SPV_BI_BaryCoordSmoothCentroidAMD, SPV_BI_BaryCoordSmoothSampleAMD, + SPV_BI_BaryCoordPullModelAMD, SPV_BI_FragStencilRefEXT, SPV_BI_ViewportMaskNV, + SPV_BI_SecondaryPositionNV, SPV_BI_SecondaryViewportMaskNV, + SPV_BI_PositionPerViewNV, SPV_BI_ViewportMaskPerViewNV, SPV_BI_FullyCoveredEXT, + SPV_BI_TaskCountNV, SPV_BI_PrimitiveCountNV, SPV_BI_PrimitiveIndicesNV, + SPV_BI_ClipDistancePerViewNV, SPV_BI_CullDistancePerViewNV, + SPV_BI_LayerPerViewNV, SPV_BI_MeshViewCountNV, SPV_BI_MeshViewIndicesNV, + SPV_BI_BaryCoordNV, SPV_BI_BaryCoordNoPerspNV, SPV_BI_FragSizeEXT, + SPV_BI_FragInvocationCountEXT, SPV_BI_LaunchIdNV, SPV_BI_LaunchSizeNV, + SPV_BI_WorldRayOriginNV, SPV_BI_WorldRayDirectionNV, SPV_BI_ObjectRayOriginNV, + SPV_BI_ObjectRayDirectionNV, SPV_BI_RayTminNV, SPV_BI_RayTmaxNV, + SPV_BI_InstanceCustomIndexNV, SPV_BI_ObjectToWorldNV, SPV_BI_WorldToObjectNV, + SPV_BI_HitTNV, SPV_BI_HitKindNV, SPV_BI_IncomingRayFlagsNV, + SPV_BI_WarpsPerSMNV, SPV_BI_SMCountNV, SPV_BI_WarpIDNV, SPV_BI_SMIDNV + ]> { + let returnType = "::mlir::spirv::BuiltIn"; + let convertFromStorage = "static_cast<::mlir::spirv::BuiltIn>($_self.getInt())"; + let cppNamespace = "::mlir::spirv"; +} + def SPV_D_RelaxedPrecision : I32EnumAttrCase<"RelaxedPrecision", 0>; def SPV_D_SpecId : I32EnumAttrCase<"SpecId", 1>; def SPV_D_Block : I32EnumAttrCase<"Block", 2>; diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td index b833da5abb2d..ba95a761fbe8 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -1192,11 +1192,13 @@ def SPV_VariableOp : SPV_Op<"Variable", []> { ``` {.ebnf} variable-op ::= ssa-id `=` `spv.Variable` (`init(` ssa-use `)`)? (`bind(` integer-literal, integer-literal `)`)? + (`built_in(` string-literal `)`)? attribute-dict? `:` spirv-pointer-type ``` - where `init` specifies initializer and `bind` specifies the descriptor set - and binding number. + where `init` specifies initializer and `bind` specifies the + descriptor set and binding number. `built_in` specifies SPIR-V + BuiltIn decoration associated with the op. For example: @@ -1206,6 +1208,7 @@ def SPV_VariableOp : SPV_Op<"Variable", []> { %1 = spv.Variable : !spv.ptr %2 = spv.Variable init(%0): !spv.ptr %3 = spv.Variable init(%0) bind(1, 2): !spv.ptr + %3 = spv.Variable built_in("GlobalInvocationID") : !spv.ptr, Uniform> ``` }]; diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index cdd10137920b..4bea441c366e 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -898,13 +898,15 @@ static ParseResult parseVariableOp(OpAsmParser *parser, OperationState *state) { return failure(); } - // Parse optional descriptor binding - Attribute set, binding; - auto descriptorSetName = - convertToSnakeCase(stringifyDecoration(spirv::Decoration::DescriptorSet)); - auto bindingName = - convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding)); + auto builtInName = + convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn)); if (succeeded(parser->parseOptionalKeyword("bind"))) { + Attribute set, binding; + // Parse optional descriptor binding + auto descriptorSetName = convertToSnakeCase( + stringifyDecoration(spirv::Decoration::DescriptorSet)); + auto bindingName = + convertToSnakeCase(stringifyDecoration(spirv::Decoration::Binding)); Type i32Type = parser->getBuilder().getIntegerType(32); if (parser->parseLParen() || parser->parseAttribute(set, i32Type, descriptorSetName, @@ -912,8 +914,21 @@ static ParseResult parseVariableOp(OpAsmParser *parser, OperationState *state) { parser->parseComma() || parser->parseAttribute(binding, i32Type, bindingName, state->attributes) || - parser->parseRParen()) + parser->parseRParen()) { return failure(); + } + } else if (succeeded(parser->parseOptionalKeyword(builtInName.c_str()))) { + Attribute builtIn; + if (parser->parseLParen() || + parser->parseAttribute(builtIn, Type(), builtInName, + state->attributes) || + parser->parseRParen()) { + return failure(); + } + if (!builtIn.isa()) { + return parser->emitError(parser->getCurrentLocation(), + "expected string value for built_in attribute"); + } } // Parse other attributes @@ -975,6 +990,14 @@ static void print(spirv::VariableOp varOp, OpAsmPrinter *printer) { << ")"; } + // Print BuiltIn attribute if present + auto builtInName = + convertToSnakeCase(stringifyDecoration(spirv::Decoration::BuiltIn)); + if (auto builtin = varOp.getAttrOfType(builtInName)) { + *printer << " " << builtInName << "(\"" << builtin.getValue() << "\")"; + elidedAttrs.push_back(builtInName); + } + printer->printOptionalAttrDict(op->getAttrs(), elidedAttrs); *printer << " : " << varOp.getType(); } diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index 1fd9758bde33..217f9b190dd6 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -321,6 +321,15 @@ LogicalResult Deserializer::processDecoration(ArrayRef words) { opBuilder.getIdentifier(attrName), opBuilder.getI32IntegerAttr(static_cast(words[2]))); break; + case spirv::Decoration::BuiltIn: + if (words.size() != 3) { + return emitError(unknownLoc, "OpDecorate with ") + << decorationName << " needs a single integer literal"; + } + decorations[words[0]].set(opBuilder.getIdentifier(attrName), + opBuilder.getStringAttr(stringifyBuiltIn( + static_cast(words[2])))); + break; default: return emitError(unknownLoc, "unhandled Decoration : '") << decorationName; } diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index f9a85feb4f93..8b55873c5c0c 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -349,6 +349,17 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, break; } return emitError(loc, "expected integer attribute for ") << attrName; + case spirv::Decoration::BuiltIn: + if (auto strAttr = attr.second.dyn_cast()) { + auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue()); + if (enumVal) { + args.push_back(static_cast(enumVal.getValue())); + break; + } + return emitError(loc, "invalid ") + << attrName << " attribute " << strAttr.getValue(); + } + return emitError(loc, "expected string attribute for ") << attrName; default: return emitError(loc, "unhandled decoration ") << decorationName; } diff --git a/mlir/test/Dialect/SPIRV/Serialization/variables.mlir b/mlir/test/Dialect/SPIRV/Serialization/variables.mlir index e0620f199424..15cc891fc808 100644 --- a/mlir/test/Dialect/SPIRV/Serialization/variables.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/variables.mlir @@ -2,10 +2,14 @@ // CHECK: {{%.*}} = spv.Variable bind(1, 0) : !spv.ptr // CHECK-NEXT: {{%.*}} = spv.Variable bind(0, 1) : !spv.ptr +// CHECK-NEXT: {{%.*}} = spv.Variable built_in("GlobalInvocationId") : !spv.ptr, Input> +// CHECK-NEXT: {{%.*}} = spv.Variable built_in("GlobalInvocationId") : !spv.ptr, Input> func @spirv_variables() -> () { spv.module "Logical" "VulkanKHR" { %2 = spv.Variable bind(1, 0) : !spv.ptr %3 = spv.Variable bind(0, 1): !spv.ptr + %4 = spv.Variable {built_in = "GlobalInvocationId"} : !spv.ptr, Input> + %5 = spv.Variable built_in("GlobalInvocationId") : !spv.ptr, Input> } return } \ No newline at end of file diff --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir index 0592345e39d7..ac9ddfd07948 100644 --- a/mlir/test/Dialect/SPIRV/ops.mlir +++ b/mlir/test/Dialect/SPIRV/ops.mlir @@ -985,6 +985,14 @@ func @variable_init_bind() -> () { return } +func @variable_builtin() -> () { + // CHECK: spv.Variable built_in("GlobalInvocationID") : !spv.ptr, Input> + %1 = spv.Variable built_in("GlobalInvocationID") : !spv.ptr, Input> + // CHECK: spv.Variable built_in("GlobalInvocationID") : !spv.ptr, Input> + %2 = spv.Variable {built_in = "GlobalInvocationID"} : !spv.ptr, Input> + return +} + // ----- func @expect_ptr_result_type(%arg0: f32) -> () { diff --git a/mlir/utils/spirv/gen_spirv_dialect.py b/mlir/utils/spirv/gen_spirv_dialect.py index ac00179ec7a5..2017e227cc20 100755 --- a/mlir/utils/spirv/gen_spirv_dialect.py +++ b/mlir/utils/spirv/gen_spirv_dialect.py @@ -125,7 +125,7 @@ def uniquify(lst, equality_fn): unique_lst = [] for elem in lst: key = equality_fn(elem) - if equality_fn(key) not in keys: + if key not in keys: unique_lst.append(elem) keys.add(key) return unique_lst