[SPIR-V](5/6) Add LegalizerInfo, InstructionSelector and utilities

The patch adds SPIRVLegalizerInfo, SPIRVInstructionSelector and
SPIRV-specific utilities.

Differential Revision: https://reviews.llvm.org/D116464

Authors: Aleksandr Bezzubikov, Lewis Crawford, Ilia Diachkov,
Michal Paszkowski, Andrey Tretyakov, Konrad Trifunovic

Co-authored-by: Aleksandr Bezzubikov <zuban32s@gmail.com>
Co-authored-by: Ilia Diachkov <iliya.diyachkov@intel.com>
Co-authored-by: Michal Paszkowski <michal.paszkowski@outlook.com>
Co-authored-by: Andrey Tretyakov <andrey1.tretyakov@intel.com>
Co-authored-by: Konrad Trifunovic <konrad.trifunovic@intel.com>
This commit is contained in:
Ilia Diachkov 2022-04-14 01:11:15 +03:00 committed by Michal Paszkowski
parent ec2590362e
commit eab7d3639b
17 changed files with 4323 additions and 20 deletions

View File

@ -15,13 +15,17 @@ add_public_tablegen_target(SPIRVCommonTableGen)
add_llvm_target(SPIRVCodeGen
SPIRVAsmPrinter.cpp
SPIRVCallLowering.cpp
SPIRVGlobalRegistry.cpp
SPIRVInstrInfo.cpp
SPIRVInstructionSelector.cpp
SPIRVISelLowering.cpp
SPIRVLegalizerInfo.cpp
SPIRVMCInstLower.cpp
SPIRVRegisterBankInfo.cpp
SPIRVRegisterInfo.cpp
SPIRVSubtarget.cpp
SPIRVTargetMachine.cpp
SPIRVUtils.cpp
LINK_COMPONENTS
Analysis

View File

@ -1,4 +1,5 @@
add_llvm_component_library(LLVMSPIRVDesc
SPIRVBaseInfo.cpp
SPIRVMCAsmInfo.cpp
SPIRVMCTargetDesc.cpp
SPIRVTargetStreamer.cpp

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,739 @@
//===-- SPIRVBaseInfo.h - Top level definitions for SPIRV ------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains small standalone helper functions and enum definitions for
// the SPIRV target useful for the compiler back-end and the MC libraries.
// As such, it deliberately does not include references to LLVM core
// code gen types, passes, etc..
//
//===----------------------------------------------------------------------===//
#ifndef LLVM_LIB_TARGET_SPIRV_MCTARGETDESC_SPIRVBASEINFO_H
#define LLVM_LIB_TARGET_SPIRV_MCTARGETDESC_SPIRVBASEINFO_H
#include "llvm/ADT/StringRef.h"
#include <string>
namespace llvm {
namespace SPIRV {
enum class Capability : uint32_t {
Matrix = 0,
Shader = 1,
Geometry = 2,
Tessellation = 3,
Addresses = 4,
Linkage = 5,
Kernel = 6,
Vector16 = 7,
Float16Buffer = 8,
Float16 = 9,
Float64 = 10,
Int64 = 11,
Int64Atomics = 12,
ImageBasic = 13,
ImageReadWrite = 14,
ImageMipmap = 15,
Pipes = 17,
Groups = 18,
DeviceEnqueue = 19,
LiteralSampler = 20,
AtomicStorage = 21,
Int16 = 22,
TessellationPointSize = 23,
GeometryPointSize = 24,
ImageGatherExtended = 25,
StorageImageMultisample = 27,
UniformBufferArrayDynamicIndexing = 28,
SampledImageArrayDymnamicIndexing = 29,
ClipDistance = 32,
CullDistance = 33,
ImageCubeArray = 34,
SampleRateShading = 35,
ImageRect = 36,
SampledRect = 37,
GenericPointer = 38,
Int8 = 39,
InputAttachment = 40,
SparseResidency = 41,
MinLod = 42,
Sampled1D = 43,
Image1D = 44,
SampledCubeArray = 45,
SampledBuffer = 46,
ImageBuffer = 47,
ImageMSArray = 48,
StorageImageExtendedFormats = 49,
ImageQuery = 50,
DerivativeControl = 51,
InterpolationFunction = 52,
TransformFeedback = 53,
GeometryStreams = 54,
StorageImageReadWithoutFormat = 55,
StorageImageWriteWithoutFormat = 56,
MultiViewport = 57,
SubgroupDispatch = 58,
NamedBarrier = 59,
PipeStorage = 60,
GroupNonUniform = 61,
GroupNonUniformVote = 62,
GroupNonUniformArithmetic = 63,
GroupNonUniformBallot = 64,
GroupNonUniformShuffle = 65,
GroupNonUniformShuffleRelative = 66,
GroupNonUniformClustered = 67,
GroupNonUniformQuad = 68,
SubgroupBallotKHR = 4423,
DrawParameters = 4427,
SubgroupVoteKHR = 4431,
StorageBuffer16BitAccess = 4433,
StorageUniform16 = 4434,
StoragePushConstant16 = 4435,
StorageInputOutput16 = 4436,
DeviceGroup = 4437,
MultiView = 4439,
VariablePointersStorageBuffer = 4441,
VariablePointers = 4442,
AtomicStorageOps = 4445,
SampleMaskPostDepthCoverage = 4447,
StorageBuffer8BitAccess = 4448,
UniformAndStorageBuffer8BitAccess = 4449,
StoragePushConstant8 = 4450,
DenormPreserve = 4464,
DenormFlushToZero = 4465,
SignedZeroInfNanPreserve = 4466,
RoundingModeRTE = 4467,
RoundingModeRTZ = 4468,
Float16ImageAMD = 5008,
ImageGatherBiasLodAMD = 5009,
FragmentMaskAMD = 5010,
StencilExportEXT = 5013,
ImageReadWriteLodAMD = 5015,
SampleMaskOverrideCoverageNV = 5249,
GeometryShaderPassthroughNV = 5251,
ShaderViewportIndexLayerEXT = 5254,
ShaderViewportMaskNV = 5255,
ShaderStereoViewNV = 5259,
PerViewAttributesNV = 5260,
FragmentFullyCoveredEXT = 5265,
MeshShadingNV = 5266,
ShaderNonUniformEXT = 5301,
RuntimeDescriptorArrayEXT = 5302,
InputAttachmentArrayDynamicIndexingEXT = 5303,
UniformTexelBufferArrayDynamicIndexingEXT = 5304,
StorageTexelBufferArrayDynamicIndexingEXT = 5305,
UniformBufferArrayNonUniformIndexingEXT = 5306,
SampledImageArrayNonUniformIndexingEXT = 5307,
StorageBufferArrayNonUniformIndexingEXT = 5308,
StorageImageArrayNonUniformIndexingEXT = 5309,
InputAttachmentArrayNonUniformIndexingEXT = 5310,
UniformTexelBufferArrayNonUniformIndexingEXT = 5311,
StorageTexelBufferArrayNonUniformIndexingEXT = 5312,
RayTracingNV = 5340,
SubgroupShuffleINTEL = 5568,
SubgroupBufferBlockIOINTEL = 5569,
SubgroupImageBlockIOINTEL = 5570,
SubgroupImageMediaBlockIOINTEL = 5579,
SubgroupAvcMotionEstimationINTEL = 5696,
SubgroupAvcMotionEstimationIntraINTEL = 5697,
SubgroupAvcMotionEstimationChromaINTEL = 5698,
GroupNonUniformPartitionedNV = 5297,
VulkanMemoryModelKHR = 5345,
VulkanMemoryModelDeviceScopeKHR = 5346,
ImageFootprintNV = 5282,
FragmentBarycentricNV = 5284,
ComputeDerivativeGroupQuadsNV = 5288,
ComputeDerivativeGroupLinearNV = 5350,
FragmentDensityEXT = 5291,
PhysicalStorageBufferAddressesEXT = 5347,
CooperativeMatrixNV = 5357,
};
StringRef getCapabilityName(Capability e);
enum class SourceLanguage : uint32_t {
Unknown = 0,
ESSL = 1,
GLSL = 2,
OpenCL_C = 3,
OpenCL_CPP = 4,
HLSL = 5,
};
StringRef getSourceLanguageName(SourceLanguage e);
enum class AddressingModel : uint32_t {
Logical = 0,
Physical32 = 1,
Physical64 = 2,
PhysicalStorageBuffer64EXT = 5348,
};
StringRef getAddressingModelName(AddressingModel e);
enum class ExecutionModel : uint32_t {
Vertex = 0,
TessellationControl = 1,
TessellationEvaluation = 2,
Geometry = 3,
Fragment = 4,
GLCompute = 5,
Kernel = 6,
TaskNV = 5267,
MeshNV = 5268,
RayGenerationNV = 5313,
IntersectionNV = 5314,
AnyHitNV = 5315,
ClosestHitNV = 5316,
MissNV = 5317,
CallableNV = 5318,
};
StringRef getExecutionModelName(ExecutionModel e);
enum class MemoryModel : uint32_t {
Simple = 0,
GLSL450 = 1,
OpenCL = 2,
VulkanKHR = 3,
};
StringRef getMemoryModelName(MemoryModel e);
enum class ExecutionMode : uint32_t {
Invocations = 0,
SpacingEqual = 1,
SpacingFractionalEven = 2,
SpacingFractionalOdd = 3,
VertexOrderCw = 4,
VertexOrderCcw = 5,
PixelCenterInteger = 6,
OriginUpperLeft = 7,
OriginLowerLeft = 8,
EarlyFragmentTests = 9,
PointMode = 10,
Xfb = 11,
DepthReplacing = 12,
DepthGreater = 14,
DepthLess = 15,
DepthUnchanged = 16,
LocalSize = 17,
LocalSizeHint = 18,
InputPoints = 19,
InputLines = 20,
InputLinesAdjacency = 21,
Triangles = 22,
InputTrianglesAdjacency = 23,
Quads = 24,
Isolines = 25,
OutputVertices = 26,
OutputPoints = 27,
OutputLineStrip = 28,
OutputTriangleStrip = 29,
VecTypeHint = 30,
ContractionOff = 31,
Initializer = 33,
Finalizer = 34,
SubgroupSize = 35,
SubgroupsPerWorkgroup = 36,
SubgroupsPerWorkgroupId = 37,
LocalSizeId = 38,
LocalSizeHintId = 39,
PostDepthCoverage = 4446,
DenormPreserve = 4459,
DenormFlushToZero = 4460,
SignedZeroInfNanPreserve = 4461,
RoundingModeRTE = 4462,
RoundingModeRTZ = 4463,
StencilRefReplacingEXT = 5027,
OutputLinesNV = 5269,
DerivativeGroupQuadsNV = 5289,
DerivativeGroupLinearNV = 5290,
OutputTrianglesNV = 5298,
};
StringRef getExecutionModeName(ExecutionMode e);
enum class StorageClass : uint32_t {
UniformConstant = 0,
Input = 1,
Uniform = 2,
Output = 3,
Workgroup = 4,
CrossWorkgroup = 5,
Private = 6,
Function = 7,
Generic = 8,
PushConstant = 9,
AtomicCounter = 10,
Image = 11,
StorageBuffer = 12,
CallableDataNV = 5328,
IncomingCallableDataNV = 5329,
RayPayloadNV = 5338,
HitAttributeNV = 5339,
IncomingRayPayloadNV = 5342,
ShaderRecordBufferNV = 5343,
PhysicalStorageBufferEXT = 5349,
};
StringRef getStorageClassName(StorageClass e);
enum class Dim : uint32_t {
DIM_1D = 0,
DIM_2D = 1,
DIM_3D = 2,
DIM_Cube = 3,
DIM_Rect = 4,
DIM_Buffer = 5,
DIM_SubpassData = 6,
};
StringRef getDimName(Dim e);
enum class SamplerAddressingMode : uint32_t {
None = 0,
ClampToEdge = 1,
Clamp = 2,
Repeat = 3,
RepeatMirrored = 4,
};
StringRef getSamplerAddressingModeName(SamplerAddressingMode e);
enum class SamplerFilterMode : uint32_t {
Nearest = 0,
Linear = 1,
};
StringRef getSamplerFilterModeName(SamplerFilterMode e);
enum class ImageFormat : uint32_t {
Unknown = 0,
Rgba32f = 1,
Rgba16f = 2,
R32f = 3,
Rgba8 = 4,
Rgba8Snorm = 5,
Rg32f = 6,
Rg16f = 7,
R11fG11fB10f = 8,
R16f = 9,
Rgba16 = 10,
Rgb10A2 = 11,
Rg16 = 12,
Rg8 = 13,
R16 = 14,
R8 = 15,
Rgba16Snorm = 16,
Rg16Snorm = 17,
Rg8Snorm = 18,
R16Snorm = 19,
R8Snorm = 20,
Rgba32i = 21,
Rgba16i = 22,
Rgba8i = 23,
R32i = 24,
Rg32i = 25,
Rg16i = 26,
Rg8i = 27,
R16i = 28,
R8i = 29,
Rgba32ui = 30,
Rgba16ui = 31,
Rgba8ui = 32,
R32ui = 33,
Rgb10a2ui = 34,
Rg32ui = 35,
Rg16ui = 36,
Rg8ui = 37,
R16ui = 38,
R8ui = 39,
};
StringRef getImageFormatName(ImageFormat e);
enum class ImageChannelOrder : uint32_t {
R = 0,
A = 1,
RG = 2,
RA = 3,
RGB = 4,
RGBA = 5,
BGRA = 6,
ARGB = 7,
Intensity = 8,
Luminance = 9,
Rx = 10,
RGx = 11,
RGBx = 12,
Depth = 13,
DepthStencil = 14,
sRGB = 15,
sRGBx = 16,
sRGBA = 17,
sBGRA = 18,
ABGR = 19,
};
StringRef getImageChannelOrderName(ImageChannelOrder e);
enum class ImageChannelDataType : uint32_t {
SnormInt8 = 0,
SnormInt16 = 1,
UnormInt8 = 2,
UnormInt16 = 3,
UnormShort565 = 4,
UnormShort555 = 5,
UnormInt101010 = 6,
SignedInt8 = 7,
SignedInt16 = 8,
SignedInt32 = 9,
UnsignedInt8 = 10,
UnsignedInt16 = 11,
UnsigendInt32 = 12,
HalfFloat = 13,
Float = 14,
UnormInt24 = 15,
UnormInt101010_2 = 16,
};
StringRef getImageChannelDataTypeName(ImageChannelDataType e);
enum class ImageOperand : uint32_t {
None = 0x0,
Bias = 0x1,
Lod = 0x2,
Grad = 0x4,
ConstOffset = 0x8,
Offset = 0x10,
ConstOffsets = 0x20,
Sample = 0x40,
MinLod = 0x80,
MakeTexelAvailableKHR = 0x100,
MakeTexelVisibleKHR = 0x200,
NonPrivateTexelKHR = 0x400,
VolatileTexelKHR = 0x800,
SignExtend = 0x1000,
ZeroExtend = 0x2000,
};
std::string getImageOperandName(uint32_t e);
enum class FPFastMathMode : uint32_t {
None = 0x0,
NotNaN = 0x1,
NotInf = 0x2,
NSZ = 0x4,
AllowRecip = 0x8,
Fast = 0x10,
};
std::string getFPFastMathModeName(uint32_t e);
enum class FPRoundingMode : uint32_t {
RTE = 0,
RTZ = 1,
RTP = 2,
RTN = 3,
};
StringRef getFPRoundingModeName(FPRoundingMode e);
enum class LinkageType : uint32_t {
Export = 0,
Import = 1,
};
StringRef getLinkageTypeName(LinkageType e);
enum class AccessQualifier : uint32_t {
ReadOnly = 0,
WriteOnly = 1,
ReadWrite = 2,
};
StringRef getAccessQualifierName(AccessQualifier e);
enum class FunctionParameterAttribute : uint32_t {
Zext = 0,
Sext = 1,
ByVal = 2,
Sret = 3,
NoAlias = 4,
NoCapture = 5,
NoWrite = 6,
NoReadWrite = 7,
};
StringRef getFunctionParameterAttributeName(FunctionParameterAttribute e);
enum class Decoration : uint32_t {
RelaxedPrecision = 0,
SpecId = 1,
Block = 2,
BufferBlock = 3,
RowMajor = 4,
ColMajor = 5,
ArrayStride = 6,
MatrixStride = 7,
GLSLShared = 8,
GLSLPacked = 9,
CPacked = 10,
BuiltIn = 11,
NoPerspective = 13,
Flat = 14,
Patch = 15,
Centroid = 16,
Sample = 17,
Invariant = 18,
Restrict = 19,
Aliased = 20,
Volatile = 21,
Constant = 22,
Coherent = 23,
NonWritable = 24,
NonReadable = 25,
Uniform = 26,
UniformId = 27,
SaturatedConversion = 28,
Stream = 29,
Location = 30,
Component = 31,
Index = 32,
Binding = 33,
DescriptorSet = 34,
Offset = 35,
XfbBuffer = 36,
XfbStride = 37,
FuncParamAttr = 38,
FPRoundingMode = 39,
FPFastMathMode = 40,
LinkageAttributes = 41,
NoContraction = 42,
InputAttachmentIndex = 43,
Alignment = 44,
MaxByteOffset = 45,
AlignmentId = 46,
MaxByteOffsetId = 47,
NoSignedWrap = 4469,
NoUnsignedWrap = 4470,
ExplicitInterpAMD = 4999,
OverrideCoverageNV = 5248,
PassthroughNV = 5250,
ViewportRelativeNV = 5252,
SecondaryViewportRelativeNV = 5256,
PerPrimitiveNV = 5271,
PerViewNV = 5272,
PerVertexNV = 5273,
NonUniformEXT = 5300,
CountBuffer = 5634,
UserSemantic = 5635,
RestrictPointerEXT = 5355,
AliasedPointerEXT = 5356,
};
StringRef getDecorationName(Decoration e);
enum class BuiltIn : uint32_t {
Position = 0,
PointSize = 1,
ClipDistance = 3,
CullDistance = 4,
VertexId = 5,
InstanceId = 6,
PrimitiveId = 7,
InvocationId = 8,
Layer = 9,
ViewportIndex = 10,
TessLevelOuter = 11,
TessLevelInner = 12,
TessCoord = 13,
PatchVertices = 14,
FragCoord = 15,
PointCoord = 16,
FrontFacing = 17,
SampleId = 18,
SamplePosition = 19,
SampleMask = 20,
FragDepth = 22,
HelperInvocation = 23,
NumWorkgroups = 24,
WorkgroupSize = 25,
WorkgroupId = 26,
LocalInvocationId = 27,
GlobalInvocationId = 28,
LocalInvocationIndex = 29,
WorkDim = 30,
GlobalSize = 31,
EnqueuedWorkgroupSize = 32,
GlobalOffset = 33,
GlobalLinearId = 34,
SubgroupSize = 36,
SubgroupMaxSize = 37,
NumSubgroups = 38,
NumEnqueuedSubgroups = 39,
SubgroupId = 40,
SubgroupLocalInvocationId = 41,
VertexIndex = 42,
InstanceIndex = 43,
SubgroupEqMask = 4416,
SubgroupGeMask = 4417,
SubgroupGtMask = 4418,
SubgroupLeMask = 4419,
SubgroupLtMask = 4420,
BaseVertex = 4424,
BaseInstance = 4425,
DrawIndex = 4426,
DeviceIndex = 4438,
ViewIndex = 4440,
BaryCoordNoPerspAMD = 4492,
BaryCoordNoPerspCentroidAMD = 4493,
BaryCoordNoPerspSampleAMD = 4494,
BaryCoordSmoothAMD = 4495,
BaryCoordSmoothCentroid = 4496,
BaryCoordSmoothSample = 4497,
BaryCoordPullModel = 4498,
FragStencilRefEXT = 5014,
ViewportMaskNV = 5253,
SecondaryPositionNV = 5257,
SecondaryViewportMaskNV = 5258,
PositionPerViewNV = 5261,
ViewportMaskPerViewNV = 5262,
FullyCoveredEXT = 5264,
TaskCountNV = 5274,
PrimitiveCountNV = 5275,
PrimitiveIndicesNV = 5276,
ClipDistancePerViewNV = 5277,
CullDistancePerViewNV = 5278,
LayerPerViewNV = 5279,
MeshViewCountNV = 5280,
MeshViewIndices = 5281,
BaryCoordNV = 5286,
BaryCoordNoPerspNV = 5287,
FragSizeEXT = 5292,
FragInvocationCountEXT = 5293,
LaunchIdNV = 5319,
LaunchSizeNV = 5320,
WorldRayOriginNV = 5321,
WorldRayDirectionNV = 5322,
ObjectRayOriginNV = 5323,
ObjectRayDirectionNV = 5324,
RayTminNV = 5325,
RayTmaxNV = 5326,
InstanceCustomIndexNV = 5327,
ObjectToWorldNV = 5330,
WorldToObjectNV = 5331,
HitTNV = 5332,
HitKindNV = 5333,
IncomingRayFlagsNV = 5351,
};
StringRef getBuiltInName(BuiltIn e);
enum class SelectionControl : uint32_t {
None = 0x0,
Flatten = 0x1,
DontFlatten = 0x2,
};
std::string getSelectionControlName(uint32_t e);
enum class LoopControl : uint32_t {
None = 0x0,
Unroll = 0x1,
DontUnroll = 0x2,
DependencyInfinite = 0x4,
DependencyLength = 0x8,
MinIterations = 0x10,
MaxIterations = 0x20,
IterationMultiple = 0x40,
PeelCount = 0x80,
PartialCount = 0x100,
};
std::string getLoopControlName(uint32_t e);
enum class FunctionControl : uint32_t {
None = 0x0,
Inline = 0x1,
DontInline = 0x2,
Pure = 0x4,
Const = 0x8,
};
std::string getFunctionControlName(uint32_t e);
enum class MemorySemantics : uint32_t {
None = 0x0,
Acquire = 0x2,
Release = 0x4,
AcquireRelease = 0x8,
SequentiallyConsistent = 0x10,
UniformMemory = 0x40,
SubgroupMemory = 0x80,
WorkgroupMemory = 0x100,
CrossWorkgroupMemory = 0x200,
AtomicCounterMemory = 0x400,
ImageMemory = 0x800,
OutputMemoryKHR = 0x1000,
MakeAvailableKHR = 0x2000,
MakeVisibleKHR = 0x4000,
};
std::string getMemorySemanticsName(uint32_t e);
enum class MemoryOperand : uint32_t {
None = 0x0,
Volatile = 0x1,
Aligned = 0x2,
Nontemporal = 0x4,
MakePointerAvailableKHR = 0x8,
MakePointerVisibleKHR = 0x10,
NonPrivatePointerKHR = 0x20,
};
std::string getMemoryOperandName(uint32_t e);
enum class Scope : uint32_t {
CrossDevice = 0,
Device = 1,
Workgroup = 2,
Subgroup = 3,
Invocation = 4,
QueueFamilyKHR = 5,
};
StringRef getScopeName(Scope e);
enum class GroupOperation : uint32_t {
Reduce = 0,
InclusiveScan = 1,
ExclusiveScan = 2,
ClusteredReduce = 3,
PartitionedReduceNV = 6,
PartitionedInclusiveScanNV = 7,
PartitionedExclusiveScanNV = 8,
};
StringRef getGroupOperationName(GroupOperation e);
enum class KernelEnqueueFlags : uint32_t {
NoWait = 0,
WaitKernel = 1,
WaitWorkGroup = 2,
};
StringRef getKernelEnqueueFlagsName(KernelEnqueueFlags e);
enum class KernelProfilingInfo : uint32_t {
None = 0x0,
CmdExecTime = 0x1,
};
StringRef getKernelProfilingInfoName(KernelProfilingInfo e);
} // namespace SPIRV
} // namespace llvm
// Return a string representation of the operands from startIndex onwards.
// Templated to allow both MachineInstr and MCInst to use the same logic.
template <class InstType>
std::string getSPIRVStringOperand(const InstType &MI, unsigned StartIndex) {
std::string s; // Iteratively append to this string.
const unsigned NumOps = MI.getNumOperands();
bool IsFinished = false;
for (unsigned i = StartIndex; i < NumOps && !IsFinished; ++i) {
const auto &Op = MI.getOperand(i);
if (!Op.isImm()) // Stop if we hit a register operand.
break;
assert((Op.getImm() >> 32) == 0 && "Imm operand should be i32 word");
const uint32_t Imm = Op.getImm(); // Each i32 word is up to 4 characters.
for (unsigned ShiftAmount = 0; ShiftAmount < 32; ShiftAmount += 8) {
char c = (Imm >> ShiftAmount) & 0xff;
if (c == 0) { // Stop if we hit a null-terminator character.
IsFinished = true;
break;
} else {
s += c; // Otherwise, append the character to the result string.
}
}
}
return s;
}
#endif // LLVM_LIB_TARGET_SPIRV_MCTARGETDESC_SPIRVBASEINFO_H

View File

@ -16,6 +16,13 @@
namespace llvm {
class SPIRVTargetMachine;
class SPIRVSubtarget;
class InstructionSelector;
class RegisterBankInfo;
InstructionSelector *
createSPIRVInstructionSelector(const SPIRVTargetMachine &TM,
const SPIRVSubtarget &Subtarget,
const RegisterBankInfo &RBI);
} // namespace llvm
#endif // LLVM_LIB_TARGET_SPIRV_SPIRV_H

View File

@ -12,17 +12,21 @@
//===----------------------------------------------------------------------===//
#include "SPIRVCallLowering.h"
#include "MCTargetDesc/SPIRVBaseInfo.h"
#include "SPIRV.h"
#include "SPIRVGlobalRegistry.h"
#include "SPIRVISelLowering.h"
#include "SPIRVRegisterInfo.h"
#include "SPIRVSubtarget.h"
#include "SPIRVUtils.h"
#include "llvm/CodeGen/FunctionLoweringInfo.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
using namespace llvm;
SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI)
: CallLowering(&TLI) {}
SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI,
const SPIRVSubtarget &ST,
SPIRVGlobalRegistry *GR)
: CallLowering(&TLI), ST(ST), GR(GR) {}
bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
const Value *Val, ArrayRef<Register> VRegs,
@ -32,19 +36,39 @@ bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
// TODO: handle the case of multiple registers.
if (VRegs.size() > 1)
return false;
if (Val) {
MIRBuilder.buildInstr(SPIRV::OpReturnValue).addUse(VRegs[0]);
return true;
}
if (Val)
return MIRBuilder.buildInstr(SPIRV::OpReturnValue)
.addUse(VRegs[0])
.constrainAllUses(MIRBuilder.getTII(), *ST.getRegisterInfo(),
*ST.getRegBankInfo());
MIRBuilder.buildInstr(SPIRV::OpReturn);
return true;
}
// Based on the LLVM function attributes, get a SPIR-V FunctionControl.
static uint32_t getFunctionControl(const Function &F) {
uint32_t FuncControl = static_cast<uint32_t>(SPIRV::FunctionControl::None);
if (F.hasFnAttribute(Attribute::AttrKind::AlwaysInline)) {
FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Inline);
}
if (F.hasFnAttribute(Attribute::AttrKind::ReadNone)) {
FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Pure);
}
if (F.hasFnAttribute(Attribute::AttrKind::ReadOnly)) {
FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Const);
}
if (F.hasFnAttribute(Attribute::AttrKind::NoInline)) {
FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::DontInline);
}
return FuncControl;
}
bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
const Function &F,
ArrayRef<ArrayRef<Register>> VRegs,
FunctionLoweringInfo &FLI) const {
auto MRI = MIRBuilder.getMRI();
assert(GR && "Must initialize the SPIRV type registry before lowering args.");
// Assign types and names to all args, and store their types for later.
SmallVector<Register, 4> ArgTypeVRegs;
if (VRegs.size() > 0) {
@ -54,21 +78,55 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
// TODO: handle the case of multiple registers.
if (VRegs[i].size() > 1)
return false;
ArgTypeVRegs.push_back(
MRI->createGenericVirtualRegister(LLT::scalar(32)));
auto *SpirvTy =
GR->assignTypeToVReg(Arg.getType(), VRegs[i][0], MIRBuilder);
ArgTypeVRegs.push_back(GR->getSPIRVTypeID(SpirvTy));
if (Arg.hasName())
buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder);
if (Arg.getType()->isPointerTy()) {
auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes());
if (DerefBytes != 0)
buildOpDecorate(VRegs[i][0], MIRBuilder,
SPIRV::Decoration::MaxByteOffset, {DerefBytes});
}
if (Arg.hasAttribute(Attribute::Alignment)) {
buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Alignment,
{static_cast<unsigned>(Arg.getParamAlignment())});
}
if (Arg.hasAttribute(Attribute::ReadOnly)) {
auto Attr =
static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoWrite);
buildOpDecorate(VRegs[i][0], MIRBuilder,
SPIRV::Decoration::FuncParamAttr, {Attr});
}
if (Arg.hasAttribute(Attribute::ZExt)) {
auto Attr =
static_cast<unsigned>(SPIRV::FunctionParameterAttribute::Zext);
buildOpDecorate(VRegs[i][0], MIRBuilder,
SPIRV::Decoration::FuncParamAttr, {Attr});
}
++i;
}
}
// Generate a SPIR-V type for the function.
auto MRI = MIRBuilder.getMRI();
Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass);
auto *FTy = F.getFunctionType();
auto FuncTy = GR->assignTypeToVReg(FTy, FuncVReg, MIRBuilder);
// Build the OpTypeFunction declaring it.
Register ReturnTypeID = FuncTy->getOperand(1).getReg();
uint32_t FuncControl = getFunctionControl(F);
MIRBuilder.buildInstr(SPIRV::OpFunction)
.addDef(FuncVReg)
.addUse(MRI->createGenericVirtualRegister(LLT::scalar(32)))
.addImm(0)
.addUse(MRI->createGenericVirtualRegister(LLT::scalar(32)));
.addUse(ReturnTypeID)
.addImm(FuncControl)
.addUse(GR->getSPIRVTypeID(FuncTy));
// Add OpFunctionParameters.
const unsigned NumArgs = ArgTypeVRegs.size();
@ -79,6 +137,24 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
.addDef(VRegs[i][0])
.addUse(ArgTypeVRegs[i]);
}
// Name the function.
if (F.hasName())
buildOpName(FuncVReg, F.getName(), MIRBuilder);
// Handle entry points and function linkage.
if (F.getCallingConv() == CallingConv::SPIR_KERNEL) {
auto MIB = MIRBuilder.buildInstr(SPIRV::OpEntryPoint)
.addImm(static_cast<uint32_t>(SPIRV::ExecutionModel::Kernel))
.addUse(FuncVReg);
addStringImm(F.getName(), MIB);
} else if (F.getLinkage() == GlobalValue::LinkageTypes::ExternalLinkage ||
F.getLinkage() == GlobalValue::LinkOnceODRLinkage) {
auto LnkTy = F.isDeclaration() ? SPIRV::LinkageType::Import
: SPIRV::LinkageType::Export;
buildOpDecorate(FuncVReg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
{static_cast<uint32_t>(LnkTy)}, F.getGlobalIdentifier());
}
return true;
}
@ -91,15 +167,49 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
Register ResVReg =
Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
// Emit a regular OpFunctionCall. If it's an externally declared function,
// be sure to emit its type and function declaration here. It will be
// hoisted globally later.
if (Info.Callee.isGlobal()) {
auto *CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal());
// TODO: support constexpr casts and indirect calls.
if (CF == nullptr)
return false;
if (CF->isDeclaration()) {
// Emit the type info and forward function declaration to the first MBB
// to ensure VReg definition dependencies are valid across all MBBs.
MachineBasicBlock::iterator OldII = MIRBuilder.getInsertPt();
MachineBasicBlock &OldBB = MIRBuilder.getMBB();
MachineBasicBlock &FirstBB = *MIRBuilder.getMF().getBlockNumbered(0);
MIRBuilder.setInsertPt(FirstBB, FirstBB.instr_end());
SmallVector<ArrayRef<Register>, 8> VRegArgs;
SmallVector<SmallVector<Register, 1>, 8> ToInsert;
for (const Argument &Arg : CF->args()) {
if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero())
continue; // Don't handle zero sized types.
ToInsert.push_back({MIRBuilder.getMRI()->createGenericVirtualRegister(
LLT::scalar(32))});
VRegArgs.push_back(ToInsert.back());
}
// TODO: Reuse FunctionLoweringInfo.
FunctionLoweringInfo FuncInfo;
lowerFormalArguments(MIRBuilder, *CF, VRegArgs, FuncInfo);
MIRBuilder.setInsertPt(OldBB, OldII);
}
}
// Make sure there's a valid return reg, even for functions returning void.
if (!ResVReg.isValid()) {
ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
}
SPIRVType *RetType =
GR->assignTypeToVReg(Info.OrigRet.Ty, ResVReg, MIRBuilder);
// Emit the OpFunctionCall and its args.
auto MIB = MIRBuilder.buildInstr(SPIRV::OpFunctionCall)
.addDef(ResVReg)
.addUse(MIRBuilder.getMRI()->createVirtualRegister(
&SPIRV::IDRegClass))
.addUse(GR->getSPIRVTypeID(RetType))
.add(Info.Callee);
for (const auto &Arg : Info.OrigArgs) {
@ -108,5 +218,6 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
return false;
MIB.addUse(Arg.Regs[0]);
}
return true;
return MIB.constrainAllUses(MIRBuilder.getTII(), *ST.getRegisterInfo(),
*ST.getRegBankInfo());
}

View File

@ -17,12 +17,19 @@
namespace llvm {
class SPIRVGlobalRegistry;
class SPIRVSubtarget;
class SPIRVTargetLowering;
class SPIRVCallLowering : public CallLowering {
private:
const SPIRVSubtarget &ST;
// Used to create and assign function, argument, and return type information.
SPIRVGlobalRegistry *GR;
public:
SPIRVCallLowering(const SPIRVTargetLowering &TLI);
SPIRVCallLowering(const SPIRVTargetLowering &TLI, const SPIRVSubtarget &ST,
SPIRVGlobalRegistry *GR);
// Built OpReturn or OpReturnValue.
bool lowerReturn(MachineIRBuilder &MIRBuiler, const Value *Val,

View File

@ -0,0 +1,453 @@
//===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains the implementation of the SPIRVGlobalRegistry class,
// which is used to maintain rich type information required for SPIR-V even
// after lowering from LLVM IR to GMIR. It can convert an llvm::Type into
// an OpTypeXXX instruction, and map it to a virtual register. Also it builds
// and supports consistency of constants and global variables.
//
//===----------------------------------------------------------------------===//
#include "SPIRVGlobalRegistry.h"
#include "SPIRV.h"
#include "SPIRVSubtarget.h"
#include "SPIRVTargetMachine.h"
#include "SPIRVUtils.h"
using namespace llvm;
SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)
: PointerSize(PointerSize) {}
SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg(
const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier AccessQual, bool EmitIR) {
SPIRVType *SpirvType =
getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR);
assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder);
return SpirvType;
}
void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType,
Register VReg,
MachineIRBuilder &MIRBuilder) {
VRegToTypeMap[&MIRBuilder.getMF()][VReg] = SpirvType;
}
static Register createTypeVReg(MachineIRBuilder &MIRBuilder) {
auto &MRI = MIRBuilder.getMF().getRegInfo();
auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));
MRI.setRegClass(Res, &SPIRV::TYPERegClass);
return Res;
}
static Register createTypeVReg(MachineRegisterInfo &MRI) {
auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));
MRI.setRegClass(Res, &SPIRV::TYPERegClass);
return Res;
}
SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
return MIRBuilder.buildInstr(SPIRV::OpTypeBool)
.addDef(createTypeVReg(MIRBuilder));
}
SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width,
MachineIRBuilder &MIRBuilder,
bool IsSigned) {
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt)
.addDef(createTypeVReg(MIRBuilder))
.addImm(Width)
.addImm(IsSigned ? 1 : 0);
return MIB;
}
SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
MachineIRBuilder &MIRBuilder) {
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
.addDef(createTypeVReg(MIRBuilder))
.addImm(Width);
return MIB;
}
SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
return MIRBuilder.buildInstr(SPIRV::OpTypeVoid)
.addDef(createTypeVReg(MIRBuilder));
}
SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
SPIRVType *ElemType,
MachineIRBuilder &MIRBuilder) {
auto EleOpc = ElemType->getOpcode();
assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
EleOpc == SPIRV::OpTypeBool) &&
"Invalid vector element type");
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeVector)
.addDef(createTypeVReg(MIRBuilder))
.addUse(getSPIRVTypeID(ElemType))
.addImm(NumElems);
return MIB;
}
Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType,
bool EmitIR) {
auto &MF = MIRBuilder.getMF();
Register Res;
const IntegerType *LLVMIntTy;
if (SpvType)
LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));
else
LLVMIntTy = IntegerType::getInt32Ty(MF.getFunction().getContext());
// Find a constant in DT or build a new one.
const auto ConstInt =
ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
assignTypeToVReg(LLVMIntTy, Res, MIRBuilder);
if (EmitIR)
MIRBuilder.buildConstant(Res, *ConstInt);
else
MIRBuilder.buildInstr(SPIRV::OpConstantI)
.addDef(Res)
.addImm(ConstInt->getSExtValue());
return Res;
}
Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType) {
auto &MF = MIRBuilder.getMF();
Register Res;
const Type *LLVMFPTy;
if (SpvType) {
LLVMFPTy = getTypeForSPIRVType(SpvType);
assert(LLVMFPTy->isFloatingPointTy());
} else {
LLVMFPTy = IntegerType::getFloatTy(MF.getFunction().getContext());
}
// Find a constant in DT or build a new one.
const auto ConstFP = ConstantFP::get(LLVMFPTy->getContext(), Val);
unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
assignTypeToVReg(LLVMFPTy, Res, MIRBuilder);
MIRBuilder.buildFConstant(Res, *ConstFP);
return Res;
}
Register SPIRVGlobalRegistry::buildGlobalVariable(
Register ResVReg, SPIRVType *BaseType, StringRef Name,
const GlobalValue *GV, SPIRV::StorageClass Storage,
const MachineInstr *Init, bool IsConst, bool HasLinkageTy,
SPIRV::LinkageType LinkageType, MachineIRBuilder &MIRBuilder,
bool IsInstSelector) {
const GlobalVariable *GVar = nullptr;
if (GV)
GVar = cast<const GlobalVariable>(GV);
else {
// If GV is not passed explicitly, use the name to find or construct
// the global variable.
Module *M = MIRBuilder.getMF().getFunction().getParent();
GVar = M->getGlobalVariable(Name);
if (GVar == nullptr) {
const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type.
GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false,
GlobalValue::ExternalLinkage, nullptr,
Twine(Name));
}
GV = GVar;
}
Register Reg;
auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable)
.addDef(ResVReg)
.addUse(getSPIRVTypeID(BaseType))
.addImm(static_cast<uint32_t>(Storage));
if (Init != 0) {
MIB.addUse(Init->getOperand(0).getReg());
}
// ISel may introduce a new register on this step, so we need to add it to
// DT and correct its type avoiding fails on the next stage.
if (IsInstSelector) {
const auto &Subtarget = CurMF->getSubtarget();
constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
*Subtarget.getRegisterInfo(),
*Subtarget.getRegBankInfo());
}
Reg = MIB->getOperand(0).getReg();
// Set to Reg the same type as ResVReg has.
auto MRI = MIRBuilder.getMRI();
assert(MRI->getType(ResVReg).isPointer() && "Pointer type is expected");
if (Reg != ResVReg) {
LLT RegLLTy = LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), 32);
MRI->setType(Reg, RegLLTy);
assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder);
}
// If it's a global variable with name, output OpName for it.
if (GVar && GVar->hasName())
buildOpName(Reg, GVar->getName(), MIRBuilder);
// Output decorations for the GV.
// TODO: maybe move to GenerateDecorations pass.
if (IsConst)
buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {});
if (GVar && GVar->getAlign().valueOrOne().value() != 1)
buildOpDecorate(
Reg, MIRBuilder, SPIRV::Decoration::Alignment,
{static_cast<uint32_t>(GVar->getAlign().valueOrOne().value())});
if (HasLinkageTy)
buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
{static_cast<uint32_t>(LinkageType)}, Name);
return Reg;
}
SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems,
SPIRVType *ElemType,
MachineIRBuilder &MIRBuilder,
bool EmitIR) {
assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) &&
"Invalid array element type");
Register NumElementsVReg =
buildConstantInt(NumElems, MIRBuilder, nullptr, EmitIR);
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeArray)
.addDef(createTypeVReg(MIRBuilder))
.addUse(getSPIRVTypeID(ElemType))
.addUse(NumElementsVReg);
return MIB;
}
SPIRVType *SPIRVGlobalRegistry::getOpTypePointer(SPIRV::StorageClass SC,
SPIRVType *ElemType,
MachineIRBuilder &MIRBuilder) {
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypePointer)
.addDef(createTypeVReg(MIRBuilder))
.addImm(static_cast<uint32_t>(SC))
.addUse(getSPIRVTypeID(ElemType));
return MIB;
}
SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction(
SPIRVType *RetType, const SmallVectorImpl<SPIRVType *> &ArgTypes,
MachineIRBuilder &MIRBuilder) {
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFunction)
.addDef(createTypeVReg(MIRBuilder))
.addUse(getSPIRVTypeID(RetType));
for (const SPIRVType *ArgType : ArgTypes)
MIB.addUse(getSPIRVTypeID(ArgType));
return MIB;
}
SPIRVType *SPIRVGlobalRegistry::createSPIRVType(const Type *Ty,
MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier AccQual,
bool EmitIR) {
if (auto IType = dyn_cast<IntegerType>(Ty)) {
const unsigned Width = IType->getBitWidth();
return Width == 1 ? getOpTypeBool(MIRBuilder)
: getOpTypeInt(Width, MIRBuilder, false);
}
if (Ty->isFloatingPointTy())
return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
if (Ty->isVoidTy())
return getOpTypeVoid(MIRBuilder);
if (Ty->isVectorTy()) {
auto El = getOrCreateSPIRVType(cast<FixedVectorType>(Ty)->getElementType(),
MIRBuilder);
return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El,
MIRBuilder);
}
if (Ty->isArrayTy()) {
auto *El = getOrCreateSPIRVType(Ty->getArrayElementType(), MIRBuilder);
return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR);
}
assert(!isa<StructType>(Ty) && "Unsupported StructType");
if (auto FType = dyn_cast<FunctionType>(Ty)) {
SPIRVType *RetTy = getOrCreateSPIRVType(FType->getReturnType(), MIRBuilder);
SmallVector<SPIRVType *, 4> ParamTypes;
for (const auto &t : FType->params()) {
ParamTypes.push_back(getOrCreateSPIRVType(t, MIRBuilder));
}
return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder);
}
if (auto PType = dyn_cast<PointerType>(Ty)) {
Type *ElemType = PType->getPointerElementType();
// Some OpenCL and SPIRV builtins like image2d_t are passed in as pointers,
// but should be treated as custom types like OpTypeImage.
assert(!isa<StructType>(ElemType) && "Unsupported StructType pointer");
// Otherwise, treat it as a regular pointer type.
auto SC = addressSpaceToStorageClass(PType->getAddressSpace());
SPIRVType *SpvElementType = getOrCreateSPIRVType(
ElemType, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, EmitIR);
return getOpTypePointer(SC, SpvElementType, MIRBuilder);
}
llvm_unreachable("Unable to convert LLVM type to SPIRVType");
}
SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const {
auto t = VRegToTypeMap.find(CurMF);
if (t != VRegToTypeMap.end()) {
auto tt = t->second.find(VReg);
if (tt != t->second.end())
return tt->second;
}
return nullptr;
}
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
const Type *Type, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier AccessQual, bool EmitIR) {
Register Reg;
SPIRVType *SpirvType = createSPIRVType(Type, MIRBuilder, AccessQual, EmitIR);
VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
SPIRVToLLVMType[SpirvType] = Type;
return SpirvType;
}
bool SPIRVGlobalRegistry::isScalarOfType(Register VReg,
unsigned TypeOpcode) const {
SPIRVType *Type = getSPIRVTypeForVReg(VReg);
assert(Type && "isScalarOfType VReg has no type assigned");
return Type->getOpcode() == TypeOpcode;
}
bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg,
unsigned TypeOpcode) const {
SPIRVType *Type = getSPIRVTypeForVReg(VReg);
assert(Type && "isScalarOrVectorOfType VReg has no type assigned");
if (Type->getOpcode() == TypeOpcode)
return true;
if (Type->getOpcode() == SPIRV::OpTypeVector) {
Register ScalarTypeVReg = Type->getOperand(1).getReg();
SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarTypeVReg);
return ScalarType->getOpcode() == TypeOpcode;
}
return false;
}
unsigned
SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const {
assert(Type && "Invalid Type pointer");
if (Type->getOpcode() == SPIRV::OpTypeVector) {
auto EleTypeReg = Type->getOperand(1).getReg();
Type = getSPIRVTypeForVReg(EleTypeReg);
}
if (Type->getOpcode() == SPIRV::OpTypeInt ||
Type->getOpcode() == SPIRV::OpTypeFloat)
return Type->getOperand(1).getImm();
if (Type->getOpcode() == SPIRV::OpTypeBool)
return 1;
llvm_unreachable("Attempting to get bit width of non-integer/float type.");
}
bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
assert(Type && "Invalid Type pointer");
if (Type->getOpcode() == SPIRV::OpTypeVector) {
auto EleTypeReg = Type->getOperand(1).getReg();
Type = getSPIRVTypeForVReg(EleTypeReg);
}
if (Type->getOpcode() == SPIRV::OpTypeInt)
return Type->getOperand(2).getImm() != 0;
llvm_unreachable("Attempting to get sign of non-integer type.");
}
SPIRV::StorageClass
SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const {
SPIRVType *Type = getSPIRVTypeForVReg(VReg);
assert(Type && Type->getOpcode() == SPIRV::OpTypePointer &&
Type->getOperand(1).isImm() && "Pointer type is expected");
return static_cast<SPIRV::StorageClass>(Type->getOperand(1).getImm());
}
SPIRVType *
SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth,
MachineIRBuilder &MIRBuilder) {
return getOrCreateSPIRVType(
IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth),
MIRBuilder);
}
SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(Type *LLVMTy,
MachineInstrBuilder MIB) {
SPIRVType *SpirvType = MIB;
VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType;
SPIRVToLLVMType[SpirvType] = LLVMTy;
return SpirvType;
}
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
MachineBasicBlock &BB = *I.getParent();
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeInt))
.addDef(createTypeVReg(CurMF->getRegInfo()))
.addImm(BitWidth)
.addImm(0);
return restOfCreateSPIRVType(LLVMTy, MIB);
}
SPIRVType *
SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) {
return getOrCreateSPIRVType(
IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1),
MIRBuilder);
}
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder) {
return getOrCreateSPIRVType(
FixedVectorType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
NumElements),
MIRBuilder);
}
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,
const SPIRVInstrInfo &TII) {
Type *LLVMTy = FixedVectorType::get(
const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements);
MachineBasicBlock &BB = *I.getParent();
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeVector))
.addDef(createTypeVReg(CurMF->getRegInfo()))
.addUse(getSPIRVTypeID(BaseType))
.addImm(NumElements);
return restOfCreateSPIRVType(LLVMTy, MIB);
}
SPIRVType *
SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(SPIRVType *BaseType,
MachineIRBuilder &MIRBuilder,
SPIRV::StorageClass SClass) {
return getOrCreateSPIRVType(
PointerType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
storageClassToAddressSpace(SClass)),
MIRBuilder);
}
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII,
SPIRV::StorageClass SC) {
Type *LLVMTy =
PointerType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
storageClassToAddressSpace(SC));
MachineBasicBlock &BB = *I.getParent();
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypePointer))
.addDef(createTypeVReg(CurMF->getRegInfo()))
.addImm(static_cast<uint32_t>(SC))
.addUse(getSPIRVTypeID(BaseType));
return restOfCreateSPIRVType(LLVMTy, MIB);
}

View File

@ -0,0 +1,174 @@
//===-- SPIRVGlobalRegistry.h - SPIR-V Global Registry ----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// SPIRVGlobalRegistry is used to maintain rich type information required for
// SPIR-V even after lowering from LLVM IR to GMIR. It can convert an llvm::Type
// into an OpTypeXXX instruction, and map it to a virtual register. Also it
// builds and supports consistency of constants and global variables.
//
//===----------------------------------------------------------------------===//
#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVTYPEMANAGER_H
#define LLVM_LIB_TARGET_SPIRV_SPIRVTYPEMANAGER_H
#include "MCTargetDesc/SPIRVBaseInfo.h"
#include "SPIRVInstrInfo.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
namespace llvm {
using SPIRVType = const MachineInstr;
class SPIRVGlobalRegistry {
// Registers holding values which have types associated with them.
// Initialized upon VReg definition in IRTranslator.
// Do not confuse this with DuplicatesTracker as DT maps Type* to <MF, Reg>
// where Reg = OpType...
// while VRegToTypeMap tracks SPIR-V type assigned to other regs (i.e. not
// type-declaring ones)
DenseMap<MachineFunction *, DenseMap<Register, SPIRVType *>> VRegToTypeMap;
DenseMap<SPIRVType *, const Type *> SPIRVToLLVMType;
// Number of bits pointers and size_t integers require.
const unsigned PointerSize;
// Add a new OpTypeXXX instruction without checking for duplicates.
SPIRVType *
createSPIRVType(const Type *Type, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier AQ = SPIRV::AccessQualifier::ReadWrite,
bool EmitIR = true);
public:
SPIRVGlobalRegistry(unsigned PointerSize);
MachineFunction *CurMF;
// Get or create a SPIR-V type corresponding the given LLVM IR type,
// and map it to the given VReg by creating an ASSIGN_TYPE instruction.
SPIRVType *assignTypeToVReg(
const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier AQ = SPIRV::AccessQualifier::ReadWrite,
bool EmitIR = true);
// In cases where the SPIR-V type is already known, this function can be
// used to map it to the given VReg via an ASSIGN_TYPE instruction.
void assignSPIRVTypeToVReg(SPIRVType *Type, Register VReg,
MachineIRBuilder &MIRBuilder);
// Either generate a new OpTypeXXX instruction or return an existing one
// corresponding to the given LLVM IR type.
// EmitIR controls if we emit GMIR or SPV constants (e.g. for array sizes)
// because this method may be called from InstructionSelector and we don't
// want to emit extra IR instructions there.
SPIRVType *getOrCreateSPIRVType(
const Type *Type, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier AQ = SPIRV::AccessQualifier::ReadWrite,
bool EmitIR = true);
const Type *getTypeForSPIRVType(const SPIRVType *Ty) const {
auto Res = SPIRVToLLVMType.find(Ty);
assert(Res != SPIRVToLLVMType.end());
return Res->second;
}
// Return the SPIR-V type instruction corresponding to the given VReg, or
// nullptr if no such type instruction exists.
SPIRVType *getSPIRVTypeForVReg(Register VReg) const;
// Whether the given VReg has a SPIR-V type mapped to it yet.
bool hasSPIRVTypeForVReg(Register VReg) const {
return getSPIRVTypeForVReg(VReg) != nullptr;
}
// Return the VReg holding the result of the given OpTypeXXX instruction.
Register getSPIRVTypeID(const SPIRVType *SpirvType) const {
assert(SpirvType && "Attempting to get type id for nullptr type.");
return SpirvType->defs().begin()->getReg();
}
void setCurrentFunc(MachineFunction &MF) { CurMF = &MF; }
// Whether the given VReg has an OpTypeXXX instruction mapped to it with the
// given opcode (e.g. OpTypeFloat).
bool isScalarOfType(Register VReg, unsigned TypeOpcode) const;
// Return true if the given VReg's assigned SPIR-V type is either a scalar
// matching the given opcode, or a vector with an element type matching that
// opcode (e.g. OpTypeBool, or OpTypeVector %x 4, where %x is OpTypeBool).
bool isScalarOrVectorOfType(Register VReg, unsigned TypeOpcode) const;
// For vectors or scalars of ints/floats, return the scalar type's bitwidth.
unsigned getScalarOrVectorBitWidth(const SPIRVType *Type) const;
// For integer vectors or scalars, return whether the integers are signed.
bool isScalarOrVectorSigned(const SPIRVType *Type) const;
// Gets the storage class of the pointer type assigned to this vreg.
SPIRV::StorageClass getPointerStorageClass(Register VReg) const;
// Return the number of bits SPIR-V pointers and size_t variables require.
unsigned getPointerSize() const { return PointerSize; }
private:
SPIRVType *getOpTypeBool(MachineIRBuilder &MIRBuilder);
SPIRVType *getOpTypeInt(uint32_t Width, MachineIRBuilder &MIRBuilder,
bool IsSigned = false);
SPIRVType *getOpTypeFloat(uint32_t Width, MachineIRBuilder &MIRBuilder);
SPIRVType *getOpTypeVoid(MachineIRBuilder &MIRBuilder);
SPIRVType *getOpTypeVector(uint32_t NumElems, SPIRVType *ElemType,
MachineIRBuilder &MIRBuilder);
SPIRVType *getOpTypeArray(uint32_t NumElems, SPIRVType *ElemType,
MachineIRBuilder &MIRBuilder, bool EmitIR = true);
SPIRVType *getOpTypePointer(SPIRV::StorageClass SC, SPIRVType *ElemType,
MachineIRBuilder &MIRBuilder);
SPIRVType *getOpTypeFunction(SPIRVType *RetType,
const SmallVectorImpl<SPIRVType *> &ArgTypes,
MachineIRBuilder &MIRBuilder);
SPIRVType *restOfCreateSPIRVType(Type *LLVMTy, MachineInstrBuilder MIB);
public:
Register buildConstantInt(uint64_t Val, MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType = nullptr, bool EmitIR = true);
Register buildConstantFP(APFloat Val, MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType = nullptr);
Register
buildGlobalVariable(Register Reg, SPIRVType *BaseType, StringRef Name,
const GlobalValue *GV, SPIRV::StorageClass Storage,
const MachineInstr *Init, bool IsConst, bool HasLinkageTy,
SPIRV::LinkageType LinkageType,
MachineIRBuilder &MIRBuilder, bool IsInstSelector);
// Convenient helpers for getting types with check for duplicates.
SPIRVType *getOrCreateSPIRVIntegerType(unsigned BitWidth,
MachineIRBuilder &MIRBuilder);
SPIRVType *getOrCreateSPIRVIntegerType(unsigned BitWidth, MachineInstr &I,
const SPIRVInstrInfo &TII);
SPIRVType *getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder);
SPIRVType *getOrCreateSPIRVVectorType(SPIRVType *BaseType,
unsigned NumElements,
MachineIRBuilder &MIRBuilder);
SPIRVType *getOrCreateSPIRVVectorType(SPIRVType *BaseType,
unsigned NumElements, MachineInstr &I,
const SPIRVInstrInfo &TII);
SPIRVType *getOrCreateSPIRVPointerType(
SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
SPIRV::StorageClass SClass = SPIRV::StorageClass::Function);
SPIRVType *getOrCreateSPIRVPointerType(
SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII,
SPIRV::StorageClass SClass = SPIRV::StorageClass::Function);
};
} // end namespace llvm
#endif // LLLVM_LIB_TARGET_SPIRV_SPIRVTYPEMANAGER_H

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,301 @@
//===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the targeting of the Machinelegalizer class for SPIR-V.
//
//===----------------------------------------------------------------------===//
#include "SPIRVLegalizerInfo.h"
#include "SPIRV.h"
#include "SPIRVGlobalRegistry.h"
#include "SPIRVSubtarget.h"
#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
#include "llvm/CodeGen/MachineInstr.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/TargetOpcodes.h"
using namespace llvm;
using namespace llvm::LegalizeActions;
using namespace llvm::LegalityPredicates;
static const std::set<unsigned> TypeFoldingSupportingOpcs = {
TargetOpcode::G_ADD,
TargetOpcode::G_FADD,
TargetOpcode::G_SUB,
TargetOpcode::G_FSUB,
TargetOpcode::G_MUL,
TargetOpcode::G_FMUL,
TargetOpcode::G_SDIV,
TargetOpcode::G_UDIV,
TargetOpcode::G_FDIV,
TargetOpcode::G_SREM,
TargetOpcode::G_UREM,
TargetOpcode::G_FREM,
TargetOpcode::G_FNEG,
TargetOpcode::G_CONSTANT,
TargetOpcode::G_FCONSTANT,
TargetOpcode::G_AND,
TargetOpcode::G_OR,
TargetOpcode::G_XOR,
TargetOpcode::G_SHL,
TargetOpcode::G_ASHR,
TargetOpcode::G_LSHR,
TargetOpcode::G_SELECT,
TargetOpcode::G_EXTRACT_VECTOR_ELT,
};
bool isTypeFoldingSupported(unsigned Opcode) {
return TypeFoldingSupportingOpcs.count(Opcode) > 0;
}
SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
using namespace TargetOpcode;
this->ST = &ST;
GR = ST.getSPIRVGlobalRegistry();
const LLT s1 = LLT::scalar(1);
const LLT s8 = LLT::scalar(8);
const LLT s16 = LLT::scalar(16);
const LLT s32 = LLT::scalar(32);
const LLT s64 = LLT::scalar(64);
const LLT v16s64 = LLT::fixed_vector(16, 64);
const LLT v16s32 = LLT::fixed_vector(16, 32);
const LLT v16s16 = LLT::fixed_vector(16, 16);
const LLT v16s8 = LLT::fixed_vector(16, 8);
const LLT v16s1 = LLT::fixed_vector(16, 1);
const LLT v8s64 = LLT::fixed_vector(8, 64);
const LLT v8s32 = LLT::fixed_vector(8, 32);
const LLT v8s16 = LLT::fixed_vector(8, 16);
const LLT v8s8 = LLT::fixed_vector(8, 8);
const LLT v8s1 = LLT::fixed_vector(8, 1);
const LLT v4s64 = LLT::fixed_vector(4, 64);
const LLT v4s32 = LLT::fixed_vector(4, 32);
const LLT v4s16 = LLT::fixed_vector(4, 16);
const LLT v4s8 = LLT::fixed_vector(4, 8);
const LLT v4s1 = LLT::fixed_vector(4, 1);
const LLT v3s64 = LLT::fixed_vector(3, 64);
const LLT v3s32 = LLT::fixed_vector(3, 32);
const LLT v3s16 = LLT::fixed_vector(3, 16);
const LLT v3s8 = LLT::fixed_vector(3, 8);
const LLT v3s1 = LLT::fixed_vector(3, 1);
const LLT v2s64 = LLT::fixed_vector(2, 64);
const LLT v2s32 = LLT::fixed_vector(2, 32);
const LLT v2s16 = LLT::fixed_vector(2, 16);
const LLT v2s8 = LLT::fixed_vector(2, 8);
const LLT v2s1 = LLT::fixed_vector(2, 1);
const unsigned PSize = ST.getPointerSize();
const LLT p0 = LLT::pointer(0, PSize); // Function
const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup
const LLT p2 = LLT::pointer(2, PSize); // UniformConstant
const LLT p3 = LLT::pointer(3, PSize); // Workgroup
const LLT p4 = LLT::pointer(4, PSize); // Generic
const LLT p5 = LLT::pointer(5, PSize); // Input
// TODO: remove copy-pasting here by using concatenation in some way.
auto allPtrsScalarsAndVectors = {
p0, p1, p2, p3, p4, p5, s1, s8, s16,
s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8,
v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1,
v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
auto allScalarsAndVectors = {
s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64,
v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
auto allIntScalarsAndVectors = {s8, s16, s32, s64, v2s8, v2s16,
v2s32, v2s64, v3s8, v3s16, v3s32, v3s64,
v4s8, v4s16, v4s32, v4s64, v8s8, v8s16,
v8s32, v8s64, v16s8, v16s16, v16s32, v16s64};
auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1};
auto allIntScalars = {s8, s16, s32, s64};
auto allFloatScalarsAndVectors = {
s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64,
v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
auto allFloatAndIntScalars = allIntScalars;
auto allPtrs = {p0, p1, p2, p3, p4, p5};
auto allWritablePtrs = {p0, p1, p3, p4};
for (auto Opc : TypeFoldingSupportingOpcs)
getActionDefinitionsBuilder(Opc).custom();
getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal();
// TODO: add proper rules for vectors legalization.
getActionDefinitionsBuilder({G_BUILD_VECTOR, G_SHUFFLE_VECTOR}).alwaysLegal();
getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
.legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs)));
getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
.legalForCartesianProduct(allPtrs, allPtrs);
getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs));
getActionDefinitionsBuilder(G_BITREVERSE).legalFor(allFloatScalarsAndVectors);
getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors);
getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
.legalForCartesianProduct(allIntScalarsAndVectors,
allFloatScalarsAndVectors);
getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
.legalForCartesianProduct(allFloatScalarsAndVectors,
allScalarsAndVectors);
getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS})
.legalFor(allIntScalarsAndVectors);
getActionDefinitionsBuilder(G_CTPOP).legalForCartesianProduct(
allIntScalarsAndVectors, allIntScalarsAndVectors);
getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors);
getActionDefinitionsBuilder(G_BITCAST).legalIf(all(
typeInSet(0, allPtrsScalarsAndVectors),
typeInSet(1, allPtrsScalarsAndVectors),
LegalityPredicate(([=](const LegalityQuery &Query) {
return Query.Types[0].getSizeInBits() == Query.Types[1].getSizeInBits();
}))));
getActionDefinitionsBuilder(G_IMPLICIT_DEF).alwaysLegal();
getActionDefinitionsBuilder(G_INTTOPTR)
.legalForCartesianProduct(allPtrs, allIntScalars);
getActionDefinitionsBuilder(G_PTRTOINT)
.legalForCartesianProduct(allIntScalars, allPtrs);
getActionDefinitionsBuilder(G_PTR_ADD).legalForCartesianProduct(
allPtrs, allIntScalars);
// ST.canDirectlyComparePointers() for pointer args is supported in
// legalizeCustom().
getActionDefinitionsBuilder(G_ICMP).customIf(
all(typeInSet(0, allBoolScalarsAndVectors),
typeInSet(1, allPtrsScalarsAndVectors)));
getActionDefinitionsBuilder(G_FCMP).legalIf(
all(typeInSet(0, allBoolScalarsAndVectors),
typeInSet(1, allFloatScalarsAndVectors)));
getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND,
G_ATOMICRMW_MAX, G_ATOMICRMW_MIN,
G_ATOMICRMW_SUB, G_ATOMICRMW_XOR,
G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
.legalForCartesianProduct(allIntScalars, allWritablePtrs);
getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
.legalForCartesianProduct(allFloatAndIntScalars, allWritablePtrs);
getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower();
// TODO: add proper legalization rules.
getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal();
getActionDefinitionsBuilder({G_UADDO, G_USUBO, G_SMULO, G_UMULO})
.alwaysLegal();
// Extensions.
getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT})
.legalForCartesianProduct(allScalarsAndVectors);
// FP conversions.
getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT})
.legalForCartesianProduct(allFloatScalarsAndVectors);
// Pointer-handling.
getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
// Control-flow.
getActionDefinitionsBuilder(G_BRCOND).legalFor({s1});
getActionDefinitionsBuilder({G_FPOW,
G_FEXP,
G_FEXP2,
G_FLOG,
G_FLOG2,
G_FABS,
G_FMINNUM,
G_FMAXNUM,
G_FCEIL,
G_FCOS,
G_FSIN,
G_FSQRT,
G_FFLOOR,
G_FRINT,
G_FNEARBYINT,
G_INTRINSIC_ROUND,
G_INTRINSIC_TRUNC,
G_FMINIMUM,
G_FMAXIMUM,
G_INTRINSIC_ROUNDEVEN})
.legalFor(allFloatScalarsAndVectors);
getActionDefinitionsBuilder(G_FCOPYSIGN)
.legalForCartesianProduct(allFloatScalarsAndVectors,
allFloatScalarsAndVectors);
getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct(
allFloatScalarsAndVectors, allIntScalarsAndVectors);
getLegacyLegalizerInfo().computeTables();
verify(*ST.getInstrInfo());
}
static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType,
LegalizerHelper &Helper,
MachineRegisterInfo &MRI,
SPIRVGlobalRegistry *GR) {
Register ConvReg = MRI.createGenericVirtualRegister(ConvTy);
GR->assignSPIRVTypeToVReg(SpirvType, ConvReg, Helper.MIRBuilder);
Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
.addDef(ConvReg)
.addUse(Reg);
return ConvReg;
}
bool SPIRVLegalizerInfo::legalizeCustom(LegalizerHelper &Helper,
MachineInstr &MI) const {
auto Opc = MI.getOpcode();
MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
if (!isTypeFoldingSupported(Opc)) {
assert(Opc == TargetOpcode::G_ICMP);
assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
auto &Op0 = MI.getOperand(2);
auto &Op1 = MI.getOperand(3);
Register Reg0 = Op0.getReg();
Register Reg1 = Op1.getReg();
CmpInst::Predicate Cond =
static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
if ((!ST->canDirectlyComparePointers() ||
(Cond != CmpInst::ICMP_EQ && Cond != CmpInst::ICMP_NE)) &&
MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) {
LLT ConvT = LLT::scalar(ST->getPointerSize());
Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(),
ST->getPointerSize());
SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder);
Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR));
Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR));
}
return true;
}
// TODO: implement legalization for other opcodes.
return true;
}

View File

@ -0,0 +1,36 @@
//===- SPIRVLegalizerInfo.h --- SPIR-V Legalization Rules --------*- C++ -*-==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares the targeting of the MachineLegalizer class for SPIR-V.
//
//===----------------------------------------------------------------------===//
#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVMACHINELEGALIZER_H
#define LLVM_LIB_TARGET_SPIRV_SPIRVMACHINELEGALIZER_H
#include "SPIRVGlobalRegistry.h"
#include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
bool isTypeFoldingSupported(unsigned Opcode);
namespace llvm {
class LLVMContext;
class SPIRVSubtarget;
// This class provides the information for legalizing SPIR-V instructions.
class SPIRVLegalizerInfo : public LegalizerInfo {
const SPIRVSubtarget *ST;
SPIRVGlobalRegistry *GR;
public:
bool legalizeCustom(LegalizerHelper &Helper, MachineInstr &MI) const override;
SPIRVLegalizerInfo(const SPIRVSubtarget &ST);
};
} // namespace llvm
#endif // LLVM_LIB_TARGET_SPIRV_SPIRVMACHINELEGALIZER_H

View File

@ -12,6 +12,8 @@
#include "SPIRVSubtarget.h"
#include "SPIRV.h"
#include "SPIRVGlobalRegistry.h"
#include "SPIRVLegalizerInfo.h"
#include "SPIRVRegisterBankInfo.h"
#include "SPIRVTargetMachine.h"
#include "llvm/MC/TargetRegistry.h"
@ -43,8 +45,13 @@ SPIRVSubtarget::SPIRVSubtarget(const Triple &TT, const std::string &CPU,
: SPIRVGenSubtargetInfo(TT, CPU, /*TuneCPU=*/CPU, FS),
PointerSize(computePointerSize(TT)), SPIRVVersion(0), InstrInfo(),
FrameLowering(initSubtargetDependencies(CPU, FS)), TLInfo(TM, *this) {
CallLoweringInfo = std::make_unique<SPIRVCallLowering>(TLInfo);
GR = std::make_unique<SPIRVGlobalRegistry>(PointerSize);
CallLoweringInfo =
std::make_unique<SPIRVCallLowering>(TLInfo, *this, GR.get());
Legalizer = std::make_unique<SPIRVLegalizerInfo>(*this);
RegBankInfo = std::make_unique<SPIRVRegisterBankInfo>();
InstSelector.reset(
createSPIRVInstructionSelector(TM, *this, *RegBankInfo.get()));
}
SPIRVSubtarget &SPIRVSubtarget::initSubtargetDependencies(StringRef CPU,

View File

@ -30,7 +30,7 @@
namespace llvm {
class StringRef;
class SPIRVGlobalRegistry;
class SPIRVTargetMachine;
class SPIRVSubtarget : public SPIRVGenSubtargetInfo {
@ -38,6 +38,8 @@ private:
const unsigned PointerSize;
uint32_t SPIRVVersion;
std::unique_ptr<SPIRVGlobalRegistry> GR;
SPIRVInstrInfo InstrInfo;
SPIRVFrameLowering FrameLowering;
SPIRVTargetLowering TLInfo;
@ -45,6 +47,8 @@ private:
// GlobalISel related APIs.
std::unique_ptr<CallLowering> CallLoweringInfo;
std::unique_ptr<RegisterBankInfo> RegBankInfo;
std::unique_ptr<LegalizerInfo> Legalizer;
std::unique_ptr<InstructionSelector> InstSelector;
public:
// This constructor initializes the data members to match that
@ -59,6 +63,7 @@ public:
unsigned getPointerSize() const { return PointerSize; }
bool canDirectlyComparePointers() const;
uint32_t getSPIRVVersion() const { return SPIRVVersion; };
SPIRVGlobalRegistry *getSPIRVGlobalRegistry() const { return GR.get(); }
const CallLowering *getCallLowering() const override {
return CallLoweringInfo.get();
@ -66,6 +71,12 @@ public:
const RegisterBankInfo *getRegBankInfo() const override {
return RegBankInfo.get();
}
const LegalizerInfo *getLegalizerInfo() const override {
return Legalizer.get();
}
InstructionSelector *getInstructionSelector() const override {
return InstSelector.get();
}
const SPIRVInstrInfo *getInstrInfo() const override { return &InstrInfo; }
const SPIRVFrameLowering *getFrameLowering() const override {
return &FrameLowering;

View File

@ -12,6 +12,9 @@
#include "SPIRVTargetMachine.h"
#include "SPIRV.h"
#include "SPIRVCallLowering.h"
#include "SPIRVGlobalRegistry.h"
#include "SPIRVLegalizerInfo.h"
#include "SPIRVTargetObjectFile.h"
#include "SPIRVTargetTransformInfo.h"
#include "TargetInfo/SPIRVTargetInfo.h"
@ -34,6 +37,9 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeSPIRVTarget() {
// Register the target.
RegisterTargetMachine<SPIRVTargetMachine> X(getTheSPIRV32Target());
RegisterTargetMachine<SPIRVTargetMachine> Y(getTheSPIRV64Target());
PassRegistry &PR = *PassRegistry::getPassRegistry();
initializeGlobalISel(PR);
}
static std::string computeDataLayout(const Triple &TT) {
@ -155,7 +161,19 @@ bool SPIRVPassConfig::addRegBankSelect() {
return false;
}
namespace {
// A custom subclass of InstructionSelect, which is mostly the same except from
// not requiring RegBankSelect to occur previously.
class SPIRVInstructionSelect : public InstructionSelect {
// We don't use register banks, so unset the requirement for them
MachineFunctionProperties getRequiredProperties() const override {
return InstructionSelect::getRequiredProperties().reset(
MachineFunctionProperties::Property::RegBankSelected);
}
};
} // namespace
bool SPIRVPassConfig::addGlobalInstructionSelect() {
addPass(new InstructionSelect(getOptLevel()));
addPass(new SPIRVInstructionSelect());
return false;
}

View File

@ -0,0 +1,182 @@
//===--- SPIRVUtils.cpp ---- SPIR-V Utility Functions -----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains miscellaneous utility functions.
//
//===----------------------------------------------------------------------===//
#include "SPIRVUtils.h"
#include "MCTargetDesc/SPIRVBaseInfo.h"
#include "SPIRV.h"
#include "SPIRVInstrInfo.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
#include "llvm/CodeGen/MachineInstr.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
using namespace llvm;
// The following functions are used to add these string literals as a series of
// 32-bit integer operands with the correct format, and unpack them if necessary
// when making string comparisons in compiler passes.
// SPIR-V requires null-terminated UTF-8 strings padded to 32-bit alignment.
static uint32_t convertCharsToWord(const StringRef &Str, unsigned i) {
uint32_t Word = 0u; // Build up this 32-bit word from 4 8-bit chars.
for (unsigned WordIndex = 0; WordIndex < 4; ++WordIndex) {
unsigned StrIndex = i + WordIndex;
uint8_t CharToAdd = 0; // Initilize char as padding/null.
if (StrIndex < Str.size()) { // If it's within the string, get a real char.
CharToAdd = Str[StrIndex];
}
Word |= (CharToAdd << (WordIndex * 8));
}
return Word;
}
// Get length including padding and null terminator.
static size_t getPaddedLen(const StringRef &Str) {
const size_t Len = Str.size() + 1;
return (Len % 4 == 0) ? Len : Len + (4 - (Len % 4));
}
void addStringImm(const StringRef &Str, MachineInstrBuilder &MIB) {
const size_t PaddedLen = getPaddedLen(Str);
for (unsigned i = 0; i < PaddedLen; i += 4) {
// Add an operand for the 32-bits of chars or padding.
MIB.addImm(convertCharsToWord(Str, i));
}
}
void addStringImm(const StringRef &Str, IRBuilder<> &B,
std::vector<Value *> &Args) {
const size_t PaddedLen = getPaddedLen(Str);
for (unsigned i = 0; i < PaddedLen; i += 4) {
// Add a vector element for the 32-bits of chars or padding.
Args.push_back(B.getInt32(convertCharsToWord(Str, i)));
}
}
std::string getStringImm(const MachineInstr &MI, unsigned StartIndex) {
return getSPIRVStringOperand(MI, StartIndex);
}
void addNumImm(const APInt &Imm, MachineInstrBuilder &MIB) {
const auto Bitwidth = Imm.getBitWidth();
switch (Bitwidth) {
case 1:
break; // Already handled.
case 8:
case 16:
case 32:
MIB.addImm(Imm.getZExtValue());
break;
case 64: {
uint64_t FullImm = Imm.getZExtValue();
uint32_t LowBits = FullImm & 0xffffffff;
uint32_t HighBits = (FullImm >> 32) & 0xffffffff;
MIB.addImm(LowBits).addImm(HighBits);
break;
}
default:
report_fatal_error("Unsupported constant bitwidth");
}
}
void buildOpName(Register Target, const StringRef &Name,
MachineIRBuilder &MIRBuilder) {
if (!Name.empty()) {
auto MIB = MIRBuilder.buildInstr(SPIRV::OpName).addUse(Target);
addStringImm(Name, MIB);
}
}
static void finishBuildOpDecorate(MachineInstrBuilder &MIB,
const std::vector<uint32_t> &DecArgs,
StringRef StrImm) {
if (!StrImm.empty())
addStringImm(StrImm, MIB);
for (const auto &DecArg : DecArgs)
MIB.addImm(DecArg);
}
void buildOpDecorate(Register Reg, MachineIRBuilder &MIRBuilder,
llvm::SPIRV::Decoration Dec,
const std::vector<uint32_t> &DecArgs, StringRef StrImm) {
auto MIB = MIRBuilder.buildInstr(SPIRV::OpDecorate)
.addUse(Reg)
.addImm(static_cast<uint32_t>(Dec));
finishBuildOpDecorate(MIB, DecArgs, StrImm);
}
void buildOpDecorate(Register Reg, MachineInstr &I, const SPIRVInstrInfo &TII,
llvm::SPIRV::Decoration Dec,
const std::vector<uint32_t> &DecArgs, StringRef StrImm) {
MachineBasicBlock &MBB = *I.getParent();
auto MIB = BuildMI(MBB, I, I.getDebugLoc(), TII.get(SPIRV::OpDecorate))
.addUse(Reg)
.addImm(static_cast<uint32_t>(Dec));
finishBuildOpDecorate(MIB, DecArgs, StrImm);
}
// TODO: maybe the following two functions should be handled in the subtarget
// to allow for different OpenCL vs Vulkan handling.
unsigned storageClassToAddressSpace(SPIRV::StorageClass SC) {
switch (SC) {
case SPIRV::StorageClass::Function:
return 0;
case SPIRV::StorageClass::CrossWorkgroup:
return 1;
case SPIRV::StorageClass::UniformConstant:
return 2;
case SPIRV::StorageClass::Workgroup:
return 3;
case SPIRV::StorageClass::Generic:
return 4;
case SPIRV::StorageClass::Input:
return 7;
default:
llvm_unreachable("Unable to get address space id");
}
}
SPIRV::StorageClass addressSpaceToStorageClass(unsigned AddrSpace) {
switch (AddrSpace) {
case 0:
return SPIRV::StorageClass::Function;
case 1:
return SPIRV::StorageClass::CrossWorkgroup;
case 2:
return SPIRV::StorageClass::UniformConstant;
case 3:
return SPIRV::StorageClass::Workgroup;
case 4:
return SPIRV::StorageClass::Generic;
case 7:
return SPIRV::StorageClass::Input;
default:
llvm_unreachable("Unknown address space");
}
}
SPIRV::MemorySemantics getMemSemanticsForStorageClass(SPIRV::StorageClass SC) {
switch (SC) {
case SPIRV::StorageClass::StorageBuffer:
case SPIRV::StorageClass::Uniform:
return SPIRV::MemorySemantics::UniformMemory;
case SPIRV::StorageClass::Workgroup:
return SPIRV::MemorySemantics::WorkgroupMemory;
case SPIRV::StorageClass::CrossWorkgroup:
return SPIRV::MemorySemantics::CrossWorkgroupMemory;
case SPIRV::StorageClass::AtomicCounter:
return SPIRV::MemorySemantics::AtomicCounterMemory;
case SPIRV::StorageClass::Image:
return SPIRV::MemorySemantics::ImageMemory;
default:
return SPIRV::MemorySemantics::None;
}
}

View File

@ -0,0 +1,69 @@
//===--- SPIRVUtils.h ---- SPIR-V Utility Functions -------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains miscellaneous utility functions.
//
//===----------------------------------------------------------------------===//
#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H
#define LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H
#include "MCTargetDesc/SPIRVBaseInfo.h"
#include "llvm/IR/IRBuilder.h"
#include <string>
namespace llvm {
class MCInst;
class MachineFunction;
class MachineInstr;
class MachineInstrBuilder;
class MachineIRBuilder;
class MachineRegisterInfo;
class Register;
class StringRef;
class SPIRVInstrInfo;
} // namespace llvm
// Add the given string as a series of integer operand, inserting null
// terminators and padding to make sure the operands all have 32-bit
// little-endian words.
void addStringImm(const llvm::StringRef &Str, llvm::MachineInstrBuilder &MIB);
void addStringImm(const llvm::StringRef &Str, llvm::IRBuilder<> &B,
std::vector<llvm::Value *> &Args);
// Read the series of integer operands back as a null-terminated string using
// the reverse of the logic in addStringImm.
std::string getStringImm(const llvm::MachineInstr &MI, unsigned StartIndex);
// Add the given numerical immediate to MIB.
void addNumImm(const llvm::APInt &Imm, llvm::MachineInstrBuilder &MIB);
// Add an OpName instruction for the given target register.
void buildOpName(llvm::Register Target, const llvm::StringRef &Name,
llvm::MachineIRBuilder &MIRBuilder);
// Add an OpDecorate instruction for the given Reg.
void buildOpDecorate(llvm::Register Reg, llvm::MachineIRBuilder &MIRBuilder,
llvm::SPIRV::Decoration Dec,
const std::vector<uint32_t> &DecArgs,
llvm::StringRef StrImm = "");
void buildOpDecorate(llvm::Register Reg, llvm::MachineInstr &I,
const llvm::SPIRVInstrInfo &TII,
llvm::SPIRV::Decoration Dec,
const std::vector<uint32_t> &DecArgs,
llvm::StringRef StrImm = "");
// Convert a SPIR-V storage class to the corresponding LLVM IR address space.
unsigned storageClassToAddressSpace(llvm::SPIRV::StorageClass SC);
// Convert an LLVM IR address space to a SPIR-V storage class.
llvm::SPIRV::StorageClass addressSpaceToStorageClass(unsigned AddrSpace);
llvm::SPIRV::MemorySemantics
getMemSemanticsForStorageClass(llvm::SPIRV::StorageClass SC);
#endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H