[mlir][spirv] Migrate to use specalized enum attributes

Previously we are using IntegerAttr to back all SPIR-V enum
attributes. Therefore we all such attributes are showed like
IntegerAttr in IRs, which is barely readable and breaks
roundtripability of the IR. This commit changes to use
`EnumAttr` as the base directly so that we can have separate
attribute definitions and better IR printing.

Reviewed By: kuhar

Differential Revision: https://reviews.llvm.org/D131311
This commit is contained in:
Lei Zhang 2022-08-09 14:03:54 -04:00
parent b09f6b471f
commit a29fffc475
31 changed files with 359 additions and 411 deletions

View File

@ -56,7 +56,7 @@ class Availability {
string instance = ?;
}
class MinVersionBase<string name, I32EnumAttr scheme, I32EnumAttrCase min>
class MinVersionBase<string name, EnumAttr scheme, I32EnumAttrCase min>
: Availability {
let interfaceName = name;
@ -69,13 +69,13 @@ class MinVersionBase<string name, I32EnumAttr scheme, I32EnumAttrCase min>
"std::max(*$overall, $instance)); "
"} else { $overall = $instance; }}";
let initializer = "::llvm::None";
let instanceType = scheme.cppNamespace # "::" # scheme.className;
let instanceType = scheme.cppNamespace # "::" # scheme.enum.className;
let instance = scheme.cppNamespace # "::" # scheme.className # "::" #
let instance = scheme.cppNamespace # "::" # scheme.enum.className # "::" #
min.symbol;
}
class MaxVersionBase<string name, I32EnumAttr scheme, I32EnumAttrCase max>
class MaxVersionBase<string name, EnumAttr scheme, I32EnumAttrCase max>
: Availability {
let interfaceName = name;
@ -88,9 +88,9 @@ class MaxVersionBase<string name, I32EnumAttr scheme, I32EnumAttrCase max>
"std::min(*$overall, $instance)); "
"} else { $overall = $instance; }}";
let initializer = "::llvm::None";
let instanceType = scheme.cppNamespace # "::" # scheme.className;
let instanceType = scheme.cppNamespace # "::" # scheme.enum.className;
let instance = scheme.cppNamespace # "::" # scheme.className # "::" #
let instance = scheme.cppNamespace # "::" # scheme.enum.className # "::" #
max.symbol;
}

View File

@ -77,8 +77,6 @@ def SPV_ControlBarrierOp : SPV_Op<"ControlBarrier", []> {
let results = (outs);
let autogenSerialization = 0;
let assemblyFormat = [{
$execution_scope `,` $memory_scope `,` $memory_semantics attr-dict
}];
@ -129,8 +127,6 @@ def SPV_MemoryBarrierOp : SPV_Op<"MemoryBarrier", []> {
let results = (outs);
let autogenSerialization = 0;
let assemblyFormat = "$memory_scope `,` $memory_semantics attr-dict";
}

View File

@ -82,43 +82,31 @@ def SPIRV_Dialect : Dialect {
// Utility definitions
//===----------------------------------------------------------------------===//
// A predicate that checks whether `$_self` is a known enum case for the
// enum class with `name`.
class SPV_IsKnownEnumCaseFor<string name> :
CPred<"::mlir::spirv::symbolize" # name # "("
"$_self.cast<IntegerAttr>().getValue().getZExtValue()).has_value()">;
// Wrapper over base BitEnumAttr to set common fields.
class SPV_BitEnumAttr<string name, string description,
class SPV_BitEnum<string name, string description,
list<BitEnumAttrCaseBase> cases>
: I32BitEnumAttr<name, description, cases> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::spirv";
}
class SPV_BitEnumAttr<string name, string description, string mnemonic,
list<BitEnumAttrCaseBase> cases> :
I32BitEnumAttr<name, description, cases> {
let predicate = And<[
I32Attr.predicate,
SPV_IsKnownEnumCaseFor<name>,
]>;
let cppNamespace = "::mlir::spirv";
EnumAttr<SPIRV_Dialect, SPV_BitEnum<name, description, cases>, mnemonic> {
let assemblyFormat = "`<` $value `>`";
}
// Wrapper over base I32EnumAttr to set common fields.
class SPV_I32EnumAttr<string name, string description,
list<I32EnumAttrCase> cases> :
I32EnumAttr<name, description, cases> {
let predicate = And<[
I32Attr.predicate,
SPV_IsKnownEnumCaseFor<name>,
]>;
let cppNamespace = "::mlir::spirv";
}
// Wrapper over base I32EnumAttr to set common fields.
class SPV_Enum<string name, string description, list<I32EnumAttrCase> cases>
class SPV_I32Enum<string name, string description,
list<I32EnumAttrCase> cases>
: I32EnumAttr<name, description, cases> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::spirv";
}
class SPV_EnumAttr<string name, string description, string mnemonic,
class SPV_I32EnumAttr<string name, string description, string mnemonic,
list<I32EnumAttrCase> cases> :
EnumAttr<SPIRV_Dialect, SPV_Enum<name, description, cases>, mnemonic>;
EnumAttr<SPIRV_Dialect, SPV_I32Enum<name, description, cases>, mnemonic> {
let assemblyFormat = "`<` $value `>`";
}
//===----------------------------------------------------------------------===//
// SPIR-V availability definitions
@ -132,7 +120,8 @@ def SPV_V_1_4 : I32EnumAttrCase<"V_1_4", 4, "v1.4">;
def SPV_V_1_5 : I32EnumAttrCase<"V_1_5", 5, "v1.5">;
def SPV_V_1_6 : I32EnumAttrCase<"V_1_6", 6, "v1.6">;
def SPV_VersionAttr : SPV_I32EnumAttr<"Version", "valid SPIR-V version", [
def SPV_VersionAttr : SPV_I32EnumAttr<
"Version", "valid SPIR-V version", "version", [
SPV_V_1_0, SPV_V_1_1, SPV_V_1_2, SPV_V_1_3, SPV_V_1_4, SPV_V_1_5,
SPV_V_1_6]>;
@ -284,7 +273,7 @@ def SPV_DT_Other : I32EnumAttrCase<"Other", 3>;
// Information missing.
def SPV_DT_Unknown : I32EnumAttrCase<"Unknown", 4>;
def SPV_DeviceTypeAttr : SPV_EnumAttr<
def SPV_DeviceTypeAttr : SPV_I32EnumAttr<
"DeviceType", "valid SPIR-V device types", "device_type", [
SPV_DT_Other, SPV_DT_IntegratedGPU, SPV_DT_DiscreteGPU,
SPV_DT_CPU, SPV_DT_Unknown
@ -300,7 +289,7 @@ def SPV_V_Qualcomm : I32EnumAttrCase<"Qualcomm", 6>;
def SPV_V_SwiftShader : I32EnumAttrCase<"SwiftShader", 7>;
def SPV_V_Unknown : I32EnumAttrCase<"Unknown", 0xff>;
def SPV_VendorAttr : SPV_EnumAttr<
def SPV_VendorAttr : SPV_I32EnumAttr<
"Vendor", "recognized SPIR-V vendor strings", "vendor", [
SPV_V_AMD, SPV_V_Apple, SPV_V_ARM, SPV_V_Imagination,
SPV_V_Intel, SPV_V_NVIDIA, SPV_V_Qualcomm, SPV_V_SwiftShader,
@ -418,7 +407,7 @@ def SPV_NV_ray_tracing_motion_blur : I32EnumAttrCase<"SPV_NV_ray_tracing_m
def SPV_NVX_multiview_per_view_attributes : I32EnumAttrCase<"SPV_NVX_multiview_per_view_attributes", 5015>;
def SPV_ExtensionAttr :
SPV_EnumAttr<"Extension", "supported SPIR-V extensions", "ext", [
SPV_I32EnumAttr<"Extension", "supported SPIR-V extensions", "ext", [
SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_device_group,
SPV_KHR_float_controls, SPV_KHR_physical_storage_buffer, SPV_KHR_multiview,
SPV_KHR_no_integer_wrap_decoration, SPV_KHR_post_depth_coverage,
@ -1402,7 +1391,7 @@ def SPV_C_ShaderStereoViewNV : I32EnumAttrCase<"ShaderS
}
def SPV_CapabilityAttr :
SPV_I32EnumAttr<"Capability", "valid SPIR-V Capability", [
SPV_I32EnumAttr<"Capability", "valid SPIR-V Capability", "capability", [
SPV_C_Matrix, SPV_C_Addresses, SPV_C_Linkage, SPV_C_Kernel, SPV_C_Float16,
SPV_C_Float64, SPV_C_Int64, SPV_C_Groups, SPV_C_Int16, SPV_C_Int8,
SPV_C_Sampled1D, SPV_C_SampledBuffer, SPV_C_GroupNonUniform, SPV_C_ShaderLayer,
@ -1514,7 +1503,7 @@ def SPV_AM_PhysicalStorageBuffer64 : I32EnumAttrCase<"PhysicalStorageBuffer64",
}
def SPV_AddressingModelAttr :
SPV_I32EnumAttr<"AddressingModel", "valid SPIR-V AddressingModel", [
SPV_I32EnumAttr<"AddressingModel", "valid SPIR-V AddressingModel", "addressing_model", [
SPV_AM_Logical, SPV_AM_Physical32, SPV_AM_Physical64,
SPV_AM_PhysicalStorageBuffer64
]>;
@ -2049,7 +2038,7 @@ def SPV_BI_CullMaskKHR : I32EnumAttrCase<"CullMaskKHR", 6021> {
}
def SPV_BuiltInAttr :
SPV_I32EnumAttr<"BuiltIn", "valid SPIR-V BuiltIn", [
SPV_I32EnumAttr<"BuiltIn", "valid SPIR-V BuiltIn", "built_in", [
SPV_BI_Position, SPV_BI_PointSize, SPV_BI_ClipDistance, SPV_BI_CullDistance,
SPV_BI_VertexId, SPV_BI_InstanceId, SPV_BI_PrimitiveId, SPV_BI_InvocationId,
SPV_BI_Layer, SPV_BI_ViewportIndex, SPV_BI_TessLevelOuter,
@ -2610,7 +2599,7 @@ def SPV_D_MediaBlockIOINTEL : I32EnumAttrCase<"MediaBlockIOINTE
}
def SPV_DecorationAttr :
SPV_I32EnumAttr<"Decoration", "valid SPIR-V Decoration", [
SPV_I32EnumAttr<"Decoration", "valid SPIR-V Decoration", "decoration", [
SPV_D_RelaxedPrecision, SPV_D_SpecId, SPV_D_Block, SPV_D_BufferBlock,
SPV_D_RowMajor, SPV_D_ColMajor, SPV_D_ArrayStride, SPV_D_MatrixStride,
SPV_D_GLSLShared, SPV_D_GLSLPacked, SPV_D_CPacked, SPV_D_BuiltIn,
@ -2679,7 +2668,7 @@ def SPV_D_SubpassData : I32EnumAttrCase<"SubpassData", 6> {
}
def SPV_DimAttr :
SPV_I32EnumAttr<"Dim", "valid SPIR-V Dim", [
SPV_I32EnumAttr<"Dim", "valid SPIR-V Dim", "dim", [
SPV_D_1D, SPV_D_2D, SPV_D_3D, SPV_D_Cube, SPV_D_Rect, SPV_D_Buffer,
SPV_D_SubpassData
]>;
@ -3093,7 +3082,7 @@ def SPV_EM_NamedBarrierCountINTEL : I32EnumAttrCase<"NamedBarrierCount
}
def SPV_ExecutionModeAttr :
SPV_I32EnumAttr<"ExecutionMode", "valid SPIR-V ExecutionMode", [
SPV_I32EnumAttr<"ExecutionMode", "valid SPIR-V ExecutionMode", "execution_mode", [
SPV_EM_Invocations, SPV_EM_SpacingEqual, SPV_EM_SpacingFractionalEven,
SPV_EM_SpacingFractionalOdd, SPV_EM_VertexOrderCw, SPV_EM_VertexOrderCcw,
SPV_EM_PixelCenterInteger, SPV_EM_OriginUpperLeft, SPV_EM_OriginLowerLeft,
@ -3203,7 +3192,7 @@ def SPV_EM_CallableKHR : I32EnumAttrCase<"CallableKHR", 5318> {
}
def SPV_ExecutionModelAttr :
SPV_I32EnumAttr<"ExecutionModel", "valid SPIR-V ExecutionModel", [
SPV_I32EnumAttr<"ExecutionModel", "valid SPIR-V ExecutionModel", "execution_model", [
SPV_EM_Vertex, SPV_EM_TessellationControl, SPV_EM_TessellationEvaluation,
SPV_EM_Geometry, SPV_EM_Fragment, SPV_EM_GLCompute, SPV_EM_Kernel,
SPV_EM_TaskNV, SPV_EM_MeshNV, SPV_EM_RayGenerationKHR, SPV_EM_IntersectionKHR,
@ -3222,7 +3211,7 @@ def SPV_FC_OptNoneINTEL : I32BitEnumAttrCaseBit<"OptNoneINTEL", 16> {
}
def SPV_FunctionControlAttr :
SPV_BitEnumAttr<"FunctionControl", "valid SPIR-V FunctionControl", [
SPV_BitEnumAttr<"FunctionControl", "valid SPIR-V FunctionControl", "function_control", [
SPV_FC_None, SPV_FC_Inline, SPV_FC_DontInline, SPV_FC_Pure, SPV_FC_Const,
SPV_FC_OptNoneINTEL
]>;
@ -3268,7 +3257,7 @@ def SPV_GO_PartitionedExclusiveScanNV : I32EnumAttrCase<"PartitionedExclusiveSca
}
def SPV_GroupOperationAttr :
SPV_I32EnumAttr<"GroupOperation", "valid SPIR-V GroupOperation", [
SPV_I32EnumAttr<"GroupOperation", "valid SPIR-V GroupOperation", "group_operation", [
SPV_GO_Reduce, SPV_GO_InclusiveScan, SPV_GO_ExclusiveScan,
SPV_GO_ClusteredReduce, SPV_GO_PartitionedReduceNV,
SPV_GO_PartitionedInclusiveScanNV, SPV_GO_PartitionedExclusiveScanNV
@ -3482,7 +3471,7 @@ def SPV_IF_R64i : I32EnumAttrCase<"R64i", 41> {
}
def SPV_ImageFormatAttr :
SPV_I32EnumAttr<"ImageFormat", "valid SPIR-V ImageFormat", [
SPV_I32EnumAttr<"ImageFormat", "valid SPIR-V ImageFormat", "image_format", [
SPV_IF_Unknown, SPV_IF_Rgba32f, SPV_IF_Rgba16f, SPV_IF_R32f, SPV_IF_Rgba8,
SPV_IF_Rgba8Snorm, SPV_IF_Rg32f, SPV_IF_Rg16f, SPV_IF_R11fG11fB10f,
SPV_IF_R16f, SPV_IF_Rgba16, SPV_IF_Rgb10A2, SPV_IF_Rg16, SPV_IF_Rg8,
@ -3561,7 +3550,7 @@ def SPV_IO_Nontemporal : I32BitEnumAttrCaseBit<"Nontemporal", 14> {
}
def SPV_ImageOperandsAttr :
SPV_BitEnumAttr<"ImageOperands", "valid SPIR-V ImageOperands", [
SPV_BitEnumAttr<"ImageOperands", "valid SPIR-V ImageOperands", "image_operands", [
SPV_IO_None, SPV_IO_Bias, SPV_IO_Lod, SPV_IO_Grad, SPV_IO_ConstOffset,
SPV_IO_Offset, SPV_IO_ConstOffsets, SPV_IO_Sample, SPV_IO_MinLod,
SPV_IO_MakeTexelAvailable, SPV_IO_MakeTexelVisible, SPV_IO_NonPrivateTexel,
@ -3587,7 +3576,7 @@ def SPV_LT_LinkOnceODR : I32EnumAttrCase<"LinkOnceODR", 2> {
}
def SPV_LinkageTypeAttr :
SPV_I32EnumAttr<"LinkageType", "valid SPIR-V LinkageType", [
SPV_I32EnumAttr<"LinkageType", "valid SPIR-V LinkageType", "linkage_type", [
SPV_LT_Export, SPV_LT_Import, SPV_LT_LinkOnceODR
]>;
@ -3679,7 +3668,7 @@ def SPV_LC_NoFusionINTEL : I32BitEnumAttrCaseBit<"NoFusionINTEL", 23
}
def SPV_LoopControlAttr :
SPV_BitEnumAttr<"LoopControl", "valid SPIR-V LoopControl", [
SPV_BitEnumAttr<"LoopControl", "valid SPIR-V LoopControl", "loop_control", [
SPV_LC_None, SPV_LC_Unroll, SPV_LC_DontUnroll, SPV_LC_DependencyInfinite,
SPV_LC_DependencyLength, SPV_LC_MinIterations, SPV_LC_MaxIterations,
SPV_LC_IterationMultiple, SPV_LC_PeelCount, SPV_LC_PartialCount,
@ -3725,7 +3714,7 @@ def SPV_MA_NoAliasINTELMask : I32BitEnumAttrCaseBit<"NoAliasINTELMask", 17>
}
def SPV_MemoryAccessAttr :
SPV_BitEnumAttr<"MemoryAccess", "valid SPIR-V MemoryAccess", [
SPV_BitEnumAttr<"MemoryAccess", "valid SPIR-V MemoryAccess", "memory_access", [
SPV_MA_None, SPV_MA_Volatile, SPV_MA_Aligned, SPV_MA_Nontemporal,
SPV_MA_MakePointerAvailable, SPV_MA_MakePointerVisible,
SPV_MA_NonPrivatePointer, SPV_MA_AliasScopeINTELMask, SPV_MA_NoAliasINTELMask
@ -3754,7 +3743,7 @@ def SPV_MM_Vulkan : I32EnumAttrCase<"Vulkan", 3> {
}
def SPV_MemoryModelAttr :
SPV_I32EnumAttr<"MemoryModel", "valid SPIR-V MemoryModel", [
SPV_I32EnumAttr<"MemoryModel", "valid SPIR-V MemoryModel", "memory_model", [
SPV_MM_Simple, SPV_MM_GLSL450, SPV_MM_OpenCL, SPV_MM_Vulkan
]>;
@ -3803,7 +3792,7 @@ def SPV_MS_Volatile : I32BitEnumAttrCaseBit<"Volatile", 15> {
}
def SPV_MemorySemanticsAttr :
SPV_BitEnumAttr<"MemorySemantics", "valid SPIR-V MemorySemantics", [
SPV_BitEnumAttr<"MemorySemantics", "valid SPIR-V MemorySemantics", "memory_semantics", [
SPV_MS_None, SPV_MS_Acquire, SPV_MS_Release, SPV_MS_AcquireRelease,
SPV_MS_SequentiallyConsistent, SPV_MS_UniformMemory, SPV_MS_SubgroupMemory,
SPV_MS_WorkgroupMemory, SPV_MS_CrossWorkgroupMemory,
@ -3829,7 +3818,7 @@ def SPV_S_ShaderCallKHR : I32EnumAttrCase<"ShaderCallKHR", 6> {
}
def SPV_ScopeAttr :
SPV_I32EnumAttr<"Scope", "valid SPIR-V Scope", [
SPV_I32EnumAttr<"Scope", "valid SPIR-V Scope", "scope", [
SPV_S_CrossDevice, SPV_S_Device, SPV_S_Workgroup, SPV_S_Subgroup,
SPV_S_Invocation, SPV_S_QueueFamily, SPV_S_ShaderCallKHR
]>;
@ -3839,7 +3828,7 @@ def SPV_SC_Flatten : I32BitEnumAttrCaseBit<"Flatten", 0>;
def SPV_SC_DontFlatten : I32BitEnumAttrCaseBit<"DontFlatten", 1>;
def SPV_SelectionControlAttr :
SPV_BitEnumAttr<"SelectionControl", "valid SPIR-V SelectionControl", [
SPV_BitEnumAttr<"SelectionControl", "valid SPIR-V SelectionControl", "selection_control", [
SPV_SC_None, SPV_SC_Flatten, SPV_SC_DontFlatten
]>;
@ -3947,7 +3936,7 @@ def SPV_SC_HostOnlyINTEL : I32EnumAttrCase<"HostOnlyINTEL", 5937> {
}
def SPV_StorageClassAttr :
SPV_I32EnumAttr<"StorageClass", "valid SPIR-V StorageClass", [
SPV_I32EnumAttr<"StorageClass", "valid SPIR-V StorageClass", "storage_class", [
SPV_SC_UniformConstant, SPV_SC_Input, SPV_SC_Uniform, SPV_SC_Output,
SPV_SC_Workgroup, SPV_SC_CrossWorkgroup, SPV_SC_Private, SPV_SC_Function,
SPV_SC_Generic, SPV_SC_PushConstant, SPV_SC_AtomicCounter, SPV_SC_Image,
@ -3965,34 +3954,32 @@ def SPV_IDI_NoDepth : I32EnumAttrCase<"NoDepth", 0>;
def SPV_IDI_IsDepth : I32EnumAttrCase<"IsDepth", 1>;
def SPV_IDI_DepthUnknown : I32EnumAttrCase<"DepthUnknown", 2>;
def SPV_DepthAttr :
SPV_I32EnumAttr<"ImageDepthInfo", "valid SPIR-V Image Depth specification",
[SPV_IDI_NoDepth, SPV_IDI_IsDepth, SPV_IDI_DepthUnknown]>;
def SPV_DepthAttr : SPV_I32EnumAttr<
"ImageDepthInfo", "valid SPIR-V Image Depth specification",
"image_depth_info", [SPV_IDI_NoDepth, SPV_IDI_IsDepth, SPV_IDI_DepthUnknown]>;
def SPV_IAI_NonArrayed : I32EnumAttrCase<"NonArrayed", 0>;
def SPV_IAI_Arrayed : I32EnumAttrCase<"Arrayed", 1>;
def SPV_ArrayedAttr :
SPV_I32EnumAttr<
"ImageArrayedInfo", "valid SPIR-V Image Arrayed specification",
[SPV_IAI_NonArrayed, SPV_IAI_Arrayed]>;
def SPV_ArrayedAttr : SPV_I32EnumAttr<
"ImageArrayedInfo", "valid SPIR-V Image Arrayed specification",
"image_arrayed_info", [SPV_IAI_NonArrayed, SPV_IAI_Arrayed]>;
def SPV_ISI_SingleSampled : I32EnumAttrCase<"SingleSampled", 0>;
def SPV_ISI_MultiSampled : I32EnumAttrCase<"MultiSampled", 1>;
def SPV_SamplingAttr:
SPV_I32EnumAttr<
"ImageSamplingInfo", "valid SPIR-V Image Sampling specification",
[SPV_ISI_SingleSampled, SPV_ISI_MultiSampled]>;
def SPV_SamplingAttr: SPV_I32EnumAttr<
"ImageSamplingInfo", "valid SPIR-V Image Sampling specification",
"image_sampling_info", [SPV_ISI_SingleSampled, SPV_ISI_MultiSampled]>;
def SPV_ISUI_SamplerUnknown : I32EnumAttrCase<"SamplerUnknown", 0>;
def SPV_ISUI_NeedSampler : I32EnumAttrCase<"NeedSampler", 1>;
def SPV_ISUI_NoSampler : I32EnumAttrCase<"NoSampler", 2>;
def SPV_SamplerUseAttr:
SPV_I32EnumAttr<
"ImageSamplerUseInfo", "valid SPIR-V Sampler Use specification",
[SPV_ISUI_SamplerUnknown, SPV_ISUI_NeedSampler, SPV_ISUI_NoSampler]>;
def SPV_SamplerUseAttr: SPV_I32EnumAttr<
"ImageSamplerUseInfo", "valid SPIR-V Sampler Use specification",
"image_sampler_use_info",
[SPV_ISUI_SamplerUnknown, SPV_ISUI_NeedSampler, SPV_ISUI_NoSampler]>;
//===----------------------------------------------------------------------===//
// SPIR-V attribute definitions
@ -4326,7 +4313,7 @@ def SPV_OC_OpAssumeTrueKHR : I32EnumAttrCase<"OpAssumeTrueKHR", 5630
def SPV_OC_OpAtomicFAddEXT : I32EnumAttrCase<"OpAtomicFAddEXT", 6035>;
def SPV_OpcodeAttr :
SPV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
SPV_I32EnumAttr<"Opcode", "valid SPIR-V instructions", "opcode", [
SPV_OC_OpNop, SPV_OC_OpUndef, SPV_OC_OpSourceContinued, SPV_OC_OpSource,
SPV_OC_OpSourceExtension, SPV_OC_OpName, SPV_OC_OpMemberName, SPV_OC_OpString,
SPV_OC_OpLine, SPV_OC_OpExtension, SPV_OC_OpExtInstImport, SPV_OC_OpExtInst,

View File

@ -113,6 +113,9 @@ public:
// Returns the dialect for the attribute if defined.
Dialect getDialect() const;
// Returns the TableGen definition this Attribute was constructed from.
const llvm::Record &getDef() const;
};
// Wrapper class providing helper methods for accessing MLIR constant attribute

View File

@ -15,8 +15,8 @@
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/StringExtras.h"

View File

@ -15,6 +15,7 @@
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
#include "mlir/IR/BuiltinOps.h"
@ -643,15 +644,15 @@ public:
// this entry point's execution mode. We set it to be:
// __spv__{SPIR-V module name}_{function name}_execution_mode_info_{mode}
ModuleOp module = op->getParentOfType<ModuleOp>();
IntegerAttr executionModeAttr = op.execution_modeAttr();
spirv::ExecutionModeAttr executionModeAttr = op.execution_modeAttr();
std::string moduleName;
if (module.getName().has_value())
moduleName = "_" + module.getName().value().str();
else
moduleName = "";
std::string executionModeInfoName =
llvm::formatv("__spv_{0}_{1}_execution_mode_info_{2}", moduleName,
op.fn().str(), executionModeAttr.getValue());
std::string executionModeInfoName = llvm::formatv(
"__spv_{0}_{1}_execution_mode_info_{2}", moduleName, op.fn().str(),
static_cast<uint32_t>(executionModeAttr.getValue()));
MLIRContext *context = rewriter.getContext();
OpBuilder::InsertionGuard guard(rewriter);
@ -684,8 +685,10 @@ public:
// Initialize the struct and set the execution mode value.
rewriter.setInsertionPoint(block, block->begin());
Value structValue = rewriter.create<LLVM::UndefOp>(loc, structType);
Value executionMode =
rewriter.create<LLVM::ConstantOp>(loc, llvmI32Type, executionModeAttr);
Value executionMode = rewriter.create<LLVM::ConstantOp>(
loc, llvmI32Type,
rewriter.getI32IntegerAttr(
static_cast<uint32_t>(executionModeAttr.getValue())));
structValue = rewriter.create<LLVM::InsertValueOp>(
loc, structType, structValue, executionMode,
ArrayAttr::get(context,
@ -1391,8 +1394,8 @@ public:
auto llvmI32Type = IntegerType::get(context, 32);
Value targetOp = rewriter.create<LLVM::UndefOp>(loc, dstType);
for (unsigned i = 0; i < componentsArray.size(); i++) {
if (componentsArray[i].isa<IntegerAttr>())
op.emitError("unable to support non-constant component");
if (!componentsArray[i].isa<IntegerAttr>())
return op.emitError("unable to support non-constant component");
int indexVal = componentsArray[i].cast<IntegerAttr>().getInt();
if (indexVal == -1)

View File

@ -15,6 +15,7 @@
#include "mlir/Dialect/SPIRV/IR/ParserUtils.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
@ -174,19 +175,16 @@ parseEnumStrAttr(EnumClass &value, OpAsmParser &parser,
NamedAttrList attr;
auto loc = parser.getCurrentLocation();
if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
attrName, attr)) {
attrName, attr))
return failure();
}
if (!attrVal.isa<StringAttr>()) {
if (!attrVal.isa<StringAttr>())
return parser.emitError(loc, "expected ")
<< attrName << " attribute specified as string";
}
auto attrOptional =
spirv::symbolizeEnum<EnumClass>(attrVal.cast<StringAttr>().getValue());
if (!attrOptional) {
if (!attrOptional)
return parser.emitError(loc, "invalid ")
<< attrName << " attribute specification: " << attrVal;
}
value = *attrOptional;
return success();
}
@ -194,50 +192,52 @@ parseEnumStrAttr(EnumClass &value, OpAsmParser &parser,
/// Parses the next string attribute in `parser` as an enumerant of the given
/// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer
/// attribute with the enum class's name as attribute name.
template <typename EnumClass>
template <typename EnumAttrClass,
typename EnumClass = typename EnumAttrClass::ValueType>
static ParseResult
parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state,
StringRef attrName = spirv::attributeName<EnumClass>()) {
if (parseEnumStrAttr(value, parser)) {
if (parseEnumStrAttr(value, parser))
return failure();
}
state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr(
llvm::bit_cast<int32_t>(value)));
state.addAttribute(attrName,
parser.getBuilder().getAttr<EnumAttrClass>(value));
return success();
}
/// Parses the next keyword in `parser` as an enumerant of the given `EnumClass`
/// and inserts the enumerant into `state` as an 32-bit integer attribute with
/// the enum class's name as attribute name.
template <typename EnumClass>
template <typename EnumAttrClass,
typename EnumClass = typename EnumAttrClass::ValueType>
static ParseResult
parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
OperationState &state,
StringRef attrName = spirv::attributeName<EnumClass>()) {
if (parseEnumKeywordAttr(value, parser)) {
if (parseEnumKeywordAttr(value, parser))
return failure();
}
state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr(
llvm::bit_cast<int32_t>(value)));
state.addAttribute(attrName,
parser.getBuilder().getAttr<EnumAttrClass>(value));
return success();
}
/// Parses Function, Selection and Loop control attributes. If no control is
/// specified, "None" is used as a default.
template <typename EnumClass>
template <typename EnumAttrClass, typename EnumClass>
static ParseResult
parseControlAttribute(OpAsmParser &parser, OperationState &state,
StringRef attrName = spirv::attributeName<EnumClass>()) {
if (succeeded(parser.parseOptionalKeyword(kControl))) {
EnumClass control;
if (parser.parseLParen() || parseEnumKeywordAttr(control, parser, state) ||
if (parser.parseLParen() ||
parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) ||
parser.parseRParen())
return failure();
return success();
}
// Set control to "None" otherwise.
Builder builder = parser.getBuilder();
state.addAttribute(attrName, builder.getI32IntegerAttr(0));
state.addAttribute(attrName,
builder.getAttr<EnumAttrClass>(static_cast<EnumClass>(0)));
return success();
}
@ -256,10 +256,9 @@ static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
}
spirv::MemoryAccess memoryAccessAttr;
if (parseEnumStrAttr(memoryAccessAttr, parser, state,
kMemoryAccessAttrName)) {
if (parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser, state,
kMemoryAccessAttrName))
return failure();
}
if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
// Parse integer attribute for alignment.
@ -287,10 +286,9 @@ static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
}
spirv::MemoryAccess memoryAccessAttr;
if (parseEnumStrAttr(memoryAccessAttr, parser, state,
kSourceMemoryAccessAttrName)) {
if (parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser, state,
kSourceMemoryAccessAttrName))
return failure();
}
if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
// Parse integer attribute for alignment.
@ -479,15 +477,15 @@ static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
return success();
}
auto memAccessVal = memAccessAttr.template cast<IntegerAttr>();
auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());
auto memAccess = memAccessAttr.template cast<spirv::MemoryAccessAttr>();
if (!memAccess) {
return memoryOp.emitOpError("invalid memory access specifier: ")
<< memAccessVal;
<< memAccessAttr;
}
if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
if (spirv::bitEnumContains(memAccess.getValue(),
spirv::MemoryAccess::Aligned)) {
if (!op->getAttr(kAlignmentAttrName)) {
return memoryOp.emitOpError("missing alignment value");
}
@ -523,15 +521,15 @@ static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
return success();
}
auto memAccessVal = memAccessAttr.template cast<IntegerAttr>();
auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());
auto memAccess = memAccessAttr.template cast<spirv::MemoryAccessAttr>();
if (!memAccess) {
return memoryOp.emitOpError("invalid memory access specifier: ")
<< memAccessVal;
<< memAccess;
}
if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
if (spirv::bitEnumContains(memAccess.getValue(),
spirv::MemoryAccess::Aligned)) {
if (!op->getAttr(kSourceAlignmentAttrName)) {
return memoryOp.emitOpError("missing alignment value");
}
@ -770,8 +768,10 @@ static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
OpAsmParser::UnresolvedOperand ptrInfo, valueInfo;
Type type;
SMLoc loc;
if (parseEnumStrAttr(scope, parser, state, kMemoryScopeAttrName) ||
parseEnumStrAttr(memoryScope, parser, state, kSemanticsAttrName) ||
if (parseEnumStrAttr<spirv::ScopeAttr>(scope, parser, state,
kMemoryScopeAttrName) ||
parseEnumStrAttr<spirv::MemorySemanticsAttr>(memoryScope, parser, state,
kSemanticsAttrName) ||
parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) ||
parser.getCurrentLocation(&loc) || parser.parseColonType(type))
return failure();
@ -793,14 +793,11 @@ static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
// Prints an atomic update op.
static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) {
printer << " \"";
auto scopeAttr = op->getAttrOfType<IntegerAttr>(kMemoryScopeAttrName);
printer << spirv::stringifyScope(
static_cast<spirv::Scope>(scopeAttr.getInt()))
<< "\" \"";
auto memorySemanticsAttr = op->getAttrOfType<IntegerAttr>(kSemanticsAttrName);
printer << spirv::stringifyMemorySemantics(
static_cast<spirv::MemorySemantics>(
memorySemanticsAttr.getInt()))
auto scopeAttr = op->getAttrOfType<spirv::ScopeAttr>(kMemoryScopeAttrName);
printer << spirv::stringifyScope(scopeAttr.getValue()) << "\" \"";
auto memorySemanticsAttr =
op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName);
printer << spirv::stringifyMemorySemantics(memorySemanticsAttr.getValue())
<< "\" " << op->getOperands() << " : " << op->getOperand(0).getType();
}
@ -834,8 +831,9 @@ static LogicalResult verifyAtomicUpdateOp(Operation *op) {
"pointer operand's pointee type ")
<< elementType << ", but found " << valueType;
}
auto memorySemantics = static_cast<spirv::MemorySemantics>(
op->getAttrOfType<IntegerAttr>(kSemanticsAttrName).getInt());
auto memorySemantics =
op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName)
.getValue();
if (failed(verifyMemorySemantics(op, memorySemantics))) {
return failure();
}
@ -847,10 +845,10 @@ static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
spirv::Scope executionScope;
spirv::GroupOperation groupOperation;
OpAsmParser::UnresolvedOperand valueInfo;
if (parseEnumStrAttr(executionScope, parser, state,
kExecutionScopeAttrName) ||
parseEnumStrAttr(groupOperation, parser, state,
kGroupOperationAttrName) ||
if (parseEnumStrAttr<spirv::ScopeAttr>(executionScope, parser, state,
kExecutionScopeAttrName) ||
parseEnumStrAttr<spirv::GroupOperationAttr>(groupOperation, parser, state,
kGroupOperationAttrName) ||
parser.parseOperand(valueInfo))
return failure();
@ -880,15 +878,17 @@ static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
static void printGroupNonUniformArithmeticOp(Operation *groupOp,
OpAsmPrinter &printer) {
printer << " \""
<< stringifyScope(static_cast<spirv::Scope>(
groupOp->getAttrOfType<IntegerAttr>(kExecutionScopeAttrName)
.getInt()))
<< "\" \""
<< stringifyGroupOperation(static_cast<spirv::GroupOperation>(
groupOp->getAttrOfType<IntegerAttr>(kGroupOperationAttrName)
.getInt()))
<< "\" " << groupOp->getOperand(0);
printer
<< " \""
<< stringifyScope(
groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
.getValue())
<< "\" \""
<< stringifyGroupOperation(groupOp
->getAttrOfType<spirv::GroupOperationAttr>(
kGroupOperationAttrName)
.getValue())
<< "\" " << groupOp->getOperand(0);
if (groupOp->getNumOperands() > 1)
printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')';
@ -896,14 +896,16 @@ static void printGroupNonUniformArithmeticOp(Operation *groupOp,
}
static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
spirv::Scope scope = static_cast<spirv::Scope>(
groupOp->getAttrOfType<IntegerAttr>(kExecutionScopeAttrName).getInt());
spirv::Scope scope =
groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
.getValue();
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
return groupOp->emitOpError(
"execution scope must be 'Workgroup' or 'Subgroup'");
spirv::GroupOperation operation = static_cast<spirv::GroupOperation>(
groupOp->getAttrOfType<IntegerAttr>(kGroupOperationAttrName).getInt());
spirv::GroupOperation operation =
groupOp->getAttrOfType<spirv::GroupOperationAttr>(kGroupOperationAttrName)
.getValue();
if (operation == spirv::GroupOperation::ClusteredReduce &&
groupOp->getNumOperands() == 1)
return groupOp->emitOpError("cluster size operand must be provided for "
@ -1145,11 +1147,12 @@ static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser,
spirv::MemorySemantics equalSemantics, unequalSemantics;
SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
Type type;
if (parseEnumStrAttr(memoryScope, parser, state, kMemoryScopeAttrName) ||
parseEnumStrAttr(equalSemantics, parser, state,
kEqualSemanticsAttrName) ||
parseEnumStrAttr(unequalSemantics, parser, state,
kUnequalSemanticsAttrName) ||
if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, state,
kMemoryScopeAttrName) ||
parseEnumStrAttr<spirv::MemorySemanticsAttr>(
equalSemantics, parser, state, kEqualSemanticsAttrName) ||
parseEnumStrAttr<spirv::MemorySemanticsAttr>(
unequalSemantics, parser, state, kUnequalSemanticsAttrName) ||
parser.parseOperandList(operandInfo, 3))
return failure();
@ -1267,8 +1270,10 @@ ParseResult spirv::AtomicExchangeOp::parse(OpAsmParser &parser,
spirv::MemorySemantics semantics;
SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
Type type;
if (parseEnumStrAttr(memoryScope, parser, result, kMemoryScopeAttrName) ||
parseEnumStrAttr(semantics, parser, result, kSemanticsAttrName) ||
if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, result,
kMemoryScopeAttrName) ||
parseEnumStrAttr<spirv::MemorySemanticsAttr>(semantics, parser, result,
kSemanticsAttrName) ||
parser.parseOperandList(operandInfo, 2))
return failure();
@ -2075,7 +2080,7 @@ ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
SmallVector<Attribute, 4> interfaceVars;
FlatSymbolRefAttr fn;
if (parseEnumStrAttr(execModel, parser, result) ||
if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
return failure();
}
@ -2132,7 +2137,7 @@ ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
spirv::ExecutionMode execMode;
Attribute fn;
if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
parseEnumStrAttr(execMode, parser, result)) {
parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) {
return failure();
}
@ -2220,7 +2225,7 @@ ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
// Parse the optional function control keyword.
spirv::FunctionControl fnControl;
if (parseEnumStrAttr(fnControl, parser, result))
if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result))
return failure();
// If additional attributes are present, parse them.
@ -2308,7 +2313,7 @@ void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
builder.getStringAttr(name));
state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
builder.getI32IntegerAttr(static_cast<uint32_t>(control)));
builder.getAttr<spirv::FunctionControlAttr>(control));
state.attributes.append(attrs.begin(), attrs.end());
state.addRegion();
}
@ -2997,14 +3002,14 @@ LogicalResult spirv::LoadOp::verify() {
//===----------------------------------------------------------------------===//
void spirv::LoopOp::build(OpBuilder &builder, OperationState &state) {
state.addAttribute("loop_control",
builder.getI32IntegerAttr(
static_cast<uint32_t>(spirv::LoopControl::None)));
state.addAttribute("loop_control", builder.getAttr<spirv::LoopControlAttr>(
spirv::LoopControl::None));
state.addRegion();
}
ParseResult spirv::LoopOp::parse(OpAsmParser &parser, OperationState &result) {
if (parseControlAttribute<spirv::LoopControl>(parser, result))
if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
result))
return failure();
return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
}
@ -3195,9 +3200,9 @@ void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
Optional<StringRef> name) {
state.addAttribute(
"addressing_model",
builder.getI32IntegerAttr(static_cast<int32_t>(addressingModel)));
state.addAttribute("memory_model", builder.getI32IntegerAttr(
static_cast<int32_t>(memoryModel)));
builder.getAttr<spirv::AddressingModelAttr>(addressingModel));
state.addAttribute("memory_model",
builder.getAttr<spirv::MemoryModelAttr>(memoryModel));
OpBuilder::InsertionGuard guard(builder);
builder.createBlock(state.addRegion());
if (vceTriple)
@ -3219,8 +3224,10 @@ ParseResult spirv::ModuleOp::parse(OpAsmParser &parser,
// Parse attributes
spirv::AddressingModel addrModel;
spirv::MemoryModel memoryModel;
if (::parseEnumKeywordAttr(addrModel, parser, result) ||
::parseEnumKeywordAttr(memoryModel, parser, result))
if (::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
result) ||
::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
result))
return failure();
if (succeeded(parser.parseOptionalKeyword("requires"))) {
@ -3401,7 +3408,8 @@ LogicalResult spirv::SelectOp::verify() {
ParseResult spirv::SelectionOp::parse(OpAsmParser &parser,
OperationState &result) {
if (parseControlAttribute<spirv::SelectionControl>(parser, result))
if (parseControlAttribute<spirv::SelectionControlAttr,
spirv::SelectionControl>(parser, result))
return failure();
return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
}
@ -3666,8 +3674,8 @@ ParseResult spirv::VariableOp::parse(OpAsmParser &parser,
return failure();
}
auto attr = parser.getBuilder().getI32IntegerAttr(
llvm::bit_cast<int32_t>(ptrType.getStorageClass()));
auto attr = parser.getBuilder().getAttr<spirv::StorageClassAttr>(
ptrType.getStorageClass());
result.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
return success();

View File

@ -132,6 +132,8 @@ Dialect Attribute::getDialect() const {
return Dialect(nullptr);
}
const llvm::Record &Attribute::getDef() const { return *def; }
ConstantAttr::ConstantAttr(const DefInit *init) : def(init->getDef()) {
assert(def->isSubClassOf("ConstantAttr") &&
"must be subclass of TableGen 'ConstantAttr' class");

View File

@ -12,6 +12,7 @@
#include "Deserializer.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Location.h"
@ -406,35 +407,6 @@ Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
return success();
}
template <>
LogicalResult
Deserializer::processOp<spirv::ControlBarrierOp>(ArrayRef<uint32_t> operands) {
if (operands.size() != 3) {
return emitError(
unknownLoc,
"OpControlBarrier must have execution scope <id>, memory scope <id> "
"and memory semantics <id>");
}
SmallVector<IntegerAttr, 3> argAttrs;
for (auto operand : operands) {
auto argAttr = getConstantInt(operand);
if (!argAttr) {
return emitError(unknownLoc,
"expected 32-bit integer constant from <id> ")
<< operand << " for OpControlBarrier";
}
argAttrs.push_back(argAttr);
}
opBuilder.create<spirv::ControlBarrierOp>(
unknownLoc, argAttrs[0].cast<spirv::ScopeAttr>(),
argAttrs[1].cast<spirv::ScopeAttr>(),
argAttrs[2].cast<spirv::MemorySemanticsAttr>());
return success();
}
template <>
LogicalResult
Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
@ -477,31 +449,6 @@ Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
return success();
}
template <>
LogicalResult
Deserializer::processOp<spirv::MemoryBarrierOp>(ArrayRef<uint32_t> operands) {
if (operands.size() != 2) {
return emitError(unknownLoc, "OpMemoryBarrier must have memory scope <id> "
"and memory semantics <id>");
}
SmallVector<IntegerAttr, 2> argAttrs;
for (auto operand : operands) {
auto argAttr = getConstantInt(operand);
if (!argAttr) {
return emitError(unknownLoc,
"expected 32-bit integer constant from <id> ")
<< operand << " for OpMemoryBarrier";
}
argAttrs.push_back(argAttr);
}
opBuilder.create<spirv::MemoryBarrierOp>(
unknownLoc, argAttrs[0].cast<spirv::ScopeAttr>(),
argAttrs[1].cast<spirv::MemorySemanticsAttr>());
return success();
}
template <>
LogicalResult
Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
@ -538,8 +485,9 @@ Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
if (wordIndex < words.size()) {
auto attrValue = words[wordIndex++];
attributes.push_back(opBuilder.getNamedAttr(
"memory_access", opBuilder.getI32IntegerAttr(attrValue)));
auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
static_cast<spirv::MemoryAccess>(attrValue));
attributes.push_back(opBuilder.getNamedAttr("memory_access", attr));
isAlignedAttr = (attrValue == 2);
}
@ -549,9 +497,10 @@ Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
}
if (wordIndex < words.size()) {
attributes.push_back(opBuilder.getNamedAttr(
"source_memory_access",
opBuilder.getI32IntegerAttr(words[wordIndex++])));
auto attrValue = words[wordIndex++];
auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
static_cast<spirv::MemoryAccess>(attrValue));
attributes.push_back(opBuilder.getNamedAttr("source_memory_access", attr));
}
if (wordIndex < words.size()) {

View File

@ -216,10 +216,11 @@ spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) {
(*module)->setAttr(
"addressing_model",
opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.front())));
(*module)->setAttr(
"memory_model",
opBuilder.getI32IntegerAttr(llvm::bit_cast<int32_t>(operands.back())));
opBuilder.getAttr<spirv::AddressingModelAttr>(
static_cast<spirv::AddressingModel>(operands.front())));
(*module)->setAttr("memory_model",
opBuilder.getAttr<spirv::MemoryModelAttr>(
static_cast<spirv::MemoryModel>(operands.back())));
return success();
}

View File

@ -13,6 +13,7 @@
#include "Serializer.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/IR/RegionGraphTraits.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
@ -277,8 +278,8 @@ LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
operands.push_back(resultID);
auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>());
if (attr) {
operands.push_back(static_cast<uint32_t>(
attr.cast<IntegerAttr>().getValue().getZExtValue()));
operands.push_back(
static_cast<uint32_t>(attr.cast<spirv::StorageClassAttr>().getValue()));
}
elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
for (auto arg : op.getODSOperands(0)) {
@ -565,27 +566,6 @@ Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
return success();
}
template <>
LogicalResult
Serializer::processOp<spirv::ControlBarrierOp>(spirv::ControlBarrierOp op) {
StringRef argNames[] = {"execution_scope", "memory_scope",
"memory_semantics"};
SmallVector<uint32_t, 3> operands;
for (auto argName : argNames) {
auto argIntAttr = op->getAttrOfType<IntegerAttr>(argName);
auto operand = prepareConstantInt(op.getLoc(), argIntAttr);
if (!operand) {
return failure();
}
operands.push_back(operand);
}
encodeInstructionInto(functionBody, spirv::Opcode::OpControlBarrier,
operands);
return success();
}
template <>
LogicalResult
Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
@ -615,25 +595,6 @@ Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
return success();
}
template <>
LogicalResult
Serializer::processOp<spirv::MemoryBarrierOp>(spirv::MemoryBarrierOp op) {
StringRef argNames[] = {"memory_scope", "memory_semantics"};
SmallVector<uint32_t, 2> operands;
for (auto argName : argNames) {
auto argIntAttr = op->getAttrOfType<IntegerAttr>(argName);
auto operand = prepareConstantInt(op.getLoc(), argIntAttr);
if (!operand) {
return failure();
}
operands.push_back(operand);
}
encodeInstructionInto(functionBody, spirv::Opcode::OpMemoryBarrier, operands);
return success();
}
template <>
LogicalResult
Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
@ -674,8 +635,8 @@ Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
}
if (auto attr = op->getAttr("memory_access")) {
operands.push_back(static_cast<uint32_t>(
attr.cast<IntegerAttr>().getValue().getZExtValue()));
operands.push_back(
static_cast<uint32_t>(attr.cast<spirv::MemoryAccessAttr>().getValue()));
}
elidedAttrs.push_back("memory_access");
@ -688,8 +649,8 @@ Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
elidedAttrs.push_back("alignment");
if (auto attr = op->getAttr("source_memory_access")) {
operands.push_back(static_cast<uint32_t>(
attr.cast<IntegerAttr>().getValue().getZExtValue()));
operands.push_back(
static_cast<uint32_t>(attr.cast<spirv::MemoryAccessAttr>().getValue()));
}
elidedAttrs.push_back("source_memory_access");

View File

@ -14,6 +14,7 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
@ -23,6 +24,7 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/bit.h"
#include "llvm/Support/Debug.h"
#include <cstdint>
#define DEBUG_TYPE "spirv-serialization"
@ -192,8 +194,11 @@ void Serializer::processExtension() {
}
void Serializer::processMemoryModel() {
uint32_t mm = module->getAttrOfType<IntegerAttr>("memory_model").getInt();
uint32_t am = module->getAttrOfType<IntegerAttr>("addressing_model").getInt();
auto mm = static_cast<uint32_t>(
module->getAttrOfType<spirv::MemoryModelAttr>("memory_model").getValue());
auto am = static_cast<uint32_t>(
module->getAttrOfType<spirv::AddressingModelAttr>("addressing_model")
.getValue());
encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});
}

View File

@ -112,7 +112,7 @@ module attributes {gpu.container_module} {
// CHECK-LABEL: spv.func @barrier
gpu.func @barrier(%arg0 : f32, %arg1 : memref<12xf32>) kernel
attributes {spv.entry_point_abi = #spv.entry_point_abi<local_size = dense<[32, 4, 1]>: vector<3xi32>>} {
// CHECK: spv.ControlBarrier Workgroup, Workgroup, "AcquireRelease|WorkgroupMemory"
// CHECK: spv.ControlBarrier <Workgroup>, <Workgroup>, <AcquireRelease|WorkgroupMemory>
gpu.barrier
gpu.return
}

View File

@ -32,7 +32,7 @@ module attributes {
// CHECK: %[[ADD:.+]] = spv.GroupNonUniformIAdd "Subgroup" "Reduce" %[[VAL]] : i32
// CHECK: %[[OUTPTR:.+]] = spv.AccessChain %[[OUTPUT]][%[[ZERO]], %[[ZERO]]]
// CHECK: %[[ELECT:.+]] = spv.GroupNonUniformElect Subgroup : i1
// CHECK: %[[ELECT:.+]] = spv.GroupNonUniformElect <Subgroup> : i1
// CHECK: spv.mlir.selection {
// CHECK: spv.BranchConditional %[[ELECT]], ^bb1, ^bb2

View File

@ -1,32 +1,30 @@
// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -map-memref-spirv-storage-class='client-api=vulkan' -verify-diagnostics %s -o - | FileCheck %s --check-prefix=VULKAN
// Vulkan Mappings:
// 0 -> StorageBuffer (12)
// 1 -> Generic (8)
// 3 -> Workgroup (4)
// 4 -> Uniform (2)
// TODO: create a StorageClass wrapper class so we can print the symbolc
// storage class (instead of the backing IntegerAttr) and be able to
// round trip the IR.
// 0 -> StorageBuffer
// 1 -> Generic
// 2 -> [null]
// 3 -> Workgroup
// 4 -> Uniform
// VULKAN-LABEL: func @operand_result
func.func @operand_result() {
// VULKAN: memref<f32, 12 : i32>
// VULKAN: memref<f32, #spv.storage_class<StorageBuffer>>
%0 = "dialect.memref_producer"() : () -> (memref<f32>)
// VULKAN: memref<4xi32, 8 : i32>
// VULKAN: memref<4xi32, #spv.storage_class<Generic>>
%1 = "dialect.memref_producer"() : () -> (memref<4xi32, 1>)
// VULKAN: memref<?x4xf16, 4 : i32>
// VULKAN: memref<?x4xf16, #spv.storage_class<Workgroup>>
%2 = "dialect.memref_producer"() : () -> (memref<?x4xf16, 3>)
// VULKAN: memref<*xf16, 2 : i32>
// VULKAN: memref<*xf16, #spv.storage_class<Uniform>>
%3 = "dialect.memref_producer"() : () -> (memref<*xf16, 4>)
"dialect.memref_consumer"(%0) : (memref<f32>) -> ()
// VULKAN: memref<4xi32, 8 : i32>
// VULKAN: memref<4xi32, #spv.storage_class<Generic>>
"dialect.memref_consumer"(%1) : (memref<4xi32, 1>) -> ()
// VULKAN: memref<?x4xf16, 4 : i32>
// VULKAN: memref<?x4xf16, #spv.storage_class<Workgroup>>
"dialect.memref_consumer"(%2) : (memref<?x4xf16, 3>) -> ()
// VULKAN: memref<*xf16, 2 : i32>
// VULKAN: memref<*xf16, #spv.storage_class<Uniform>>
"dialect.memref_consumer"(%3) : (memref<*xf16, 4>) -> ()
return
@ -36,7 +34,7 @@ func.func @operand_result() {
// VULKAN-LABEL: func @type_attribute
func.func @type_attribute() {
// VULKAN: attr = memref<i32, 8 : i32>
// VULKAN: attr = memref<i32, #spv.storage_class<Generic>>
"dialect.memref_producer"() { attr = memref<i32, 1> } : () -> ()
return
}
@ -45,9 +43,9 @@ func.func @type_attribute() {
// VULKAN-LABEL: func @function_io
func.func @function_io
// VULKAN-SAME: (%{{.+}}: memref<f64, 8 : i32>, %{{.+}}: memref<4xi32, 4 : i32>)
// VULKAN-SAME: (%{{.+}}: memref<f64, #spv.storage_class<Generic>>, %{{.+}}: memref<4xi32, #spv.storage_class<Workgroup>>)
(%arg0: memref<f64, 1>, %arg1: memref<4xi32, 3>)
// VULKAN-SAME: -> (memref<f64, 8 : i32>, memref<4xi32, 4 : i32>)
// VULKAN-SAME: -> (memref<f64, #spv.storage_class<Generic>>, memref<4xi32, #spv.storage_class<Workgroup>>)
-> (memref<f64, 1>, memref<4xi32, 3>) {
return %arg0, %arg1: memref<f64, 1>, memref<4xi32, 3>
}
@ -57,8 +55,8 @@ func.func @function_io
// VULKAN: func @region
func.func @region(%cond: i1, %arg0: memref<f32, 1>) {
scf.if %cond {
// VULKAN: "dialect.memref_consumer"(%{{.+}}) {attr = memref<i64, 4 : i32>}
// VULKAN-SAME: (memref<f32, 8 : i32>) -> memref<f32, 8 : i32>
// VULKAN: "dialect.memref_consumer"(%{{.+}}) {attr = memref<i64, #spv.storage_class<Workgroup>>}
// VULKAN-SAME: (memref<f32, #spv.storage_class<Generic>>) -> memref<f32, #spv.storage_class<Generic>>
%0 = "dialect.memref_consumer"(%arg0) { attr = memref<i64, 3> } : (memref<f32, 1>) -> (memref<f32, 1>)
}
return

View File

@ -14,7 +14,7 @@ func.func @atomic_and(%ptr : !spv.ptr<i32, StorageBuffer>, %value : i32) -> i32
func.func @atomic_and(%ptr : !spv.ptr<f32, StorageBuffer>, %value : i32) -> i32 {
// expected-error @+1 {{pointer operand must point to an integer value, found 'f32'}}
%0 = "spv.AtomicAnd"(%ptr, %value) {memory_scope = 4: i32, semantics = 0x4 : i32} : (!spv.ptr<f32, StorageBuffer>, i32) -> (i32)
%0 = "spv.AtomicAnd"(%ptr, %value) {memory_scope = #spv.scope<Workgroup>, semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<f32, StorageBuffer>, i32) -> (i32)
return %0 : i32
}
@ -23,7 +23,7 @@ func.func @atomic_and(%ptr : !spv.ptr<f32, StorageBuffer>, %value : i32) -> i32
func.func @atomic_and(%ptr : !spv.ptr<i32, StorageBuffer>, %value : i64) -> i64 {
// expected-error @+1 {{expected value to have the same type as the pointer operand's pointee type 'i32', but found 'i64'}}
%0 = "spv.AtomicAnd"(%ptr, %value) {memory_scope = 2: i32, semantics = 0x8 : i32} : (!spv.ptr<i32, StorageBuffer>, i64) -> (i64)
%0 = "spv.AtomicAnd"(%ptr, %value) {memory_scope = #spv.scope<Workgroup>, semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<i32, StorageBuffer>, i64) -> (i64)
return %0 : i64
}
@ -51,7 +51,7 @@ func.func @atomic_compare_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i32,
func.func @atomic_compare_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i64, %comparator: i32) -> i32 {
// expected-error @+1 {{value operand must have the same type as the op result, but found 'i64' vs 'i32'}}
%0 = "spv.AtomicCompareExchange"(%ptr, %value, %comparator) {memory_scope = 4: i32, equal_semantics = 0x4: i32, unequal_semantics = 0x2:i32} : (!spv.ptr<i32, Workgroup>, i64, i32) -> (i32)
%0 = "spv.AtomicCompareExchange"(%ptr, %value, %comparator) {memory_scope = #spv.scope<Workgroup>, equal_semantics = #spv.memory_semantics<AcquireRelease>, unequal_semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<i32, Workgroup>, i64, i32) -> (i32)
return %0: i32
}
@ -59,7 +59,7 @@ func.func @atomic_compare_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i64,
func.func @atomic_compare_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i32, %comparator: i16) -> i32 {
// expected-error @+1 {{comparator operand must have the same type as the op result, but found 'i16' vs 'i32'}}
%0 = "spv.AtomicCompareExchange"(%ptr, %value, %comparator) {memory_scope = 4: i32, equal_semantics = 0x4: i32, unequal_semantics = 0x2:i32} : (!spv.ptr<i32, Workgroup>, i32, i16) -> (i32)
%0 = "spv.AtomicCompareExchange"(%ptr, %value, %comparator) {memory_scope = #spv.scope<Workgroup>, equal_semantics = #spv.memory_semantics<AcquireRelease>, unequal_semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<i32, Workgroup>, i32, i16) -> (i32)
return %0: i32
}
@ -67,7 +67,7 @@ func.func @atomic_compare_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i32,
func.func @atomic_compare_exchange(%ptr: !spv.ptr<i64, Workgroup>, %value: i32, %comparator: i32) -> i32 {
// expected-error @+1 {{pointer operand's pointee type must have the same as the op result type, but found 'i64' vs 'i32'}}
%0 = "spv.AtomicCompareExchange"(%ptr, %value, %comparator) {memory_scope = 4: i32, equal_semantics = 0x4: i32, unequal_semantics = 0x2:i32} : (!spv.ptr<i64, Workgroup>, i32, i32) -> (i32)
%0 = "spv.AtomicCompareExchange"(%ptr, %value, %comparator) {memory_scope = #spv.scope<Workgroup>, equal_semantics = #spv.memory_semantics<AcquireRelease>, unequal_semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<i64, Workgroup>, i32, i32) -> (i32)
return %0: i32
}
@ -87,7 +87,7 @@ func.func @atomic_compare_exchange_weak(%ptr: !spv.ptr<i32, Workgroup>, %value:
func.func @atomic_compare_exchange_weak(%ptr: !spv.ptr<i32, Workgroup>, %value: i64, %comparator: i32) -> i32 {
// expected-error @+1 {{value operand must have the same type as the op result, but found 'i64' vs 'i32'}}
%0 = "spv.AtomicCompareExchangeWeak"(%ptr, %value, %comparator) {memory_scope = 4: i32, equal_semantics = 0x4: i32, unequal_semantics = 0x2:i32} : (!spv.ptr<i32, Workgroup>, i64, i32) -> (i32)
%0 = "spv.AtomicCompareExchangeWeak"(%ptr, %value, %comparator) {memory_scope = #spv.scope<Workgroup>, equal_semantics = #spv.memory_semantics<AcquireRelease>, unequal_semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<i32, Workgroup>, i64, i32) -> (i32)
return %0: i32
}
@ -95,7 +95,7 @@ func.func @atomic_compare_exchange_weak(%ptr: !spv.ptr<i32, Workgroup>, %value:
func.func @atomic_compare_exchange_weak(%ptr: !spv.ptr<i32, Workgroup>, %value: i32, %comparator: i16) -> i32 {
// expected-error @+1 {{comparator operand must have the same type as the op result, but found 'i16' vs 'i32'}}
%0 = "spv.AtomicCompareExchangeWeak"(%ptr, %value, %comparator) {memory_scope = 4: i32, equal_semantics = 0x4: i32, unequal_semantics = 0x2:i32} : (!spv.ptr<i32, Workgroup>, i32, i16) -> (i32)
%0 = "spv.AtomicCompareExchangeWeak"(%ptr, %value, %comparator) {memory_scope = #spv.scope<Workgroup>, equal_semantics = #spv.memory_semantics<AcquireRelease>, unequal_semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<i32, Workgroup>, i32, i16) -> (i32)
return %0: i32
}
@ -103,7 +103,7 @@ func.func @atomic_compare_exchange_weak(%ptr: !spv.ptr<i32, Workgroup>, %value:
func.func @atomic_compare_exchange_weak(%ptr: !spv.ptr<i64, Workgroup>, %value: i32, %comparator: i32) -> i32 {
// expected-error @+1 {{pointer operand's pointee type must have the same as the op result type, but found 'i64' vs 'i32'}}
%0 = "spv.AtomicCompareExchangeWeak"(%ptr, %value, %comparator) {memory_scope = 4: i32, equal_semantics = 0x4: i32, unequal_semantics = 0x2:i32} : (!spv.ptr<i64, Workgroup>, i32, i32) -> (i32)
%0 = "spv.AtomicCompareExchangeWeak"(%ptr, %value, %comparator) {memory_scope = #spv.scope<Workgroup>, equal_semantics = #spv.memory_semantics<AcquireRelease>, unequal_semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<i64, Workgroup>, i32, i32) -> (i32)
return %0: i32
}
@ -123,7 +123,7 @@ func.func @atomic_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i32) -> i32 {
func.func @atomic_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i64) -> i32 {
// expected-error @+1 {{value operand must have the same type as the op result, but found 'i64' vs 'i32'}}
%0 = "spv.AtomicExchange"(%ptr, %value) {memory_scope = 4: i32, semantics = 0x4: i32} : (!spv.ptr<i32, Workgroup>, i64) -> (i32)
%0 = "spv.AtomicExchange"(%ptr, %value) {memory_scope = #spv.scope<Workgroup>, semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<i32, Workgroup>, i64) -> (i32)
return %0: i32
}
@ -131,7 +131,7 @@ func.func @atomic_exchange(%ptr: !spv.ptr<i32, Workgroup>, %value: i64) -> i32 {
func.func @atomic_exchange(%ptr: !spv.ptr<i64, Workgroup>, %value: i32) -> i32 {
// expected-error @+1 {{pointer operand's pointee type must have the same as the op result type, but found 'i64' vs 'i32'}}
%0 = "spv.AtomicExchange"(%ptr, %value) {memory_scope = 4: i32, semantics = 0x4: i32} : (!spv.ptr<i64, Workgroup>, i32) -> (i32)
%0 = "spv.AtomicExchange"(%ptr, %value) {memory_scope = #spv.scope<Workgroup>, semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<i64, Workgroup>, i32) -> (i32)
return %0: i32
}
@ -253,7 +253,7 @@ func.func @atomic_fadd(%ptr : !spv.ptr<f32, StorageBuffer>, %value : f32) -> f32
func.func @atomic_fadd(%ptr : !spv.ptr<i32, StorageBuffer>, %value : f32) -> f32 {
// expected-error @+1 {{pointer operand must point to an float value, found 'i32'}}
%0 = "spv.AtomicFAddEXT"(%ptr, %value) {memory_scope = 4: i32, semantics = 0x4 : i32} : (!spv.ptr<i32, StorageBuffer>, f32) -> (f32)
%0 = "spv.AtomicFAddEXT"(%ptr, %value) {memory_scope = #spv.scope<Workgroup>, semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<i32, StorageBuffer>, f32) -> (f32)
return %0 : f32
}
@ -261,7 +261,7 @@ func.func @atomic_fadd(%ptr : !spv.ptr<i32, StorageBuffer>, %value : f32) -> f32
func.func @atomic_fadd(%ptr : !spv.ptr<f32, StorageBuffer>, %value : f64) -> f64 {
// expected-error @+1 {{expected value to have the same type as the pointer operand's pointee type 'f32', but found 'f64'}}
%0 = "spv.AtomicFAddEXT"(%ptr, %value) {memory_scope = 2: i32, semantics = 0x8 : i32} : (!spv.ptr<f32, StorageBuffer>, f64) -> (f64)
%0 = "spv.AtomicFAddEXT"(%ptr, %value) {memory_scope = #spv.scope<Device>, semantics = #spv.memory_semantics<AcquireRelease>} : (!spv.ptr<f32, StorageBuffer>, f64) -> (f64)
return %0 : f64
}

View File

@ -26,7 +26,7 @@ func.func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
// CHECK: max version: v1.6
// CHECK: extensions: [ ]
// CHECK: capabilities: [ [GroupNonUniformBallot] ]
%0 = spv.GroupNonUniformBallot Workgroup %predicate : vector<4xi32>
%0 = spv.GroupNonUniformBallot <Workgroup> %predicate : vector<4xi32>
return %0: vector<4xi32>
}

View File

@ -5,16 +5,17 @@
//===----------------------------------------------------------------------===//
func.func @control_barrier_0() -> () {
// CHECK: spv.ControlBarrier Workgroup, Device, "Acquire|UniformMemory"
spv.ControlBarrier Workgroup, Device, "Acquire|UniformMemory"
// CHECK: spv.ControlBarrier <Workgroup>, <Device>, <Acquire|UniformMemory>
spv.ControlBarrier <Workgroup>, <Device>, <Acquire|UniformMemory>
return
}
// -----
func.func @control_barrier_1() -> () {
// expected-error @+1 {{expected string or keyword containing one of the following enum values}}
spv.ControlBarrier Something, Device, "Acquire|UniformMemory"
// expected-error @+2 {{to be one of}}
// expected-error @+1 {{failed to parse SPV_ScopeAttr}}
spv.ControlBarrier <Something>, <Device>, <Acquire|UniformMemory>
return
}
@ -26,16 +27,16 @@ func.func @control_barrier_1() -> () {
//===----------------------------------------------------------------------===//
func.func @memory_barrier_0() -> () {
// CHECK: spv.MemoryBarrier Device, "Acquire|UniformMemory"
spv.MemoryBarrier Device, "Acquire|UniformMemory"
// CHECK: spv.MemoryBarrier <Device>, <Acquire|UniformMemory>
spv.MemoryBarrier <Device>, <Acquire|UniformMemory>
return
}
// -----
func.func @memory_barrier_1() -> () {
// CHECK: spv.MemoryBarrier Workgroup, Acquire
spv.MemoryBarrier Workgroup, Acquire
// CHECK: spv.MemoryBarrier <Workgroup>, <Acquire>
spv.MemoryBarrier <Workgroup>, <Acquire>
return
}
@ -43,7 +44,7 @@ func.func @memory_barrier_1() -> () {
func.func @memory_barrier_2() -> () {
// expected-error @+1 {{expected at most one of these four memory constraints to be set: `Acquire`, `Release`,`AcquireRelease` or `SequentiallyConsistent`}}
spv.MemoryBarrier Device, "Acquire|Release"
spv.MemoryBarrier <Device>, <Acquire|Release>
return
}

View File

@ -17,24 +17,24 @@ func.func @subgroup_ballot(%predicate: i1) -> vector<4xi32> {
//===----------------------------------------------------------------------===//
func.func @group_broadcast_scalar(%value: f32, %localid: i32 ) -> f32 {
// CHECK: spv.GroupBroadcast Workgroup %{{.*}}, %{{.*}} : f32, i32
%0 = spv.GroupBroadcast Workgroup %value, %localid : f32, i32
// CHECK: spv.GroupBroadcast <Workgroup> %{{.*}}, %{{.*}} : f32, i32
%0 = spv.GroupBroadcast <Workgroup> %value, %localid : f32, i32
return %0: f32
}
// -----
func.func @group_broadcast_scalar_vector(%value: f32, %localid: vector<3xi32> ) -> f32 {
// CHECK: spv.GroupBroadcast Workgroup %{{.*}}, %{{.*}} : f32, vector<3xi32>
%0 = spv.GroupBroadcast Workgroup %value, %localid : f32, vector<3xi32>
// CHECK: spv.GroupBroadcast <Workgroup> %{{.*}}, %{{.*}} : f32, vector<3xi32>
%0 = spv.GroupBroadcast <Workgroup> %value, %localid : f32, vector<3xi32>
return %0: f32
}
// -----
func.func @group_broadcast_vector(%value: vector<4xf32>, %localid: vector<3xi32> ) -> vector<4xf32> {
// CHECK: spv.GroupBroadcast Subgroup %{{.*}}, %{{.*}} : vector<4xf32>, vector<3xi32>
%0 = spv.GroupBroadcast Subgroup %value, %localid : vector<4xf32>, vector<3xi32>
// CHECK: spv.GroupBroadcast <Subgroup> %{{.*}}, %{{.*}} : vector<4xf32>, vector<3xi32>
%0 = spv.GroupBroadcast <Subgroup> %value, %localid : vector<4xf32>, vector<3xi32>
return %0: vector<4xf32>
}
@ -42,7 +42,7 @@ func.func @group_broadcast_vector(%value: vector<4xf32>, %localid: vector<3xi32>
func.func @group_broadcast_negative_scope(%value: f32, %localid: vector<3xi32> ) -> f32 {
// expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
%0 = spv.GroupBroadcast Device %value, %localid : f32, vector<3xi32>
%0 = spv.GroupBroadcast <Device> %value, %localid : f32, vector<3xi32>
return %0: f32
}
@ -50,7 +50,7 @@ func.func @group_broadcast_negative_scope(%value: f32, %localid: vector<3xi32> )
func.func @group_broadcast_negative_locid_dtype(%value: f32, %localid: vector<3xf32> ) -> f32 {
// expected-error @+1 {{operand #1 must be 8/16/32/64-bit integer or vector of 8/16/32/64-bit integer values}}
%0 = spv.GroupBroadcast Subgroup %value, %localid : f32, vector<3xf32>
%0 = spv.GroupBroadcast <Subgroup> %value, %localid : f32, vector<3xf32>
return %0: f32
}
@ -58,7 +58,7 @@ func.func @group_broadcast_negative_locid_dtype(%value: f32, %localid: vector<3x
func.func @group_broadcast_negative_locid_vec4(%value: f32, %localid: vector<4xi32> ) -> f32 {
// expected-error @+1 {{localid is a vector and can be with only 2 or 3 components, actual number is 4}}
%0 = spv.GroupBroadcast Subgroup %value, %localid : f32, vector<4xi32>
%0 = spv.GroupBroadcast <Subgroup> %value, %localid : f32, vector<4xi32>
return %0: f32
}

View File

@ -198,7 +198,7 @@ func.func @load_none_access() -> () {
%0 = spv.Variable : !spv.ptr<f32, Function>
// CHECK: spv.Load
// CHECK-SAME: ["None"]
%1 = "spv.Load"(%0) {memory_access = 0 : i32} : (!spv.ptr<f32, Function>) -> (f32)
%1 = "spv.Load"(%0) {memory_access = #spv.memory_access<None>} : (!spv.ptr<f32, Function>) -> (f32)
return
}
@ -207,7 +207,7 @@ func.func @volatile_load() -> () {
%0 = spv.Variable : !spv.ptr<f32, Function>
// CHECK: spv.Load
// CHECK-SAME: ["Volatile"]
%1 = "spv.Load"(%0) {memory_access = 1 : i32} : (!spv.ptr<f32, Function>) -> (f32)
%1 = "spv.Load"(%0) {memory_access = #spv.memory_access<Volatile>} : (!spv.ptr<f32, Function>) -> (f32)
return
}
@ -216,7 +216,7 @@ func.func @aligned_load() -> () {
%0 = spv.Variable : !spv.ptr<f32, Function>
// CHECK: spv.Load
// CHECK-SAME: ["Aligned", 4]
%1 = "spv.Load"(%0) {memory_access = 2 : i32, alignment = 4 : i32} : (!spv.ptr<f32, Function>) -> (f32)
%1 = "spv.Load"(%0) {memory_access = #spv.memory_access<Aligned>, alignment = 4 : i32} : (!spv.ptr<f32, Function>) -> (f32)
return
}
@ -225,7 +225,7 @@ func.func @volatile_aligned_load() -> () {
%0 = spv.Variable : !spv.ptr<f32, Function>
// CHECK: spv.Load
// CHECK-SAME: ["Volatile|Aligned", 4]
%1 = "spv.Load"(%0) {memory_access = 3 : i32, alignment = 4 : i32} : (!spv.ptr<f32, Function>) -> (f32)
%1 = "spv.Load"(%0) {memory_access = #spv.memory_access<Volatile|Aligned>, alignment = 4 : i32} : (!spv.ptr<f32, Function>) -> (f32)
return
}
@ -588,7 +588,7 @@ func.func @copy_memory_invalid_maa() {
%0 = spv.Variable : !spv.ptr<f32, Function>
%1 = spv.Variable : !spv.ptr<f32, Function>
// expected-error @+1 {{missing alignment value}}
"spv.CopyMemory"(%0, %1) {memory_access=0x0002 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
"spv.CopyMemory"(%0, %1) {memory_access=#spv.memory_access<Aligned>} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
spv.Return
}
@ -598,7 +598,7 @@ func.func @copy_memory_invalid_source_maa() {
%0 = spv.Variable : !spv.ptr<f32, Function>
%1 = spv.Variable : !spv.ptr<f32, Function>
// expected-error @+1 {{invalid alignment specification with non-aligned memory access specification}}
"spv.CopyMemory"(%0, %1) {source_memory_access=0x0001 : i32, memory_access=0x0002 : i32, source_alignment=8 : i32, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
"spv.CopyMemory"(%0, %1) {source_memory_access=#spv.memory_access<Volatile>, memory_access=#spv.memory_access<Aligned>, source_alignment=8 : i32, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
spv.Return
}
@ -608,7 +608,7 @@ func.func @copy_memory_invalid_source_maa2() {
%0 = spv.Variable : !spv.ptr<f32, Function>
%1 = spv.Variable : !spv.ptr<f32, Function>
// expected-error @+1 {{missing alignment value}}
"spv.CopyMemory"(%0, %1) {source_memory_access=0x0002 : i32, memory_access=0x0002 : i32, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
"spv.CopyMemory"(%0, %1) {source_memory_access=#spv.memory_access<Aligned>, memory_access=#spv.memory_access<Aligned>, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
spv.Return
}
@ -619,16 +619,16 @@ func.func @copy_memory_print_maa() {
%1 = spv.Variable : !spv.ptr<f32, Function>
// CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Volatile"] : f32
"spv.CopyMemory"(%0, %1) {memory_access=0x0001 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
"spv.CopyMemory"(%0, %1) {memory_access=#spv.memory_access<Volatile>} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
// CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4] : f32
"spv.CopyMemory"(%0, %1) {memory_access=0x0002 : i32, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
"spv.CopyMemory"(%0, %1) {memory_access=#spv.memory_access<Aligned>, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
// CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4], ["Volatile"] : f32
"spv.CopyMemory"(%0, %1) {source_memory_access=0x0001 : i32, memory_access=0x0002 : i32, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
"spv.CopyMemory"(%0, %1) {source_memory_access=#spv.memory_access<Volatile>, memory_access=#spv.memory_access<Aligned>, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
// CHECK: spv.CopyMemory "Function" %{{.*}}, "Function" %{{.*}} ["Aligned", 4], ["Aligned", 8] : f32
"spv.CopyMemory"(%0, %1) {source_memory_access=0x0002 : i32, memory_access=0x0002 : i32, source_alignment=8 : i32, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
"spv.CopyMemory"(%0, %1) {source_memory_access=#spv.memory_access<Aligned>, memory_access=#spv.memory_access<Aligned>, source_alignment=8 : i32, alignment=4 : i32} : (!spv.ptr<f32, Function>, !spv.ptr<f32, Function>) -> ()
spv.Return
}

View File

@ -5,8 +5,8 @@
//===----------------------------------------------------------------------===//
func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> {
// CHECK: %{{.*}} = spv.GroupNonUniformBallot Workgroup %{{.*}}: vector<4xi32>
%0 = spv.GroupNonUniformBallot Workgroup %predicate : vector<4xi32>
// CHECK: %{{.*}} = spv.GroupNonUniformBallot <Workgroup> %{{.*}}: vector<4xi32>
%0 = spv.GroupNonUniformBallot <Workgroup> %predicate : vector<4xi32>
return %0: vector<4xi32>
}
@ -14,7 +14,7 @@ func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> {
func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> {
// expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
%0 = spv.GroupNonUniformBallot Device %predicate : vector<4xi32>
%0 = spv.GroupNonUniformBallot <Device> %predicate : vector<4xi32>
return %0: vector<4xi32>
}
@ -22,7 +22,7 @@ func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> {
func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xsi32> {
// expected-error @+1 {{op result #0 must be vector of 8/16/32/64-bit signless/unsigned integer values of length 4, but got 'vector<4xsi32>'}}
%0 = spv.GroupNonUniformBallot Workgroup %predicate : vector<4xsi32>
%0 = spv.GroupNonUniformBallot <Workgroup> %predicate : vector<4xsi32>
return %0: vector<4xsi32>
}
@ -34,8 +34,8 @@ func.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xsi32> {
func.func @group_non_uniform_broadcast_scalar(%value: f32) -> f32 {
%one = spv.Constant 1 : i32
// CHECK: spv.GroupNonUniformBroadcast Workgroup %{{.*}}, %{{.*}} : f32, i32
%0 = spv.GroupNonUniformBroadcast Workgroup %value, %one : f32, i32
// CHECK: spv.GroupNonUniformBroadcast <Workgroup> %{{.*}}, %{{.*}} : f32, i32
%0 = spv.GroupNonUniformBroadcast <Workgroup> %value, %one : f32, i32
return %0: f32
}
@ -43,8 +43,8 @@ func.func @group_non_uniform_broadcast_scalar(%value: f32) -> f32 {
func.func @group_non_uniform_broadcast_vector(%value: vector<4xf32>) -> vector<4xf32> {
%one = spv.Constant 1 : i32
// CHECK: spv.GroupNonUniformBroadcast Subgroup %{{.*}}, %{{.*}} : vector<4xf32>, i32
%0 = spv.GroupNonUniformBroadcast Subgroup %value, %one : vector<4xf32>, i32
// CHECK: spv.GroupNonUniformBroadcast <Subgroup> %{{.*}}, %{{.*}} : vector<4xf32>, i32
%0 = spv.GroupNonUniformBroadcast <Subgroup> %value, %one : vector<4xf32>, i32
return %0: vector<4xf32>
}
@ -53,7 +53,7 @@ func.func @group_non_uniform_broadcast_vector(%value: vector<4xf32>) -> vector<4
func.func @group_non_uniform_broadcast_negative_scope(%value: f32, %localid: i32 ) -> f32 {
%one = spv.Constant 1 : i32
// expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
%0 = spv.GroupNonUniformBroadcast Device %value, %one : f32, i32
%0 = spv.GroupNonUniformBroadcast <Device> %value, %one : f32, i32
return %0: f32
}
@ -61,7 +61,7 @@ func.func @group_non_uniform_broadcast_negative_scope(%value: f32, %localid: i32
func.func @group_non_uniform_broadcast_negative_non_const(%value: f32, %localid: i32) -> f32 {
// expected-error @+1 {{id must be the result of a constant op}}
%0 = spv.GroupNonUniformBroadcast Subgroup %value, %localid : f32, i32
%0 = spv.GroupNonUniformBroadcast <Subgroup> %value, %localid : f32, i32
return %0: f32
}
@ -73,8 +73,8 @@ func.func @group_non_uniform_broadcast_negative_non_const(%value: f32, %localid:
// CHECK-LABEL: @group_non_uniform_elect
func.func @group_non_uniform_elect() -> i1 {
// CHECK: %{{.+}} = spv.GroupNonUniformElect Workgroup : i1
%0 = spv.GroupNonUniformElect Workgroup : i1
// CHECK: %{{.+}} = spv.GroupNonUniformElect <Workgroup> : i1
%0 = spv.GroupNonUniformElect <Workgroup> : i1
return %0: i1
}
@ -82,7 +82,7 @@ func.func @group_non_uniform_elect() -> i1 {
func.func @group_non_uniform_elect() -> i1 {
// expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
%0 = spv.GroupNonUniformElect CrossDevice : i1
%0 = spv.GroupNonUniformElect <CrossDevice> : i1
return %0: i1
}

View File

@ -819,7 +819,7 @@ spv.module Logical GLSL450 {
%0 = spv.Variable : !spv.ptr<i32, Function>
// expected-error @+1 {{invalid enclosed op}}
%1 = spv.SpecConstantOperation wraps "spv.Load"(%0) {memory_access = 0 : i32} : (!spv.ptr<i32, Function>) -> i32
%1 = spv.SpecConstantOperation wraps "spv.Load"(%0) {memory_access = #spv.memory_access<None>} : (!spv.ptr<i32, Function>) -> i32
spv.Return
}
}

View File

@ -165,11 +165,11 @@ func.func @target_env_cooperative_matrix() attributes{
// CHECK-SAME: #spv.coop_matrix_props<
// CHECK-SAME: m_size = 8, n_size = 8, k_size = 32,
// CHECK-SAME: a_type = i8, b_type = i8, c_type = i32,
// CHECK-SAME: result_type = i32, scope = 3 : i32>
// CHECK-SAME: result_type = i32, scope = <Subgroup>>
// CHECK-SAME: #spv.coop_matrix_props<
// CHECK-SAME: m_size = 8, n_size = 8, k_size = 16,
// CHECK-SAME: a_type = f16, b_type = f16, c_type = f16,
// CHECK-SAME: result_type = f16, scope = 3 : i32>
// CHECK-SAME: result_type = f16, scope = <Subgroup>>
spv.target_env = #spv.target_env<
#spv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class,
SPV_NV_cooperative_matrix]>,
@ -182,7 +182,7 @@ func.func @target_env_cooperative_matrix() attributes{
b_type = i8,
c_type = i32,
result_type = i32,
scope = 3 : i32
scope = #spv.scope<Subgroup>
>, #spv.coop_matrix_props<
m_size = 8,
n_size = 8,
@ -191,7 +191,7 @@ func.func @target_env_cooperative_matrix() attributes{
b_type = f16,
c_type = f16,
result_type = f16,
scope = 3 : i32
scope = #spv.scope<Subgroup>
>]
>>
} { return }

View File

@ -59,7 +59,7 @@ func.func @cmp_exchange_weak_unsupported_version(%ptr: !spv.ptr<i32, Workgroup>,
func.func @group_non_uniform_ballot_suitable_version(%predicate: i1) -> vector<4xi32> attributes {
spv.target_env = #spv.target_env<#spv.vce<v1.4, [GroupNonUniformBallot], []>, #spv.resource_limits<>>
} {
// CHECK: spv.GroupNonUniformBallot Workgroup
// CHECK: spv.GroupNonUniformBallot <Workgroup>
%0 = "test.convert_to_group_non_uniform_ballot_op"(%predicate): (i1) -> (vector<4xi32>)
return %0: vector<4xi32>
}

View File

@ -27,7 +27,7 @@ spv.module Logical GLSL450 attributes {
#spv.vce<v1.5, [Shader, GroupNonUniformBallot], []>, #spv.resource_limits<>>
} {
spv.func @group_non_uniform_ballot(%predicate : i1) -> vector<4xi32> "None" {
%0 = spv.GroupNonUniformBallot Workgroup %predicate : vector<4xi32>
%0 = spv.GroupNonUniformBallot <Workgroup> %predicate : vector<4xi32>
spv.ReturnValue %0: vector<4xi32>
}
}

View File

@ -2,23 +2,23 @@
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
spv.func @memory_barrier_0() -> () "None" {
// CHECK: spv.MemoryBarrier Device, "Release|UniformMemory"
spv.MemoryBarrier Device, "Release|UniformMemory"
// CHECK: spv.MemoryBarrier <Device>, <Release|UniformMemory>
spv.MemoryBarrier <Device>, <Release|UniformMemory>
spv.Return
}
spv.func @memory_barrier_1() -> () "None" {
// CHECK: spv.MemoryBarrier Subgroup, "AcquireRelease|SubgroupMemory"
spv.MemoryBarrier Subgroup, "AcquireRelease|SubgroupMemory"
// CHECK: spv.MemoryBarrier <Subgroup>, <AcquireRelease|SubgroupMemory>
spv.MemoryBarrier <Subgroup>, <AcquireRelease|SubgroupMemory>
spv.Return
}
spv.func @control_barrier_0() -> () "None" {
// CHECK: spv.ControlBarrier Device, Workgroup, "Release|UniformMemory"
spv.ControlBarrier Device, Workgroup, "Release|UniformMemory"
// CHECK: spv.ControlBarrier <Device>, <Workgroup>, <Release|UniformMemory>
spv.ControlBarrier <Device>, <Workgroup>, <Release|UniformMemory>
spv.Return
}
spv.func @control_barrier_1() -> () "None" {
// CHECK: spv.ControlBarrier Workgroup, Invocation, "AcquireRelease|UniformMemory"
spv.ControlBarrier Workgroup, Invocation, "AcquireRelease|UniformMemory"
// CHECK: spv.ControlBarrier <Workgroup>, <Invocation>, <AcquireRelease|UniformMemory>
spv.ControlBarrier <Workgroup>, <Invocation>, <AcquireRelease|UniformMemory>
spv.Return
}
}

View File

@ -9,14 +9,14 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
}
// CHECK-LABEL: @group_broadcast_1
spv.func @group_broadcast_1(%value: f32, %localid: i32 ) -> f32 "None" {
// CHECK: spv.GroupBroadcast Workgroup %{{.*}}, %{{.*}} : f32, i32
%0 = spv.GroupBroadcast Workgroup %value, %localid : f32, i32
// CHECK: spv.GroupBroadcast <Workgroup> %{{.*}}, %{{.*}} : f32, i32
%0 = spv.GroupBroadcast <Workgroup> %value, %localid : f32, i32
spv.ReturnValue %0: f32
}
// CHECK-LABEL: @group_broadcast_2
spv.func @group_broadcast_2(%value: f32, %localid: vector<3xi32> ) -> f32 "None" {
// CHECK: spv.GroupBroadcast Workgroup %{{.*}}, %{{.*}} : f32, vector<3xi32>
%0 = spv.GroupBroadcast Workgroup %value, %localid : f32, vector<3xi32>
// CHECK: spv.GroupBroadcast <Workgroup> %{{.*}}, %{{.*}} : f32, vector<3xi32>
%0 = spv.GroupBroadcast <Workgroup> %value, %localid : f32, vector<3xi32>
spv.ReturnValue %0: f32
}
// CHECK-LABEL: @subgroup_block_read_intel

View File

@ -3,23 +3,23 @@
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
// CHECK-LABEL: @group_non_uniform_ballot
spv.func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> "None" {
// CHECK: %{{.*}} = spv.GroupNonUniformBallot Workgroup %{{.*}}: vector<4xi32>
%0 = spv.GroupNonUniformBallot Workgroup %predicate : vector<4xi32>
// CHECK: %{{.*}} = spv.GroupNonUniformBallot <Workgroup> %{{.*}}: vector<4xi32>
%0 = spv.GroupNonUniformBallot <Workgroup> %predicate : vector<4xi32>
spv.ReturnValue %0: vector<4xi32>
}
// CHECK-LABEL: @group_non_uniform_broadcast
spv.func @group_non_uniform_broadcast(%value: f32) -> f32 "None" {
%one = spv.Constant 1 : i32
// CHECK: spv.GroupNonUniformBroadcast Subgroup %{{.*}}, %{{.*}} : f32, i32
%0 = spv.GroupNonUniformBroadcast Subgroup %value, %one : f32, i32
// CHECK: spv.GroupNonUniformBroadcast <Subgroup> %{{.*}}, %{{.*}} : f32, i32
%0 = spv.GroupNonUniformBroadcast <Subgroup> %value, %one : f32, i32
spv.ReturnValue %0: f32
}
// CHECK-LABEL: @group_non_uniform_elect
spv.func @group_non_uniform_elect() -> i1 "None" {
// CHECK: %{{.+}} = spv.GroupNonUniformElect Workgroup : i1
%0 = spv.GroupNonUniformElect Workgroup : i1
// CHECK: %{{.+}} = spv.GroupNonUniformElect <Workgroup> : i1
%0 = spv.GroupNonUniformElect <Workgroup> : i1
spv.ReturnValue %0: i1
}

View File

@ -519,10 +519,24 @@ static void emitAttributeSerialization(const Attribute &attr,
<< formatv("if (auto attr = {0}->getAttr(\"{1}\")) {{\n", opVar, attrName);
if (attr.getAttrDefName() == "SPV_ScopeAttr" ||
attr.getAttrDefName() == "SPV_MemorySemanticsAttr") {
// These two enums are encoded as <id> to constant values in SPIR-V blob,
// but we directly use the constant value as attribute in SPIR-V dialect. So
// need to handle them separately from normal enum attributes.
EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
os << tabs
<< formatv(" {0}.push_back(prepareConstantInt({1}.getLoc(), "
"attr.cast<IntegerAttr>()));\n",
operandList, opVar);
"Builder({1}).getI32IntegerAttr(static_cast<uint32_t>("
"attr.cast<{2}::{3}Attr>().getValue()))));\n",
operandList, opVar, baseEnum.getCppNamespace(),
baseEnum.getEnumClassName());
} else if (attr.isSubClassOf("SPV_BitEnumAttr") ||
attr.isSubClassOf("SPV_I32EnumAttr")) {
EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
os << tabs
<< formatv(" {0}.push_back(static_cast<uint32_t>("
"attr.cast<{1}::{2}Attr>().getValue()));\n",
operandList, baseEnum.getCppNamespace(),
baseEnum.getEnumClassName());
} else if (attr.getAttrDefName() == "I32ArrayAttr") {
// Serialize all the elements of the array
os << tabs << " for (auto attrElem : attr.cast<ArrayAttr>()) {\n";
@ -531,7 +545,7 @@ static void emitAttributeSerialization(const Attribute &attr,
"attrElem.cast<IntegerAttr>().getValue().getZExtValue()));\n",
operandList);
os << tabs << " }\n";
} else if (attr.isEnumAttr() || attr.getAttrDefName() == "I32Attr") {
} else if (attr.getAttrDefName() == "I32Attr") {
os << tabs
<< formatv(" {0}.push_back(static_cast<uint32_t>("
"attr.cast<IntegerAttr>().getValue().getZExtValue()));\n",
@ -797,10 +811,25 @@ static void emitAttributeDeserialization(const Attribute &attr,
raw_ostream &os) {
if (attr.getAttrDefName() == "SPV_ScopeAttr" ||
attr.getAttrDefName() == "SPV_MemorySemanticsAttr") {
// These two enums are encoded as <id> to constant values in SPIR-V blob,
// but we directly use the constant value as attribute in SPIR-V dialect. So
// need to handle them separately from normal enum attributes.
EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
os << tabs
<< formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
"getConstantInt({2}[{3}++])));\n",
attrList, attrName, words, wordIndex);
"opBuilder.getAttr<{2}::{3}Attr>(static_cast<{2}::{3}>("
"getConstantInt({4}[{5}++]).getValue().getZExtValue()))));\n",
attrList, attrName, baseEnum.getCppNamespace(),
baseEnum.getEnumClassName(), words, wordIndex);
} else if (attr.isSubClassOf("SPV_BitEnumAttr") ||
attr.isSubClassOf("SPV_I32EnumAttr")) {
EnumAttr baseEnum(attr.getDef().getValueAsDef("enum"));
os << tabs
<< formatv(" {0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
"opBuilder.getAttr<{2}::{3}Attr>("
"static_cast<{2}::{3}>({4}[{5}++]))));\n",
attrList, attrName, baseEnum.getCppNamespace(),
baseEnum.getEnumClassName(), words, wordIndex);
} else if (attr.getAttrDefName() == "I32ArrayAttr") {
os << tabs << "SmallVector<Attribute, 4> attrListElems;\n";
os << tabs << formatv("while ({0} < {1}.size()) {{\n", wordIndex, words);
@ -815,7 +844,7 @@ static void emitAttributeDeserialization(const Attribute &attr,
<< formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
"opBuilder.getArrayAttr(attrListElems)));\n",
attrList, attrName);
} else if (attr.isEnumAttr() || attr.getAttrDefName() == "I32Attr") {
} else if (attr.getAttrDefName() == "I32Attr") {
os << tabs
<< formatv("{0}.push_back(opBuilder.getNamedAttr(\"{1}\", "
"opBuilder.getI32IntegerAttr({2}[{3}++])));\n",
@ -1257,11 +1286,12 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
for (const Availability &avail : opAvailabilities)
availClasses.try_emplace(avail.getClass(), avail);
for (const NamedAttribute &namedAttr : srcOp.getAttributes()) {
const auto *enumAttr = llvm::dyn_cast<EnumAttr>(&namedAttr.attr);
if (!enumAttr)
if (!namedAttr.attr.isSubClassOf("SPV_BitEnumAttr") &&
!namedAttr.attr.isSubClassOf("SPV_I32EnumAttr"))
continue;
EnumAttr enumAttr(namedAttr.attr.getDef().getValueAsDef("enum"));
for (const EnumAttrCase &enumerant : enumAttr->getAllCases())
for (const EnumAttrCase &enumerant : enumAttr.getAllCases())
for (const Availability &caseAvail :
getAvailabilities(enumerant.getDef()))
availClasses.try_emplace(caseAvail.getClass(), caseAvail);
@ -1298,16 +1328,17 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
// Update with enum attributes' specific availability spec.
for (const NamedAttribute &namedAttr : srcOp.getAttributes()) {
const auto *enumAttr = llvm::dyn_cast<EnumAttr>(&namedAttr.attr);
if (!enumAttr)
if (!namedAttr.attr.isSubClassOf("SPV_BitEnumAttr") &&
!namedAttr.attr.isSubClassOf("SPV_I32EnumAttr"))
continue;
EnumAttr enumAttr(namedAttr.attr.getDef().getValueAsDef("enum"));
// (enumerant, availability specification) pairs for this availability
// class.
SmallVector<std::pair<EnumAttrCase, Availability>, 1> caseSpecs;
// Collect all cases' availability specs.
for (const EnumAttrCase &enumerant : enumAttr->getAllCases())
for (const EnumAttrCase &enumerant : enumAttr.getAllCases())
for (const Availability &caseAvail :
getAvailabilities(enumerant.getDef()))
if (availClassName == caseAvail.getClass())
@ -1318,19 +1349,19 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
if (caseSpecs.empty())
continue;
if (enumAttr->isBitEnum()) {
if (enumAttr.isBitEnum()) {
// For BitEnumAttr, we need to iterate over each bit to query its
// availability spec.
os << formatv(" for (unsigned i = 0; "
"i < std::numeric_limits<{0}>::digits; ++i) {{\n",
enumAttr->getUnderlyingType());
enumAttr.getUnderlyingType());
os << formatv(" {0}::{1} tblgen_attrVal = this->{2}() & "
"static_cast<{0}::{1}>(1 << i);\n",
enumAttr->getCppNamespace(), enumAttr->getEnumClassName(),
enumAttr.getCppNamespace(), enumAttr.getEnumClassName(),
namedAttr.name);
os << formatv(
" if (static_cast<{0}>(tblgen_attrVal) == 0) continue;\n",
enumAttr->getUnderlyingType());
enumAttr.getUnderlyingType());
} else {
// For IntEnumAttr, we just need to query the value as a whole.
os << " {\n";
@ -1338,7 +1369,7 @@ static void emitAvailabilityImpl(const Operator &srcOp, raw_ostream &os) {
namedAttr.name);
}
os << formatv(" auto tblgen_instance = {0}::{1}(tblgen_attrVal);\n",
enumAttr->getCppNamespace(), avail.getQueryFnName());
enumAttr.getCppNamespace(), avail.getQueryFnName());
os << " if (tblgen_instance) "
// TODO` here once ODS supports
// dialect-specific contents so that we can use not implementing the
@ -1385,7 +1416,8 @@ static bool emitCapabilityImplication(const RecordKeeper &recordKeeper,
raw_ostream &os) {
llvm::emitSourceFileHeader("SPIR-V Capability Implication", os);
EnumAttr enumAttr(recordKeeper.getDef("SPV_CapabilityAttr"));
EnumAttr enumAttr(
recordKeeper.getDef("SPV_CapabilityAttr")->getValueAsDef("enum"));
os << "ArrayRef<spirv::Capability> "
"spirv::getDirectImpliedCapabilities(spirv::Capability cap) {\n"

View File

@ -14,6 +14,7 @@
#include "mlir/Target/SPIRV/Serialization.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/IR/Builders.h"
@ -46,11 +47,10 @@ protected:
OperationState state(UnknownLoc::get(&context),
spirv::ModuleOp::getOperationName());
state.addAttribute("addressing_model",
builder.getI32IntegerAttr(static_cast<uint32_t>(
spirv::AddressingModel::Logical)));
state.addAttribute("memory_model",
builder.getI32IntegerAttr(
static_cast<uint32_t>(spirv::MemoryModel::GLSL450)));
builder.getAttr<spirv::AddressingModelAttr>(
spirv::AddressingModel::Logical));
state.addAttribute("memory_model", builder.getAttr<spirv::MemoryModelAttr>(
spirv::MemoryModel::GLSL450));
state.addAttribute("vce_triple",
spirv::VerCapExtAttr::get(
spirv::Version::V_1_0, ArrayRef<spirv::Capability>(),

View File

@ -437,10 +437,13 @@ def gen_operand_kind_enum_attr(operand_kind, capability_mapping):
# Generate the enum attribute definition
kind_category = 'Bit' if is_bit_enum else 'I32'
enum_attr = '''def SPV_{name}Attr :
SPV_{category}EnumAttr<"{name}", "valid SPIR-V {name}", [
SPV_{category}EnumAttr<"{name}", "valid SPIR-V {name}", "{snake_name}", [
{cases}
]>;'''.format(
name=kind_name, category=kind_category, cases=case_names)
name=kind_name,
snake_name=snake_casify(kind_name),
category=kind_category,
cases=case_names)
return kind_name, case_defs + '\n\n' + enum_attr
@ -473,7 +476,8 @@ def gen_opcode(instructions):
]
opcode_list = ',\n'.join(opcode_list)
enum_attr = 'def SPV_OpcodeAttr :\n'\
' SPV_I32EnumAttr<"{name}", "valid SPIR-V instructions", [\n'\
' SPV_I32EnumAttr<"{name}", "valid SPIR-V instructions", '\
'"opcode", [\n'\
'{lst}\n'\
' ]>;'.format(name='Opcode', lst=opcode_list)
return opcode_str + '\n\n' + enum_attr
@ -630,9 +634,7 @@ def update_td_enum_attrs(path, operand_kinds, filter_list):
def snake_casify(name):
"""Turns the given name to follow snake_case convention."""
name = re.sub('\W+', '', name).split()
name = [s.lower() for s in name]
return '_'.join(name)
return re.sub(r'(?<!^)(?=[A-Z])', '_', name).lower()
def map_spec_operand_to_ods_argument(operand):