Add support for SPIR-V Struct Types. Current support is limited to

supporting only Offset decorations

PiperOrigin-RevId: 256216704
This commit is contained in:
Mahesh Ravishankar 2019-07-02 12:30:34 -07:00 committed by Mehdi Amini
parent 08927308b7
commit c73edeec13
9 changed files with 458 additions and 95 deletions

View File

@ -152,6 +152,24 @@ For example,
!spv.rtarray<vector<4 x f32>>
```
### Struct type
This corresponds to SPIR-V [struct type][StructType]. Its syntax is
``` {.ebnf}
struct-type ::= `!spv.struct<` spirv-type (` [` integer-literal `]` )?
(`, ` spirv-type ( ` [` integer-literal `] ` )? )* `>`
```
For Example,
``` {.mlir}
!spv.struct<f32>
!spv.struct<f32 [0]>
!spv.struct<f32, !spv.image<f32, 1D, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Unknown>>
!spv.struct<f32 [0], i32 [4]>
```
## Serialization
The serialization library provides two entry points, `mlir::spirv::serialize()`
@ -168,7 +186,8 @@ for now). For the latter, please use the assembler/disassembler in the
[SPIR-V]: https://www.khronos.org/registry/spir-v/
[ArrayType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeArray
[ImageType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeImage
[PointerType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypePointer
[RuntimeArrayType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeRuntimeArray
[ImageType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#OpTypeImage
[StructType]: https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#Structure
[SPIRV-Tools]: https://github.com/KhronosGroup/SPIRV-Tools

View File

@ -145,10 +145,10 @@ public:
unsigned getKind() const;
/// Return the LLVMContext in which this type was uniqued.
MLIRContext *getContext();
MLIRContext *getContext() const;
/// Get the dialect this type is registered to.
Dialect &getDialect();
Dialect &getDialect() const;
// Convenience predicates. This is only for floating point types,
// derived types should use isa/dyn_cast.

View File

@ -38,22 +38,6 @@ public:
/// Prints a type registered to this dialect.
void printType(Type type, llvm::raw_ostream &os) const override;
private:
/// Parses `spec` as a type and verifies it can be used in SPIR-V types.
Type parseAndVerifyType(StringRef spec, Location loc) const;
/// Parses `spec` as a SPIR-V array type.
Type parseArrayType(StringRef spec, Location loc) const;
/// Parses `spec` as a SPIR-V pointer type.
Type parsePointerType(StringRef spec, Location loc) const;
/// Parses `spec` as a SPIR-V run-time array type.
Type parseRuntimeArrayType(StringRef spec, Location loc) const;
/// Parses `spec` as a SPIR-V image type
Type parseImageType(StringRef spec, Location loc) const;
};
} // end namespace spirv

View File

@ -83,7 +83,7 @@ def SPV_LoadOp : SPV_Op<"Load"> {
### Custom assembly form
``` {.ebnf}
memory-access ::= `"None"` | `"Volatile"` | `"Aligned"` integer-literal
memory-access ::= `"None"` | `"Volatile"` | `"Aligned", ` integer-literal
| `"NonTemporal"`
load-op ::= ssa-id ` = spv.Load ` storage-class ssa-use
@ -118,6 +118,8 @@ def SPV_LoadOp : SPV_Op<"Load"> {
return "alignment";
}
}];
let opcode = 61;
}
def SPV_ReturnOp : SPV_Op<"Return", [Terminator]> {
@ -157,7 +159,7 @@ def SPV_StoreOp : SPV_Op<"Store"> {
``` {.ebnf}
store-op ::= `spv.Store ` storage-class ssa-use `, ` ssa-use `, `
(memory-access)? : spirv-element-type
(`[` memory-access `]`)? `:` spirv-element-type
```
For example:
@ -185,6 +187,8 @@ def SPV_StoreOp : SPV_Op<"Store"> {
return "alignment";
}
}];
let opcode = 62;
}
def SPV_VariableOp : SPV_Op<"Variable"> {

View File

@ -37,14 +37,16 @@ struct ArrayTypeStorage;
struct ImageTypeStorage;
struct PointerTypeStorage;
struct RuntimeArrayTypeStorage;
struct StructTypeStorage;
} // namespace detail
namespace TypeKind {
enum Kind {
Array = Type::FIRST_SPIRV_TYPE,
ImageType,
Image,
Pointer,
RuntimeArray,
Struct,
};
}
@ -58,9 +60,9 @@ public:
static ArrayType get(Type elementType, int64_t elementCount);
Type getElementType();
Type getElementType() const;
int64_t getElementCount();
int64_t getElementCount() const;
};
// SPIR-V pointer type
@ -73,9 +75,10 @@ public:
static PointerType get(Type pointeeType, StorageClass storageClass);
Type getPointeeType();
Type getPointeeType() const;
StorageClass getStorageClass();
StorageClass getStorageClass() const;
StringRef getStorageClassStr() const;
};
// SPIR-V run-time array type
@ -89,16 +92,17 @@ public:
static RuntimeArrayType get(Type elementType);
Type getElementType();
Type getElementType() const;
};
// SPIR-V image type
// TODO(ravishankarm) : Move this in alphabetical order
class ImageType
: public Type::TypeBase<ImageType, Type, detail::ImageTypeStorage> {
public:
using Base::Base;
static bool kindof(unsigned kind) { return kind == TypeKind::ImageType; }
static bool kindof(unsigned kind) { return kind == TypeKind::Image; }
static ImageType
get(Type elementType, Dim dim,
@ -118,16 +122,45 @@ public:
get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>);
Type getElementType();
Dim getDim();
ImageDepthInfo getDepthInfo();
ImageArrayedInfo getArrayedInfo();
ImageSamplingInfo getSamplingInfo();
ImageSamplerUseInfo getSamplerUseInfo();
ImageFormat getImageFormat();
Type getElementType() const;
Dim getDim() const;
ImageDepthInfo getDepthInfo() const;
ImageArrayedInfo getArrayedInfo() const;
ImageSamplingInfo getSamplingInfo() const;
ImageSamplerUseInfo getSamplerUseInfo() const;
ImageFormat getImageFormat() const;
// TODO(ravishankarm): Add support for Access qualifier
};
// SPIR-V struct type
class StructType
: public Type::TypeBase<StructType, Type, detail::StructTypeStorage> {
public:
using Base::Base;
// Layout information used for members in a struct in SPIR-V
//
// TODO(ravishankarm) : For now this only supports the offset type, so uses
// uint64_t value to represent the offset, with
// std::numeric_limit<uint64_t>::max indicating no offset. Change this to
// something that can hold all the information needed for different member
// types
using LayoutInfo = uint64_t;
static bool kindof(unsigned kind) { return kind == TypeKind::Struct; }
static StructType get(ArrayRef<Type> memberTypes);
static StructType get(ArrayRef<Type> memberTypes,
ArrayRef<LayoutInfo> layoutInfo);
size_t getNumMembers() const;
Type getMemberType(size_t) const;
bool hasLayout() const;
uint64_t getOffset(size_t) const;
};
} // end namespace spirv
} // end namespace mlir

View File

@ -27,9 +27,9 @@ using namespace mlir::detail;
unsigned Type::getKind() const { return impl->getKind(); }
/// Get the dialect this type is registered to.
Dialect &Type::getDialect() { return impl->getDialect(); }
Dialect &Type::getDialect() const { return impl->getDialect(); }
MLIRContext *Type::getContext() { return getDialect().getContext(); }
MLIRContext *Type::getContext() const { return getDialect().getContext(); }
unsigned Type::getSubclassData() const { return impl->getSubclassData(); }
void Type::setSubclassData(unsigned val) { impl->setSubclassData(val); }

View File

@ -1,26 +1,16 @@
//===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===//
//
// Copyright 2019 The MLIR Authors.
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//===----------------------------------------------------------------------===//
//
// This file defines the SPIR-V dialect in MLIR.
//
//===----------------------------------------------------------------------===//
#include "mlir/SPIRV/SPIRVDialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Parser.h"
@ -32,8 +22,6 @@
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/raw_ostream.h"
#include <type_traits>
using namespace mlir;
using namespace mlir::spirv;
@ -43,7 +31,7 @@ using namespace mlir::spirv;
SPIRVDialect::SPIRVDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
addTypes<ArrayType, ImageType, PointerType, RuntimeArrayType>();
addTypes<ArrayType, ImageType, PointerType, RuntimeArrayType, StructType>();
addOperations<
#define GET_OP_LIST
@ -77,8 +65,9 @@ static bool parseNumberX(StringRef &spec, int64_t &number) {
return true;
}
static Type parseAndVerifyTypeImpl(SPIRVDialect const &dialect, Location loc,
StringRef spec) {
static Type parseAndVerifyType(SPIRVDialect const &dialect, StringRef spec,
Location loc) {
spec = spec.trim();
auto *context = dialect.getContext();
auto type = mlir::parseType(spec.trim(), context);
if (!type) {
@ -116,17 +105,14 @@ static Type parseAndVerifyTypeImpl(SPIRVDialect const &dialect, Location loc,
return type;
}
Type SPIRVDialect::parseAndVerifyType(StringRef spec, Location loc) const {
return parseAndVerifyTypeImpl(*this, loc, spec);
}
// element-type ::= integer-type
// | floating-point-type
// | vector-type
// | spirv-type
//
// array-type ::= `!spv.array<` integer-literal `x` element-type `>`
Type SPIRVDialect::parseArrayType(StringRef spec, Location loc) const {
static Type parseArrayType(SPIRVDialect const &dialect, StringRef spec,
Location loc) {
if (!spec.consume_front("array<") || !spec.consume_back(">")) {
emitError(loc, "spv.array delimiter <...> mismatch");
return Type();
@ -145,20 +131,24 @@ Type SPIRVDialect::parseArrayType(StringRef spec, Location loc) const {
return Type();
}
Type elementType = parseAndVerifyType(spec, loc);
Type elementType = parseAndVerifyType(dialect, spec, loc);
if (!elementType)
return Type();
return ArrayType::get(elementType, count);
}
// TODO(ravishankarm) : Reorder methods to be utilities first and parse*Type
// methods in alphabetical order
//
// storage-class ::= `UniformConstant`
// | `Uniform`
// | `Workgroup`
// | <and other storage classes...>
//
// pointer-type ::= `!spv.ptr<` element-type `,` storage-class `>`
Type SPIRVDialect::parsePointerType(StringRef spec, Location loc) const {
static Type parsePointerType(SPIRVDialect const &dialect, StringRef spec,
Location loc) {
if (!spec.consume_front("ptr<") || !spec.consume_back(">")) {
emitError(loc, "spv.ptr delimiter <...> mismatch");
return Type();
@ -186,7 +176,7 @@ Type SPIRVDialect::parsePointerType(StringRef spec, Location loc) const {
return Type();
}
auto pointeeType = parseAndVerifyType(ptSpec, loc);
auto pointeeType = parseAndVerifyType(dialect, ptSpec, loc);
if (!pointeeType)
return Type();
@ -194,7 +184,8 @@ Type SPIRVDialect::parsePointerType(StringRef spec, Location loc) const {
}
// runtime-array-type ::= `!spv.rtarray<` element-type `>`
Type SPIRVDialect::parseRuntimeArrayType(StringRef spec, Location loc) const {
static Type parseRuntimeArrayType(SPIRVDialect const &dialect, StringRef spec,
Location loc) {
if (!spec.consume_front("rtarray<") || !spec.consume_back(">")) {
emitError(loc, "spv.rtarray delimiter <...> mismatch");
return Type();
@ -205,7 +196,7 @@ Type SPIRVDialect::parseRuntimeArrayType(StringRef spec, Location loc) const {
return Type();
}
Type elementType = parseAndVerifyType(spec, loc);
Type elementType = parseAndVerifyType(dialect, spec, loc);
if (!elementType)
return Type();
@ -215,8 +206,8 @@ Type SPIRVDialect::parseRuntimeArrayType(StringRef spec, Location loc) const {
// Specialize this function to parse each of the parameters that define an
// ImageType
template <typename ValTy>
Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect, Location loc,
StringRef spec) {
static Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect, Location loc,
StringRef spec) {
emitError(loc, "unexpected parameter while parsing '") << spec << "'";
return llvm::None;
}
@ -225,15 +216,20 @@ template <>
Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect, Location loc,
StringRef spec) {
// TODO(ravishankarm): Further verify that the element type can be sampled
return parseAndVerifyTypeImpl(dialect, loc, spec);
auto ty = parseAndVerifyType(dialect, spec, loc);
if (!ty) {
return llvm::None;
}
return ty;
}
template <>
Optional<Dim> parseAndVerify<Dim>(SPIRVDialect const &dialect, Location loc,
StringRef spec) {
auto dim = symbolizeDim(spec);
if (!dim)
if (!dim) {
emitError(loc, "unknown Dim in Image type: '") << spec << "'";
}
return dim;
}
@ -242,8 +238,9 @@ Optional<ImageDepthInfo>
parseAndVerify<ImageDepthInfo>(SPIRVDialect const &dialect, Location loc,
StringRef spec) {
auto depth = symbolizeImageDepthInfo(spec);
if (!depth)
if (!depth) {
emitError(loc, "unknown ImageDepthInfo in Image type: '") << spec << "'";
}
return depth;
}
@ -252,8 +249,9 @@ Optional<ImageArrayedInfo>
parseAndVerify<ImageArrayedInfo>(SPIRVDialect const &dialect, Location loc,
StringRef spec) {
auto arrayedInfo = symbolizeImageArrayedInfo(spec);
if (!arrayedInfo)
if (!arrayedInfo) {
emitError(loc, "unknown ImageArrayedInfo in Image type: '") << spec << "'";
}
return arrayedInfo;
}
@ -262,8 +260,9 @@ Optional<ImageSamplingInfo>
parseAndVerify<ImageSamplingInfo>(SPIRVDialect const &dialect, Location loc,
StringRef spec) {
auto samplingInfo = symbolizeImageSamplingInfo(spec);
if (!samplingInfo)
if (!samplingInfo) {
emitError(loc, "unknown ImageSamplingInfo in Image type: '") << spec << "'";
}
return samplingInfo;
}
@ -272,9 +271,10 @@ Optional<ImageSamplerUseInfo>
parseAndVerify<ImageSamplerUseInfo>(SPIRVDialect const &dialect, Location loc,
StringRef spec) {
auto samplerUseInfo = symbolizeImageSamplerUseInfo(spec);
if (!samplerUseInfo)
if (!samplerUseInfo) {
emitError(loc, "unknown ImageSamplerUseInfo in Image type: '")
<< spec << "'";
}
return samplerUseInfo;
}
@ -283,11 +283,41 @@ Optional<ImageFormat> parseAndVerify<ImageFormat>(SPIRVDialect const &dialect,
Location loc,
StringRef spec) {
auto format = symbolizeImageFormat(spec);
if (!format)
if (!format) {
emitError(loc, "unknown ImageFormat in Image type: '") << spec << "'";
}
return format;
}
template <>
Optional<spirv::StructType::LayoutInfo>
parseAndVerify(SPIRVDialect const &dialect, Location loc, StringRef spec) {
uint64_t offsetVal = std::numeric_limits<uint64_t>::max();
if (!spec.consume_front("[")) {
emitError(loc, "expected '[' while parsing layout specification in '")
<< spec << "'";
return llvm::None;
}
if (spec.consumeInteger(10, offsetVal)) {
emitError(
loc,
"expected unsigned integer to specify offset of member in struct: '")
<< spec << "'";
return llvm::None;
}
spec = spec.trim();
if (!spec.consume_front("]")) {
emitError(loc, "missing ']' in decorations spec: '") << spec << "'";
return llvm::None;
}
if (spec != "") {
emitError(loc, "unexpected extra tokens in layout information: '")
<< spec << "'";
return llvm::None;
}
return spirv::StructType::LayoutInfo{offsetVal};
}
// Functor object to parse a comma separated list of specs. The function
// parseAndVerify does the actual parsing and verification of individual
// elements. This is a functor since parsing the last element of the list
@ -350,7 +380,8 @@ template <typename ParseType> struct parseCommaSeparatedList<ParseType> {
// image-type ::= `!spv.image<` element-type `,` dim `,` depth-info `,`
// arrayed-info `,` sampling-info `,`
// sampler-use-info `,` format `>`
Type SPIRVDialect::parseImageType(StringRef spec, Location loc) const {
static Type parseImageType(SPIRVDialect const &dialect, StringRef spec,
Location loc) {
if (!spec.consume_front("image<") || !spec.consume_back(">")) {
emitError(loc, "spv.image delimiter <...> mismatch");
return Type();
@ -359,7 +390,7 @@ Type SPIRVDialect::parseImageType(StringRef spec, Location loc) const {
auto value =
parseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
ImageSamplingInfo, ImageSamplerUseInfo,
ImageFormat>{}(*this, loc, spec);
ImageFormat>{}(dialect, loc, spec);
if (!value) {
return Type();
}
@ -367,15 +398,151 @@ Type SPIRVDialect::parseImageType(StringRef spec, Location loc) const {
return ImageType::get(value.getValue());
}
// Method to parse one member of a struct (including Layout information)
static ParseResult
parseStructElement(SPIRVDialect const &dialect, StringRef spec, Location loc,
SmallVectorImpl<Type> &memberTypes,
SmallVectorImpl<StructType::LayoutInfo> &layoutInfo) {
// Check for a '[' <layoutInfo> ']'
auto lastLSquare = spec.rfind('[');
auto typeSpec = spec.substr(0, lastLSquare);
auto layoutSpec = (lastLSquare == StringRef::npos ? StringRef("")
: spec.substr(lastLSquare));
auto type = parseAndVerify<Type>(dialect, loc, typeSpec);
if (!type) {
return failure();
}
memberTypes.push_back(type.getValue());
if (layoutSpec.empty()) {
return success();
}
if (layoutInfo.size() != memberTypes.size() - 1) {
emitError(loc, "layout specification must be given for all members");
return failure();
}
auto layout =
parseAndVerify<StructType::LayoutInfo>(dialect, loc, layoutSpec);
if (!layout) {
return failure();
}
layoutInfo.push_back(layout.getValue());
return success();
}
// Helper method to record the position of the corresponding '>' for every '<'
// encountered when parsing the string left to right. The relative position of
// '>' w.r.t to the '<' is recorded.
static bool
computeMatchingRAngles(Location loc, StringRef const &spec,
SmallVectorImpl<size_t> &matchingRAngleOffset) {
SmallVector<size_t, 4> openBrackets;
for (size_t i = 0, e = spec.size(); i != e; ++i) {
if (spec[i] == '<') {
openBrackets.push_back(i);
} else if (spec[i] == '>') {
if (openBrackets.empty()) {
emitError(loc, "unbalanced '<' in '") << spec << "'";
return false;
}
matchingRAngleOffset.push_back(i - openBrackets.pop_back_val());
}
}
return true;
}
static ParseResult
parseStructHelper(SPIRVDialect const &dialect, StringRef spec, Location loc,
ArrayRef<size_t> matchingRAngleOffset,
SmallVectorImpl<Type> &memberTypes,
SmallVectorImpl<StructType::LayoutInfo> &layoutInfo) {
// Check if the occurrence of ',' or '<' is before. If former, split using
// ','. If latter, split using matching '>' to get the entire type
// description
auto firstComma = spec.find(',');
auto firstLAngle = spec.find('<');
if (firstLAngle == StringRef::npos && firstComma == StringRef::npos) {
return parseStructElement(dialect, spec, loc, memberTypes, layoutInfo);
}
if (firstLAngle == StringRef::npos || firstComma < firstLAngle) {
// Parse the type before the ','
if (parseStructElement(dialect, spec.substr(0, firstComma), loc,
memberTypes, layoutInfo)) {
return failure();
}
return parseStructHelper(dialect, spec.substr(firstComma + 1).ltrim(), loc,
matchingRAngleOffset, memberTypes, layoutInfo);
}
auto matchingRAngle = matchingRAngleOffset.front() + firstLAngle;
// Find the next ',' or '>'
auto endLoc = std::min(spec.find(',', matchingRAngle + 1), spec.size());
if (parseStructElement(dialect, spec.substr(0, endLoc), loc, memberTypes,
layoutInfo)) {
return failure();
}
auto rest = spec.substr(endLoc + 1).ltrim();
if (rest.empty()) {
return success();
}
if (rest.front() == ',') {
return parseStructHelper(
dialect, rest.drop_front().trim(), loc,
ArrayRef<size_t>(std::next(matchingRAngleOffset.begin()),
matchingRAngleOffset.end()),
memberTypes, layoutInfo);
}
emitError(loc, "unexpected string : '") << rest << "'";
return failure();
}
// struct-type ::= `!spv.struct<` spirv-type (` [` integer-literal `]`)?
// (`, ` spirv-type ( ` [` integer-literal `] ` )? )*
static Type parseStructType(SPIRVDialect const &dialect, StringRef spec,
Location loc) {
if (!spec.consume_front("struct<") || !spec.consume_back(">")) {
emitError(loc, "spv.struct delimiter <...> mismatch");
return Type();
}
if (spec.trim().empty()) {
emitError(loc, "expected SPIR-V type");
return Type();
}
SmallVector<Type, 4> memberTypes;
SmallVector<StructType::LayoutInfo, 4> layoutInfo;
SmallVector<size_t, 4> matchingRAngleOffset;
if (!computeMatchingRAngles(loc, spec, matchingRAngleOffset) ||
parseStructHelper(dialect, spec, loc, matchingRAngleOffset, memberTypes,
layoutInfo)) {
return Type();
}
if (layoutInfo.empty()) {
return StructType::get(memberTypes);
}
if (memberTypes.size() != layoutInfo.size()) {
emitError(loc, "layout specification must be given for all members");
return Type();
}
return StructType::get(memberTypes, layoutInfo);
}
// spirv-type ::= array-type
// | element-type
// | image-type
// | pointer-type
// | runtime-array-type
// | struct-type
Type SPIRVDialect::parseType(StringRef spec, Location loc) const {
if (spec.startswith("array"))
return parseArrayType(spec, loc);
return parseArrayType(*this, spec, loc);
if (spec.startswith("image"))
return parseImageType(spec, loc);
return parseImageType(*this, spec, loc);
if (spec.startswith("ptr"))
return parsePointerType(spec, loc);
return parsePointerType(*this, spec, loc);
if (spec.startswith("rtarray"))
return parseRuntimeArrayType(spec, loc);
return parseRuntimeArrayType(*this, spec, loc);
if (spec.startswith("struct"))
return parseStructType(*this, spec, loc);
emitError(loc, "unknown SPIR-V type: ") << spec;
return Type();
@ -408,6 +575,19 @@ static void print(ImageType type, llvm::raw_ostream &os) {
<< stringifyImageFormat(type.getImageFormat()) << ">";
}
static void print(StructType type, llvm::raw_ostream &os) {
os << "struct<";
std::string sep = "";
for (size_t i = 0, e = type.getNumMembers(); i != e; ++i) {
os << sep << type.getMemberType(i);
if (type.hasLayout()) {
os << " [" << type.getOffset(i) << "]";
}
sep = ", ";
}
os << ">";
}
void SPIRVDialect::printType(Type type, llvm::raw_ostream &os) const {
switch (type.getKind()) {
case TypeKind::Array:
@ -419,9 +599,12 @@ void SPIRVDialect::printType(Type type, llvm::raw_ostream &os) const {
case TypeKind::RuntimeArray:
print(type.cast<RuntimeArrayType>(), os);
return;
case TypeKind::ImageType:
case TypeKind::Image:
print(type.cast<ImageType>(), os);
return;
case TypeKind::Struct:
print(type.cast<StructType>(), os);
return;
default:
llvm_unreachable("unhandled SPIR-V type");
}

View File

@ -56,9 +56,9 @@ ArrayType ArrayType::get(Type elementType, int64_t elementCount) {
elementCount);
}
Type ArrayType::getElementType() { return getImpl()->elementType; }
Type ArrayType::getElementType() const { return getImpl()->elementType; }
int64_t ArrayType::getElementCount() { return getImpl()->elementCount; }
int64_t ArrayType::getElementCount() const { return getImpl()->elementCount; }
//===----------------------------------------------------------------------===//
// ImageType
@ -216,28 +216,32 @@ ImageType
ImageType::get(std::tuple<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
ImageSamplingInfo, ImageSamplerUseInfo, ImageFormat>
value) {
return Base::get(std::get<0>(value).getContext(), TypeKind::ImageType, value);
return Base::get(std::get<0>(value).getContext(), TypeKind::Image, value);
}
Type ImageType::getElementType() { return getImpl()->elementType; }
Type ImageType::getElementType() const { return getImpl()->elementType; }
Dim ImageType::getDim() { return getImpl()->getDim(); }
Dim ImageType::getDim() const { return getImpl()->getDim(); }
ImageDepthInfo ImageType::getDepthInfo() { return getImpl()->getDepthInfo(); }
ImageDepthInfo ImageType::getDepthInfo() const {
return getImpl()->getDepthInfo();
}
ImageArrayedInfo ImageType::getArrayedInfo() {
ImageArrayedInfo ImageType::getArrayedInfo() const {
return getImpl()->getArrayedInfo();
}
ImageSamplingInfo ImageType::getSamplingInfo() {
ImageSamplingInfo ImageType::getSamplingInfo() const {
return getImpl()->getSamplingInfo();
}
ImageSamplerUseInfo ImageType::getSamplerUseInfo() {
ImageSamplerUseInfo ImageType::getSamplerUseInfo() const {
return getImpl()->getSamplerUseInfo();
}
ImageFormat ImageType::getImageFormat() { return getImpl()->getImageFormat(); }
ImageFormat ImageType::getImageFormat() const {
return getImpl()->getImageFormat();
}
//===----------------------------------------------------------------------===//
// PointerType
@ -274,12 +278,16 @@ PointerType PointerType::get(Type pointeeType, StorageClass storageClass) {
storageClass);
}
Type PointerType::getPointeeType() { return getImpl()->pointeeType; }
Type PointerType::getPointeeType() const { return getImpl()->pointeeType; }
StorageClass PointerType::getStorageClass() {
StorageClass PointerType::getStorageClass() const {
return getImpl()->getStorageClass();
}
StringRef PointerType::getStorageClassStr() const {
return stringifyStorageClass(getStorageClass());
}
//===----------------------------------------------------------------------===//
// RuntimeArrayType
//===----------------------------------------------------------------------===//
@ -305,4 +313,88 @@ RuntimeArrayType RuntimeArrayType::get(Type elementType) {
elementType);
}
Type RuntimeArrayType::getElementType() { return getImpl()->elementType; }
Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; }
//===----------------------------------------------------------------------===//
// StructType
//===----------------------------------------------------------------------===//
struct spirv::detail::StructTypeStorage : public TypeStorage {
StructTypeStorage(unsigned numMembers, Type const *memberTypes,
StructType::LayoutInfo const *layoutInfo)
: TypeStorage(numMembers), memberTypes(memberTypes),
layoutInfo(layoutInfo) {}
using KeyTy = std::pair<ArrayRef<Type>, ArrayRef<StructType::LayoutInfo>>;
bool operator==(const KeyTy &key) const {
return key == KeyTy(getMemberTypes(), getLayoutInfo());
}
static StructTypeStorage *construct(TypeStorageAllocator &allocator,
const KeyTy &key) {
ArrayRef<Type> keyTypes = key.first;
// Copy the member type and layout information into the bump pointer
auto typesList = allocator.copyInto(keyTypes).data();
const StructType::LayoutInfo *layoutInfoList = nullptr;
if (!key.second.empty()) {
ArrayRef<StructType::LayoutInfo> keyLayoutInfo = key.second;
assert(keyLayoutInfo.size() == keyTypes.size() &&
"size of layout information must be same as the size of number of "
"elements");
layoutInfoList = allocator.copyInto(keyLayoutInfo).data();
}
return new (allocator.allocate<StructTypeStorage>())
StructTypeStorage(keyTypes.size(), typesList, layoutInfoList);
}
ArrayRef<Type> getMemberTypes() const {
return ArrayRef<Type>(memberTypes, getSubclassData());
}
ArrayRef<StructType::LayoutInfo> getLayoutInfo() const {
if (layoutInfo) {
return ArrayRef<StructType::LayoutInfo>(layoutInfo, getSubclassData());
}
return ArrayRef<StructType::LayoutInfo>(nullptr, size_t(0));
}
Type const *memberTypes;
StructType::LayoutInfo const *layoutInfo;
};
StructType StructType::get(ArrayRef<Type> memberTypes) {
assert(!memberTypes.empty() && "Struct needs at least one member type");
ArrayRef<StructType::LayoutInfo> noLayout(nullptr, size_t(0));
return Base::get(memberTypes[0].getContext(), TypeKind::Struct, memberTypes,
noLayout);
}
StructType StructType::get(ArrayRef<Type> memberTypes,
ArrayRef<StructType::LayoutInfo> layoutInfo) {
assert(!memberTypes.empty() && "Struct needs at least one member type");
return Base::get(memberTypes.vec().front().getContext(), TypeKind::Struct,
memberTypes, layoutInfo);
}
size_t StructType::getNumMembers() const {
return getImpl()->getSubclassData();
}
Type StructType::getMemberType(size_t i) const {
assert(
getNumMembers() > i &&
"element index is more than number of members of the SPIR-V StructType");
return getImpl()->memberTypes[i];
}
bool StructType::hasLayout() const { return getImpl()->layoutInfo; }
uint64_t StructType::getOffset(size_t i) const {
assert(
getNumMembers() > i &&
"element index is more than number of members of the SPIR-V StructType");
return getImpl()->layoutInfo[i];
}

View File

@ -200,3 +200,51 @@ func @image_parameters_nocomma_4(!spv.image<f32, 1D, NoDepth, NonArrayed, Single
// expected-error @+1 {{expected more parameters for image type 'SamplerUnknown Unknown'}}
func @image_parameters_nocomma_5(!spv.image<f32, 1D, NoDepth, NonArrayed, SingleSampled, SamplerUnknown Unknown>) -> ()
// -----
//===----------------------------------------------------------------------===//
// StructType
//===----------------------------------------------------------------------===//
// CHECK: func @struct_type(!spv.struct<f32>)
func @struct_type(!spv.struct<f32>) -> ()
// CHECK: func @struct_type2(!spv.struct<f32 [0]>)
func @struct_type2(!spv.struct<f32 [0]>) -> ()
// CHECK: func @struct_type_simple(!spv.struct<f32, !spv.image<f32, 1D, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Unknown>>)
func @struct_type_simple(!spv.struct<f32, !spv.image<f32, 1D, NoDepth, NonArrayed, SingleSampled, SamplerUnknown, Unknown>>) -> ()
// CHECK: func @struct_type_with_offset(!spv.struct<f32 [0], i32 [4]>)
func @struct_type_with_offset(!spv.struct<f32 [0], i32 [4]>) -> ()
// CHECK: func @nested_struct(!spv.struct<f32, !spv.struct<f32, i32>>)
func @nested_struct(!spv.struct<f32, !spv.struct<f32, i32>>)
// CHECK: func @nested_struct_with_offset(!spv.struct<f32 [0], !spv.struct<f32 [0], i32 [4]> [4]>)
func @nested_struct_with_offset(!spv.struct<f32 [0], !spv.struct<f32 [0], i32 [4]> [4]>)
// -----
// expected-error @+1 {{layout specification must be given for all members}}
func @struct_type_missing_offset1((!spv.struct<f32, i32 [4]>) -> ()
// -----
// expected-error @+1 {{layout specification must be given for all members}}
func @struct_type_missing_offset2(!spv.struct<f32 [3], i32>) -> ()
// -----
// expected-error @+1 {{cannot parse type: f32 i32}}
func @struct_type_missing_comma1(!spv.struct<f32 i32>) -> ()
// -----
// expected-error @+1 {{unexpected extra tokens in layout information: ' i32'}}
func @struct_type_missing_comma2(!spv.struct<f32 [0] i32>) -> ()
// -----
// expected-error @+1 {{expected unsigned integer to specify offset of member in struct}}
func @struct_type_neg_offset(!spv.struct<f32 [-1]>) -> ()