diff --git a/mlir/g3doc/Dialects/SPIR-V.md b/mlir/g3doc/Dialects/SPIR-V.md index 58400ef58bbc..19442f27ed0b 100644 --- a/mlir/g3doc/Dialects/SPIR-V.md +++ b/mlir/g3doc/Dialects/SPIR-V.md @@ -87,6 +87,35 @@ For example, !spv.array<16 x vector<4 x f32>> ``` +### Image type + +This corresponds to SPIR-V [image_type][ImageType]. Its syntax is + +``` {.ebnf} +dim ::= `1D` | `2D` | `3D` | `Cube` | + +depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown` + +arrayed-info ::= `NonArrayed` | `Arrayed` + +sampling-info ::= `SingleSampled` | `MultiSampled` + +sampler-use-info ::= `SamplerUnknown` | `NeedSampler` | `NoSampler` + +format ::= `Unknown` | `Rgba32f` | + +image-type ::= `!spv.image<` element-type `,` dim `,` depth-info `,` + arrayed-info `,` sampling-info `,` + sampler-use-info `,` format `>` +``` + +For example, + +``` {.mlir} +!spv.image +!spv.image +``` + ### Pointer type This corresponds to SPIR-V [pointer type][PointerType]. Its syntax is @@ -122,8 +151,8 @@ For example, !spv.rtarray> ``` - [SPIR-V]: https://www.khronos.org/registry/spir-v/ [ArrayType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeArray [PointerType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypePointer [RuntimeArrayType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeRuntimeArray +[ImageType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeImage diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 7f4b261af4a6..6f199123b0f4 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -688,6 +688,14 @@ class EnumAttr cases> : // llvm::StringRef (); // ``` string symbolToStringFnName = "stringify" # name; + + // The name of the utility function that returns the max enum value used + // within the enum class. It will have the following signature: + // + // ```c++ + // static constexpr unsigned (); + // ``` + string maxEnumValFnName = "getMaxEnumValFor" # name; } class ElementsAttrBase : diff --git a/mlir/include/mlir/SPIRV/SPIRVBase.td b/mlir/include/mlir/SPIRV/SPIRVBase.td index 50ac64af2f01..0a8b576743ae 100644 --- a/mlir/include/mlir/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/SPIRV/SPIRVBase.td @@ -116,6 +116,80 @@ def SPV_AddressingModelAttr : let underlyingType = "uint32_t"; } +def SPV_D_1D : EnumAttrCase<"1D", 0>; +def SPV_D_2D : EnumAttrCase<"2D", 1>; +def SPV_D_3D : EnumAttrCase<"3D", 2>; +def SPV_D_Cube : EnumAttrCase<"Cube", 3>; +def SPV_D_Rect : EnumAttrCase<"Rect", 4>; +def SPV_D_Buffer : EnumAttrCase<"Buffer", 5>; +def SPV_D_SubpassData : EnumAttrCase<"SubpassData", 6>; + +def SPV_DimAttr : + EnumAttr<"Dim", "valid SPIR-V Dim", [ + SPV_D_1D, SPV_D_2D, SPV_D_3D, SPV_D_Cube, SPV_D_Rect, SPV_D_Buffer, + SPV_D_SubpassData + ]> { + let cppNamespace = "::mlir::spirv"; + let underlyingType = "uint32_t"; +} + +def SPV_IF_Unknown : EnumAttrCase<"Unknown", 0>; +def SPV_IF_Rgba32f : EnumAttrCase<"Rgba32f", 1>; +def SPV_IF_Rgba16f : EnumAttrCase<"Rgba16f", 2>; +def SPV_IF_R32f : EnumAttrCase<"R32f", 3>; +def SPV_IF_Rgba8 : EnumAttrCase<"Rgba8", 4>; +def SPV_IF_Rgba8Snorm : EnumAttrCase<"Rgba8Snorm", 5>; +def SPV_IF_Rg32f : EnumAttrCase<"Rg32f", 6>; +def SPV_IF_Rg16f : EnumAttrCase<"Rg16f", 7>; +def SPV_IF_R11fG11fB10f : EnumAttrCase<"R11fG11fB10f", 8>; +def SPV_IF_R16f : EnumAttrCase<"R16f", 9>; +def SPV_IF_Rgba16 : EnumAttrCase<"Rgba16", 10>; +def SPV_IF_Rgb10A2 : EnumAttrCase<"Rgb10A2", 11>; +def SPV_IF_Rg16 : EnumAttrCase<"Rg16", 12>; +def SPV_IF_Rg8 : EnumAttrCase<"Rg8", 13>; +def SPV_IF_R16 : EnumAttrCase<"R16", 14>; +def SPV_IF_R8 : EnumAttrCase<"R8", 15>; +def SPV_IF_Rgba16Snorm : EnumAttrCase<"Rgba16Snorm", 16>; +def SPV_IF_Rg16Snorm : EnumAttrCase<"Rg16Snorm", 17>; +def SPV_IF_Rg8Snorm : EnumAttrCase<"Rg8Snorm", 18>; +def SPV_IF_R16Snorm : EnumAttrCase<"R16Snorm", 19>; +def SPV_IF_R8Snorm : EnumAttrCase<"R8Snorm", 20>; +def SPV_IF_Rgba32i : EnumAttrCase<"Rgba32i", 21>; +def SPV_IF_Rgba16i : EnumAttrCase<"Rgba16i", 22>; +def SPV_IF_Rgba8i : EnumAttrCase<"Rgba8i", 23>; +def SPV_IF_R32i : EnumAttrCase<"R32i", 24>; +def SPV_IF_Rg32i : EnumAttrCase<"Rg32i", 25>; +def SPV_IF_Rg16i : EnumAttrCase<"Rg16i", 26>; +def SPV_IF_Rg8i : EnumAttrCase<"Rg8i", 27>; +def SPV_IF_R16i : EnumAttrCase<"R16i", 28>; +def SPV_IF_R8i : EnumAttrCase<"R8i", 29>; +def SPV_IF_Rgba32ui : EnumAttrCase<"Rgba32ui", 30>; +def SPV_IF_Rgba16ui : EnumAttrCase<"Rgba16ui", 31>; +def SPV_IF_Rgba8ui : EnumAttrCase<"Rgba8ui", 32>; +def SPV_IF_R32ui : EnumAttrCase<"R32ui", 33>; +def SPV_IF_Rgb10a2ui : EnumAttrCase<"Rgb10a2ui", 34>; +def SPV_IF_Rg32ui : EnumAttrCase<"Rg32ui", 35>; +def SPV_IF_Rg16ui : EnumAttrCase<"Rg16ui", 36>; +def SPV_IF_Rg8ui : EnumAttrCase<"Rg8ui", 37>; +def SPV_IF_R16ui : EnumAttrCase<"R16ui", 38>; +def SPV_IF_R8ui : EnumAttrCase<"R8ui", 39>; + +def SPV_ImageFormatAttr : + EnumAttr<"ImageFormat", "valid SPIR-V ImageFormat", [ + SPV_IF_Unknown, SPV_IF_Rgba32f, SPV_IF_Rgba16f, SPV_IF_R32f, SPV_IF_Rgba8, + SPV_IF_Rgba8Snorm, SPV_IF_Rg32f, SPV_IF_Rg16f, SPV_IF_R11fG11fB10f, + SPV_IF_R16f, SPV_IF_Rgba16, SPV_IF_Rgb10A2, SPV_IF_Rg16, SPV_IF_Rg8, + SPV_IF_R16, SPV_IF_R8, SPV_IF_Rgba16Snorm, SPV_IF_Rg16Snorm, SPV_IF_Rg8Snorm, + SPV_IF_R16Snorm, SPV_IF_R8Snorm, SPV_IF_Rgba32i, SPV_IF_Rgba16i, SPV_IF_Rgba8i, + SPV_IF_R32i, SPV_IF_Rg32i, SPV_IF_Rg16i, SPV_IF_Rg8i, SPV_IF_R16i, SPV_IF_R8i, + SPV_IF_Rgba32ui, SPV_IF_Rgba16ui, SPV_IF_Rgba8ui, SPV_IF_R32ui, + SPV_IF_Rgb10a2ui, SPV_IF_Rg32ui, SPV_IF_Rg16ui, SPV_IF_Rg8ui, SPV_IF_R16ui, + SPV_IF_R8ui + ]> { + let cppNamespace = "::mlir::spirv"; + let underlyingType = "uint32_t"; +} + def SPV_MM_Simple : EnumAttrCase<"Simple", 0>; def SPV_MM_GLSL450 : EnumAttrCase<"GLSL450", 1>; def SPV_MM_OpenCL : EnumAttrCase<"OpenCL", 2>; @@ -165,6 +239,50 @@ def SPV_StorageClassAttr : // End enum section. Generated from SPIR-V spec; DO NOT MODIFY! +// Enums added manually that are not part of SPIRV spec + +def SPV_IDI_NoDepth : EnumAttrCase<"NoDepth", 0>; +def SPV_IDI_IsDepth : EnumAttrCase<"IsDepth", 1>; +def SPV_IDI_DepthUnknown : EnumAttrCase<"DepthUnknown", 2>; + +def SPV_DepthAttr : + EnumAttr<"ImageDepthInfo", "valid SPIR-V Image Depth specification",[ + SPV_IDI_NoDepth, SPV_IDI_IsDepth, SPV_IDI_DepthUnknown]> { + let cppNamespace = "::mlir::spirv"; + let underlyingType = "uint32_t"; +} + +def SPV_IAI_NonArrayed : EnumAttrCase<"NonArrayed", 0>; +def SPV_IAI_Arrayed : EnumAttrCase<"Arrayed", 1>; + +def SPV_ArrayedAttr : + EnumAttr<"ImageArrayedInfo", "valid SPIR-V Image Arrayed specification", [ + SPV_IAI_NonArrayed, SPV_IAI_Arrayed]> { + let cppNamespace = "::mlir::spirv"; + let underlyingType = "uint32_t"; +} + +def SPV_ISI_SingleSampled : EnumAttrCase<"SingleSampled", 0>; +def SPV_ISI_MultiSampled : EnumAttrCase<"MultiSampled", 1>; + +def SPV_SamplingAttr: + EnumAttr<"ImageSamplingInfo", "valid SPIR-V Image Sampling specification", [ + SPV_ISI_SingleSampled, SPV_ISI_MultiSampled]> { + let cppNamespace = "::mlir::spirv"; + let underlyingType = "uint32_t"; +} + +def SPV_ISUI_SamplerUnknown : EnumAttrCase<"SamplerUnknown", 0>; +def SPV_ISUI_NeedSampler : EnumAttrCase<"NeedSampler", 1>; +def SPV_ISUI_NoSampler : EnumAttrCase<"NoSampler", 2>; + +def SPV_SamplerUseAttr: + EnumAttr<"ImageSamplerUseInfo", "valid SPIR-V Sampler Use specification", [ + SPV_ISUI_SamplerUnknown, SPV_ISUI_NeedSampler, SPV_ISUI_NoSampler]> { + let cppNamespace = "::mlir::spirv"; + let underlyingType = "uint32_t"; +} + //===----------------------------------------------------------------------===// // SPIR-V op definitions //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/SPIRV/SPIRVDialect.h b/mlir/include/mlir/SPIRV/SPIRVDialect.h index 18667da374bc..4272a7220874 100644 --- a/mlir/include/mlir/SPIRV/SPIRVDialect.h +++ b/mlir/include/mlir/SPIRV/SPIRVDialect.h @@ -51,6 +51,9 @@ private: /// Parses `spec` as a SPIR-V run-time array type. Type parseRuntimeArrayType(StringRef spec, Location loc) const; + + /// Parses `spec` as a SPIR-V image type + Type parseImageType(StringRef spec, Location loc) const; }; } // end namespace spirv diff --git a/mlir/include/mlir/SPIRV/SPIRVTypes.h b/mlir/include/mlir/SPIRV/SPIRVTypes.h index ddab2de84941..80370e87becb 100644 --- a/mlir/include/mlir/SPIRV/SPIRVTypes.h +++ b/mlir/include/mlir/SPIRV/SPIRVTypes.h @@ -27,11 +27,14 @@ // Pull in all enum type definitions and utility function declarations #include "mlir/SPIRV/SPIRVEnums.h.inc" +#include + namespace mlir { namespace spirv { namespace detail { struct ArrayTypeStorage; +struct ImageTypeStorage; struct PointerTypeStorage; struct RuntimeArrayTypeStorage; } // namespace detail @@ -39,6 +42,7 @@ struct RuntimeArrayTypeStorage; namespace TypeKind { enum Kind { Array = Type::FIRST_SPIRV_TYPE, + ImageType, Pointer, RuntimeArray, }; @@ -89,6 +93,42 @@ public: Type getElementType(); }; +// SPIR-V image type +class ImageType + : public Type::TypeBase { +public: + using Base::Base; + + static bool kindof(unsigned kind) { return kind == TypeKind::ImageType; } + + static ImageType + get(Type elementType, Dim dim, + ImageDepthInfo depth = ImageDepthInfo::DepthUnknown, + ImageArrayedInfo arrayed = ImageArrayedInfo::NonArrayed, + ImageSamplingInfo samplingInfo = ImageSamplingInfo::SingleSampled, + ImageSamplerUseInfo samplerUse = ImageSamplerUseInfo::SamplerUnknown, + ImageFormat format = ImageFormat::Unknown) { + return ImageType::get( + std::tuple( + elementType, dim, depth, arrayed, samplingInfo, samplerUse, + format)); + } + + static ImageType + get(std::tuple); + + Type getElementType(); + Dim getDim(); + ImageDepthInfo getDepthInfo(); + ImageArrayedInfo getArrayedInfo(); + ImageSamplingInfo getSamplingInfo(); + ImageSamplerUseInfo getSamplerUseInfo(); + ImageFormat getImageFormat(); + // TODO(ravishankarm): Add support for Access qualifier +}; + } // end namespace spirv } // end namespace mlir diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h index f69961ad8e6c..f5a8764d11bd 100644 --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -160,6 +160,10 @@ public: // corresponding string. StringRef getSymbolToStringFnName() const; + // Returns the name of the utilit function that returns the max enum value + // used within the enum class. + StringRef getMaxEnumValFnName() const; + // Returns all allowed cases for this enum attribute. std::vector getAllCases() const; }; diff --git a/mlir/lib/SPIRV/SPIRVDialect.cpp b/mlir/lib/SPIRV/SPIRVDialect.cpp index bca27b0b438a..f2885d4ce7ac 100644 --- a/mlir/lib/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/SPIRV/SPIRVDialect.cpp @@ -26,10 +26,14 @@ #include "mlir/Parser.h" #include "mlir/SPIRV/SPIRVOps.h" #include "mlir/SPIRV/SPIRVTypes.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/raw_ostream.h" +#include + using namespace mlir; using namespace mlir::spirv; @@ -39,7 +43,7 @@ using namespace mlir::spirv; SPIRVDialect::SPIRVDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) { - addTypes(); + addTypes(); addOperations< #define GET_OP_LIST @@ -73,8 +77,9 @@ static bool parseNumberX(StringRef &spec, int64_t &number) { return true; } -Type SPIRVDialect::parseAndVerifyType(StringRef spec, Location loc) const { - auto *context = getContext(); +static Type parseAndVerifyTypeImpl(SPIRVDialect const &dialect, Location loc, + StringRef spec) { + auto *context = dialect.getContext(); auto type = mlir::parseType(spec, context); if (!type) { context->emitError(loc, "cannot parse type: ") << spec; @@ -82,7 +87,7 @@ Type SPIRVDialect::parseAndVerifyType(StringRef spec, Location loc) const { } // Allow SPIR-V dialect types - if (&type.getDialect() == this) + if (&type.getDialect() == &dialect) return type; // Check other allowed types @@ -113,6 +118,10 @@ Type SPIRVDialect::parseAndVerifyType(StringRef spec, Location loc) const { return type; } +Type SPIRVDialect::parseAndVerifyType(StringRef spec, Location loc) const { + return parseAndVerifyTypeImpl(*this, loc, spec); +} + // element-type ::= integer-type // | floating-point-type // | vector-type @@ -209,10 +218,186 @@ Type SPIRVDialect::parseRuntimeArrayType(StringRef spec, Location loc) const { return RuntimeArrayType::get(elementType); } -Type SPIRVDialect::parseType(StringRef spec, Location loc) const { +// Specialize this function to parse each of the parameters that define an +// ImageType +template +Optional parseAndVerify(SPIRVDialect const &dialect, Location loc, + StringRef spec) { + auto *context = dialect.getContext(); + context->emitError(loc, "unexpected parameter while parsing '") + << spec << "'"; + return llvm::None; +} +template <> +Optional parseAndVerify(SPIRVDialect const &dialect, Location loc, + StringRef spec) { + // TODO(ravishankarm): Further verify that the element type can be sampled + return parseAndVerifyTypeImpl(dialect, loc, spec); +} + +template <> +Optional parseAndVerify(SPIRVDialect const &dialect, Location loc, + StringRef spec) { + auto dim = symbolizeDim(spec); + if (!dim) { + auto *context = dialect.getContext(); + context->emitError(loc, "unknown Dim in Image type: '") << spec << "'"; + } + return dim; +} + +template <> +Optional +parseAndVerify(SPIRVDialect const &dialect, Location loc, + StringRef spec) { + auto depth = symbolizeImageDepthInfo(spec); + if (!depth) { + auto *context = dialect.getContext(); + context->emitError(loc, "unknown ImageDepthInfo in Image type: '") + << spec << "'"; + } + return depth; +} + +template <> +Optional +parseAndVerify(SPIRVDialect const &dialect, Location loc, + StringRef spec) { + auto arrayedInfo = symbolizeImageArrayedInfo(spec); + if (!arrayedInfo) { + auto *context = dialect.getContext(); + context->emitError(loc, "unknown ImageArrayedInfo in Image type: '") + << spec << "'"; + } + return arrayedInfo; +} + +template <> +Optional +parseAndVerify(SPIRVDialect const &dialect, Location loc, + StringRef spec) { + auto samplingInfo = symbolizeImageSamplingInfo(spec); + if (!samplingInfo) { + auto *context = dialect.getContext(); + context->emitError(loc, "unknown ImageSamplingInfo in Image type: '") + << spec << "'"; + } + return samplingInfo; +} + +template <> +Optional +parseAndVerify(SPIRVDialect const &dialect, Location loc, + StringRef spec) { + auto samplerUseInfo = symbolizeImageSamplerUseInfo(spec); + if (!samplerUseInfo) { + auto *context = dialect.getContext(); + context->emitError(loc, "unknown ImageSamplerUseInfo in Image type: '") + << spec << "'"; + } + return samplerUseInfo; +} + +template <> +Optional parseAndVerify(SPIRVDialect const &dialect, + Location loc, + StringRef spec) { + auto format = symbolizeImageFormat(spec); + if (!format) { + auto *context = dialect.getContext(); + context->emitError(loc, "unknown ImageFormat in Image type: '") + << spec << "'"; + } + return format; +} + +// Functor object to parse a comma separated list of specs. The function +// parseAndVerify does the actual parsing and verification of individual +// elements. This is a functor since parsing the last element of the list +// (termination condition) needs partial specialization. +template struct parseCommaSeparatedList { + Optional> + operator()(SPIRVDialect const &dialect, Location loc, StringRef spec) const { + auto numArgs = std::tuple_size>::value; + StringRef parseSpec, restSpec; + auto *context = dialect.getContext(); + std::tie(parseSpec, restSpec) = spec.split(','); + + parseSpec = parseSpec.trim(); + if (numArgs != 0 && restSpec.empty()) { + context->emitError(loc, "expected more parameters for image type '") + << parseSpec << "'"; + return llvm::None; + } + + auto parseVal = parseAndVerify(dialect, loc, parseSpec); + if (!parseVal) { + return llvm::None; + } + + auto remainingValues = + parseCommaSeparatedList{}(dialect, loc, restSpec); + if (!remainingValues) { + return llvm::None; + } + return std::tuple_cat(std::tuple(parseVal.getValue()), + remainingValues.getValue()); + } +}; + +// Partial specialization of the function to parse a comma separated list of +// specs to parse the last element of the list. +template struct parseCommaSeparatedList { + Optional> + operator()(SPIRVDialect const &dialect, Location loc, StringRef spec) const { + spec = spec.trim(); + auto value = parseAndVerify(dialect, loc, spec); + if (!value) { + return llvm::None; + } + return std::tuple(value.getValue()); + } +}; + +// dim ::= `1D` | `2D` | `3D` | `Cube` | +// +// depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown` +// +// arrayed-info ::= `NonArrayed` | `Arrayed` +// +// sampling-info ::= `SingleSampled` | `MultiSampled` +// +// sampler-use-info ::= `SamplerUnknown` | `NeedSampler` | `NoSampler` +// +// format ::= `Unknown` | `Rgba32f` | +// +// image-type ::= `!spv.image<` element-type `,` dim `,` depth-info `,` +// arrayed-info `,` sampling-info `,` +// sampler-use-info `,` format `>` +Type SPIRVDialect::parseImageType(StringRef spec, Location loc) const { + auto *context = getContext(); + if (!spec.consume_front("image<") || !spec.consume_back(">")) { + context->emitError(loc, "spv.image delimiter <...> mismatch"); + return Type(); + } + + auto value = + parseCommaSeparatedList{}(*this, loc, spec); + if (!value) { + return Type(); + } + + return ImageType::get(value.getValue()); +} + +Type SPIRVDialect::parseType(StringRef spec, Location loc) const { if (spec.startswith("array")) return parseArrayType(spec, loc); + if (spec.startswith("image")) + return parseImageType(spec, loc); if (spec.startswith("ptr")) return parsePointerType(spec, loc); if (spec.startswith("rtarray")) @@ -240,6 +425,15 @@ static void print(PointerType type, llvm::raw_ostream &os) { << stringifyStorageClass(type.getStorageClass()) << ">"; } +static void print(ImageType type, llvm::raw_ostream &os) { + os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim()) + << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", " + << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", " + << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", " + << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", " + << stringifyImageFormat(type.getImageFormat()) << ">"; +} + void SPIRVDialect::printType(Type type, llvm::raw_ostream &os) const { switch (type.getKind()) { case TypeKind::Array: @@ -251,6 +445,9 @@ void SPIRVDialect::printType(Type type, llvm::raw_ostream &os) const { case TypeKind::RuntimeArray: print(type.cast(), os); return; + case TypeKind::ImageType: + print(type.cast(), os); + return; default: llvm_unreachable("unhandled SPIR-V type"); } diff --git a/mlir/lib/SPIRV/SPIRVTypes.cpp b/mlir/lib/SPIRV/SPIRVTypes.cpp index b273cdc99d5e..d8c648be141a 100644 --- a/mlir/lib/SPIRV/SPIRVTypes.cpp +++ b/mlir/lib/SPIRV/SPIRVTypes.cpp @@ -60,6 +60,185 @@ Type ArrayType::getElementType() { return getImpl()->elementType; } int64_t ArrayType::getElementCount() { return getImpl()->elementCount; } +//===----------------------------------------------------------------------===// +// ImageType +//===----------------------------------------------------------------------===// + +template static constexpr unsigned getNumBits() { return 0; } +template <> constexpr unsigned getNumBits() { + static_assert((1 << 3) > getMaxEnumValForDim(), + "Not enough bits to encode Dim value"); + return 3; +} +template <> constexpr unsigned getNumBits() { + static_assert((1 << 2) > getMaxEnumValForImageDepthInfo(), + "Not enough bits to encode ImageDepthInfo value"); + return 2; +} +template <> constexpr unsigned getNumBits() { + static_assert((1 << 1) > getMaxEnumValForImageArrayedInfo(), + "Not enough bits to encode ImageArrayedInfo value"); + return 1; +} +template <> constexpr unsigned getNumBits() { + static_assert((1 << 1) > getMaxEnumValForImageSamplingInfo(), + "Not enough bits to encode ImageSamplingInfo value"); + return 1; +} +template <> constexpr unsigned getNumBits() { + static_assert((1 << 2) > getMaxEnumValForImageSamplerUseInfo(), + "Not enough bits to encode ImageSamplerUseInfo value"); + return 2; +} +template <> constexpr unsigned getNumBits() { + static_assert((1 << 6) > getMaxEnumValForImageFormat(), + "Not enough bits to encode ImageFormat value"); + return 6; +} + +struct spirv::detail::ImageTypeStorage : public TypeStorage { +private: + /// Define a bit-field struct to pack the enum values + union EnumPack { + struct { + Dim dim : getNumBits(); + ImageDepthInfo depthInfo : getNumBits(); + ImageArrayedInfo arrayedInfo : getNumBits(); + ImageSamplingInfo samplingInfo : getNumBits(); + ImageSamplerUseInfo samplerUseInfo : getNumBits(); + ImageFormat format : getNumBits(); + } data; + unsigned storage; + }; + +public: + using KeyTy = std::tuple; + + static ImageTypeStorage *construct(TypeStorageAllocator &allocator, + const KeyTy &key) { + return new (allocator.allocate()) ImageTypeStorage(key); + } + + bool operator==(const KeyTy &key) const { + return key == KeyTy(elementType, getDim(), getDepthInfo(), getArrayedInfo(), + getSamplingInfo(), getSamplerUseInfo(), + getImageFormat()); + } + + Dim getDim() const { + EnumPack v; + v.storage = getSubclassData(); + return v.data.dim; + } + void setDim(Dim dim) { + EnumPack v; + v.storage = getSubclassData(); + v.data.dim = dim; + setSubclassData(v.storage); + } + + ImageDepthInfo getDepthInfo() const { + EnumPack v; + v.storage = getSubclassData(); + return v.data.depthInfo; + } + void setDepthInfo(ImageDepthInfo depthInfo) { + EnumPack v; + v.storage = getSubclassData(); + v.data.depthInfo = depthInfo; + setSubclassData(v.storage); + } + + ImageArrayedInfo getArrayedInfo() const { + EnumPack v; + v.storage = getSubclassData(); + return v.data.arrayedInfo; + } + void setArrayedInfo(ImageArrayedInfo arrayedInfo) { + EnumPack v; + v.storage = getSubclassData(); + v.data.arrayedInfo = arrayedInfo; + setSubclassData(v.storage); + } + + ImageSamplingInfo getSamplingInfo() const { + EnumPack v; + v.storage = getSubclassData(); + return v.data.samplingInfo; + } + void setSamplingInfo(ImageSamplingInfo samplingInfo) { + EnumPack v; + v.storage = getSubclassData(); + v.data.samplingInfo = samplingInfo; + setSubclassData(v.storage); + } + + ImageSamplerUseInfo getSamplerUseInfo() const { + EnumPack v; + v.storage = getSubclassData(); + return v.data.samplerUseInfo; + } + void setSamplerUseInfo(ImageSamplerUseInfo samplerUseInfo) { + EnumPack v; + v.storage = getSubclassData(); + v.data.samplerUseInfo = samplerUseInfo; + setSubclassData(v.storage); + } + + ImageFormat getImageFormat() const { + EnumPack v; + v.storage = getSubclassData(); + return v.data.format; + } + void setImageFormat(ImageFormat format) { + EnumPack v; + v.storage = getSubclassData(); + v.data.format = format; + setSubclassData(v.storage); + } + + ImageTypeStorage(const KeyTy &key) : elementType(std::get<0>(key)) { + static_assert(sizeof(EnumPack) <= sizeof(getSubclassData()), + "EnumPack size greater than subClassData type size"); + setDim(std::get<1>(key)); + setDepthInfo(std::get<2>(key)); + setArrayedInfo(std::get<3>(key)); + setSamplingInfo(std::get<4>(key)); + setSamplerUseInfo(std::get<5>(key)); + setImageFormat(std::get<6>(key)); + } + + Type elementType; +}; + +ImageType +ImageType::get(std::tuple + value) { + return Base::get(std::get<0>(value).getContext(), TypeKind::ImageType, value); +} + +Type ImageType::getElementType() { return getImpl()->elementType; } + +Dim ImageType::getDim() { return getImpl()->getDim(); } + +ImageDepthInfo ImageType::getDepthInfo() { return getImpl()->getDepthInfo(); } + +ImageArrayedInfo ImageType::getArrayedInfo() { + return getImpl()->getArrayedInfo(); +} + +ImageSamplingInfo ImageType::getSamplingInfo() { + return getImpl()->getSamplingInfo(); +} + +ImageSamplerUseInfo ImageType::getSamplerUseInfo() { + return getImpl()->getSamplerUseInfo(); +} + +ImageFormat ImageType::getImageFormat() { return getImpl()->getImageFormat(); } + //===----------------------------------------------------------------------===// // PointerType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/TableGen/Attribute.cpp b/mlir/lib/TableGen/Attribute.cpp index 29259bec69a8..0bd72ead67c2 100644 --- a/mlir/lib/TableGen/Attribute.cpp +++ b/mlir/lib/TableGen/Attribute.cpp @@ -178,6 +178,10 @@ StringRef tblgen::EnumAttr::getSymbolToStringFnName() const { return def->getValueAsString("symbolToStringFnName"); } +StringRef tblgen::EnumAttr::getMaxEnumValFnName() const { + return def->getValueAsString("maxEnumValFnName"); +} + std::vector tblgen::EnumAttr::getAllCases() const { const auto *inits = def->getValueAsListInit("enumerants"); diff --git a/mlir/test/SPIRV/types.mlir b/mlir/test/SPIRV/types.mlir index 0e87d0b02bee..857871a00d7d 100644 --- a/mlir/test/SPIRV/types.mlir +++ b/mlir/test/SPIRV/types.mlir @@ -130,3 +130,73 @@ func @missing_element_type(!spv.rtarray<>) -> () // expected-error @+1 {{cannot parse type: 4xf32}} func @redundant_count(!spv.rtarray<4xf32>) -> () + +// ----- + +//===----------------------------------------------------------------------===// +// ImageType +//===----------------------------------------------------------------------===// + +// CHECK: func @image_parameters_1D(!spv.image) +func @image_parameters_1D(!spv.image) -> () + +// ----- + +// expected-error @+1 {{expected more parameters for image type 'f32'}} +func @image_parameters_one_element(!spv.image) -> () + +// ----- + +// expected-error @+1 {{expected more parameters for image type '1D'}} +func @image_parameters_two_elements(!spv.image) -> () + +// ----- + +// expected-error @+1 {{expected more parameters for image type 'NoDepth'}} +func @image_parameters_three_elements(!spv.image) -> () + +// ----- + +// expected-error @+1 {{expected more parameters for image type 'NonArrayed'}} +func @image_parameters_four_elements(!spv.image) -> () + +// ----- + +// expected-error @+1 {{expected more parameters for image type 'SingleSampled'}} +func @image_parameters_five_elements(!spv.image) -> () + +// ----- + +// expected-error @+1 {{expected more parameters for image type 'SamplerUnknown'}} +func @image_parameters_six_elements(!spv.image) -> () + +// ----- + +// expected-error @+1 {{spv.image delimiter <...> mismatch}} +func @image_parameters_delimiter(!spv.image f32, 1D, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Unkown>) -> () + +// ----- + +// expected-error @+1 {{unknown Dim in Image type: '1D NoDepth'}} +func @image_parameters_nocomma_1(!spv.image) -> () + +// ----- + +// expected-error @+1 {{unknown ImageDepthInfo in Image type: 'NoDepth NonArrayed'}} +func @image_parameters_nocomma_2(!spv.image) -> () + +// ----- + +// expected-error @+1 {{unknown ImageArrayedInfo in Image type: 'NonArrayed SingleSampled'}} +func @image_parameters_nocomma_3(!spv.image) -> () + +// ----- + +// expected-error @+1 {{unknown ImageSamplingInfo in Image type: 'SingleSampled SamplerUnknown'}} +func @image_parameters_nocomma_4(!spv.image) -> () + +// ----- + +// expected-error @+1 {{expected more parameters for image type 'SamplerUnknown Unknown'}} +func @image_parameters_nocomma_5(!spv.image) -> () + diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp index ab86c9dd8cc4..e9a70f3131e2 100644 --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -30,6 +30,7 @@ #include "llvm/TableGen/TableGenBackend.h" using llvm::formatv; +using llvm::isDigit; using llvm::raw_ostream; using llvm::Record; using llvm::RecordKeeper; @@ -37,6 +38,14 @@ using llvm::StringRef; using mlir::tblgen::EnumAttr; using mlir::tblgen::EnumAttrCase; +static std::string makeIdentifier(StringRef str) { + if (!str.empty() && isDigit(static_cast(str.front()))) { + std::string newStr = std::string("_") + str.str(); + return newStr; + } + return str.str(); +} + static void emitEnumClass(const Record &enumDef, StringRef enumName, StringRef underlyingType, StringRef description, const std::vector &enumerants, @@ -49,7 +58,7 @@ static void emitEnumClass(const Record &enumDef, StringRef enumName, os << " {\n"; for (const auto &enumerant : enumerants) { - auto symbol = enumerant.getSymbol(); + auto symbol = makeIdentifier(enumerant.getSymbol()); auto value = enumerant.getValue(); if (value < 0) { llvm::PrintFatalError(enumDef.getLoc(), @@ -100,6 +109,7 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { StringRef description = enumAttr.getDescription(); StringRef strToSymFnName = enumAttr.getStringToSymbolFnName(); StringRef symToStrFnName = enumAttr.getSymbolToStringFnName(); + StringRef maxEnumValFnName = enumAttr.getMaxEnumValFnName(); auto enumerants = enumAttr.getAllCases(); llvm::SmallVector namespaces; @@ -119,6 +129,17 @@ static void emitEnumDecl(const Record &enumDef, raw_ostream &os) { for (auto ns : llvm::reverse(namespaces)) os << "} // namespace " << ns << "\n"; + // Emit the function to return the max enum value + unsigned maxEnumVal = 0; + for (const auto &enumerant : enumerants) { + auto value = enumerant.getValue(); + // Already checked that the value is non-negetive. + maxEnumVal = std::max(maxEnumVal, static_cast(value)); + } + os << formatv("inline constexpr unsigned {0}() {{\n", maxEnumValFnName); + os << formatv(" return {0};\n", maxEnumVal); + os << "}\n\n"; + // Emit DenseMapInfo for this enum class emitDenseMapInfo(enumName, underlyingType, cppNamespace, os); } @@ -151,7 +172,8 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) { os << " switch (val) {\n"; for (const auto &enumerant : enumerants) { auto symbol = enumerant.getSymbol(); - os << formatv(" case {0}::{1}: return \"{1}\";\n", enumName, symbol); + os << formatv(" case {0}::{1}: return \"{2}\";\n", enumName, + makeIdentifier(symbol), symbol); } os << " }\n"; os << " return \"\";\n"; @@ -163,7 +185,8 @@ static void emitEnumDef(const Record &enumDef, raw_ostream &os) { enumName); for (const auto &enumerant : enumerants) { auto symbol = enumerant.getSymbol(); - os << formatv(" .Case(\"{1}\", {0}::{1})\n", enumName, symbol); + os << formatv(" .Case(\"{1}\", {0}::{2})\n", enumName, symbol, + makeIdentifier(symbol)); } os << " .Default(llvm::None);\n"; os << "}\n";