From 9a4f5d2ee324f536cede769c10022d1ce7b875f1 Mon Sep 17 00:00:00 2001 From: Mahesh Ravishankar Date: Thu, 19 Sep 2019 14:49:29 -0700 Subject: [PATCH] Allow specification of decorators on SPIR-V StructType members. Allow specification of decorators on SPIR-V StructType members. If the struct has layout information, these decorations are to be specified after the offset specification of the member. These decorations are emitted as OpMemberDecorate instructions on the struct . Update (de)serialization to handle these decorations. PiperOrigin-RevId: 270130136 --- mlir/g3doc/Dialects/SPIR-V.md | 5 +- mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h | 17 +- mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp | 155 +++++++++++------- mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp | 87 +++++++--- .../SPIRV/Serialization/Deserializer.cpp | 68 +++++--- .../SPIRV/Serialization/Serializer.cpp | 23 ++- .../Dialect/SPIRV/Serialization/struct.mlir | 9 + mlir/test/Dialect/SPIRV/types.mlir | 37 ++++- 8 files changed, 287 insertions(+), 114 deletions(-) diff --git a/mlir/g3doc/Dialects/SPIR-V.md b/mlir/g3doc/Dialects/SPIR-V.md index 0a6bc70c1a76..b1d83b16231f 100644 --- a/mlir/g3doc/Dialects/SPIR-V.md +++ b/mlir/g3doc/Dialects/SPIR-V.md @@ -179,8 +179,9 @@ For example, This corresponds to SPIR-V [struct type][StructType]. Its syntax is ``` {.ebnf} -struct-type ::= `!spv.struct<` spirv-type (` [` integer-literal `]` )? - (`, ` spirv-type ( ` [` integer-literal `] ` )? )* `>` +struct-member-decoration ::= integer-literal? spirv-decoration* +struct-type ::= `!spv.struct<` spirv-type (`[` struct-member-decoration `]`)? + (`, ` spirv-type (`[` struct-member-decoration `]`)? ``` For Example, diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h index 679d37a7ad33..cb749b2fb467 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h @@ -174,12 +174,13 @@ public: // types using LayoutInfo = uint64_t; + using MemberDecorationInfo = std::pair; + static bool kindof(unsigned kind) { return kind == TypeKind::Struct; } - static StructType get(ArrayRef memberTypes); - static StructType get(ArrayRef memberTypes, - ArrayRef layoutInfo); + ArrayRef layoutInfo = {}, + ArrayRef memberDecorations = {}); unsigned getNumElements() const; @@ -188,6 +189,16 @@ public: bool hasLayout() const; uint64_t getOffset(unsigned) const; + + // Returns in `allMemberDecorations` the spirv::Decorations (apart from + // Offset) associated with all members of the StructType. + void getMemberDecorations(SmallVectorImpl + &allMemberDecorations) const; + + // Returns in `memberDecorations` all the spirv::Decorations (apart from + // Offset) associated with the `i`-th member of the StructType. + void getMemberDecorations( + unsigned i, SmallVectorImpl &memberDecorations) const; }; } // end namespace spirv diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp index 4660aa82de92..04cee460a41f 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -64,8 +64,8 @@ Optional parseAndVerify(SPIRVDialect const &dialect, Location loc, StringRef spec); template <> -Optional parseAndVerify(SPIRVDialect const &dialect, Location loc, - StringRef spec); +Optional parseAndVerify(SPIRVDialect const &dialect, + Location loc, StringRef spec); // Parses " x" from the beginning of `spec`. static bool parseNumberX(StringRef &spec, int64_t &number) { @@ -206,6 +206,13 @@ static Type parseArrayType(SPIRVDialect const &dialect, StringRef spec, if (lastLSquare != StringRef::npos) { auto layoutSpec = spec.substr(lastLSquare); + layoutSpec = layoutSpec.trim(); + if (!layoutSpec.consume_front("[") || !layoutSpec.consume_back("]")) { + emitError(loc, "expected array stride within '[' ']' in '") + << layoutSpec << "'"; + return Type(); + } + layoutSpec = layoutSpec.trim(); auto layout = parseAndVerify(dialect, loc, layoutSpec); if (!layout) { @@ -216,6 +223,7 @@ static Type parseArrayType(SPIRVDialect const &dialect, StringRef spec, emitError(loc, "ArrayStride must be greater than zero"); return Type(); } + spec = spec.substr(0, lastLSquare); } @@ -314,34 +322,23 @@ Optional parseAndVerify(SPIRVDialect const &dialect, Location loc, return ty; } -template <> -Optional parseAndVerify(SPIRVDialect const &dialect, Location loc, - StringRef spec) { - uint64_t offsetVal = std::numeric_limits::max(); - if (!spec.consume_front("[")) { - emitError(loc, "expected '[' while parsing layout specification in '") - << spec << "'"; - return llvm::None; - } +template +static Optional parseAndVerifyInteger(SPIRVDialect const &dialect, + Location loc, StringRef spec) { + IntTy offsetVal = std::numeric_limits::max(); spec = spec.trim(); if (spec.consumeInteger(10, offsetVal)) { - emitError(loc, "expected unsigned integer to specify layout information: '") - << spec << "'"; - return llvm::None; - } - spec = spec.trim(); - if (!spec.consume_front("]")) { - emitError(loc, "missing ']' in decorations spec: '") << spec << "'"; - return llvm::None; - } - if (spec != "") { - emitError(loc, "unexpected extra tokens in layout information: '") - << spec << "'"; return llvm::None; } return offsetVal; } +template <> +Optional parseAndVerify(SPIRVDialect const &dialect, + Location loc, StringRef spec) { + return parseAndVerifyInteger(dialect, loc, spec); +} + // Functor object to parse a comma separated list of specs. The function // parseAndVerify does the actual parsing and verification of individual // elements. This is a functor since parsing the last element of the list @@ -423,33 +420,62 @@ static Type parseImageType(SPIRVDialect const &dialect, StringRef spec, } // Method to parse one member of a struct (including Layout information) -static ParseResult -parseStructElement(SPIRVDialect const &dialect, StringRef spec, Location loc, - SmallVectorImpl &memberTypes, - SmallVectorImpl &layoutInfo) { - // Check for a '[' ']' +static ParseResult parseStructElement( + SPIRVDialect const &dialect, StringRef spec, Location loc, + SmallVectorImpl &memberTypes, + SmallVectorImpl &layoutInfo, + SmallVectorImpl &memberDecorationInfo) { + // Check for '[' integer-literal spirv-decoration* ']' auto lastLSquare = spec.rfind('['); auto typeSpec = spec.substr(0, lastLSquare); - auto layoutSpec = (lastLSquare == StringRef::npos ? StringRef("") - : spec.substr(lastLSquare)); + auto memberDecorationSpec = + (lastLSquare == StringRef::npos ? StringRef("") + : spec.substr(lastLSquare)); auto type = parseAndVerify(dialect, loc, typeSpec); if (!type) { return failure(); } memberTypes.push_back(type.getValue()); - if (layoutSpec.empty()) { + + if (memberDecorationSpec.empty()) { return success(); } - if (layoutInfo.size() != memberTypes.size() - 1) { - emitError(loc, "layout specification must be given for all members"); + memberDecorationSpec = memberDecorationSpec.trim(); + if (!memberDecorationSpec.consume_front("[") || + !memberDecorationSpec.consume_back("]")) { + emitError(loc, + "expected struct member offset/decoration within '[' ']' in '") + << spec << "'"; return failure(); } + + memberDecorationSpec = memberDecorationSpec.trim(); + auto memberInfo = memberDecorationSpec.split(' '); + // Check if the first element is offset. auto layout = - parseAndVerify(dialect, loc, layoutSpec); - if (!layout) { - return failure(); + parseAndVerify(dialect, loc, memberInfo.first); + if (layout) { + if (layoutInfo.size() != memberTypes.size() - 1) { + emitError(loc, "layout specification must be given for all members"); + return failure(); + } + layoutInfo.push_back(layout.getValue()); + memberDecorationSpec = memberInfo.second.trim(); + } + + // Check for spirv::Decorations. + while (!memberDecorationSpec.empty()) { + memberInfo = memberDecorationSpec.split(' '); + auto memberDecoration = + parseAndVerify(dialect, loc, memberInfo.first); + if (!memberDecoration) { + return failure(); + } + memberDecorationInfo.emplace_back( + static_cast(memberTypes.size() - 1), + memberDecoration.getValue()); + memberDecorationSpec = memberInfo.second.trim(); } - layoutInfo.push_back(layout.getValue()); return success(); } @@ -474,33 +500,35 @@ computeMatchingRAngles(Location loc, StringRef const &spec, return true; } -static ParseResult -parseStructHelper(SPIRVDialect const &dialect, StringRef spec, Location loc, - ArrayRef matchingRAngleOffset, - SmallVectorImpl &memberTypes, - SmallVectorImpl &layoutInfo) { +static ParseResult parseStructHelper( + SPIRVDialect const &dialect, StringRef spec, Location loc, + ArrayRef matchingRAngleOffset, SmallVectorImpl &memberTypes, + SmallVectorImpl &layoutInfo, + SmallVectorImpl &memberDecorationsInfo) { // Check if the occurrence of ',' or '<' is before. If former, split using // ','. If latter, split using matching '>' to get the entire type // description auto firstComma = spec.find(','); auto firstLAngle = spec.find('<'); if (firstLAngle == StringRef::npos && firstComma == StringRef::npos) { - return parseStructElement(dialect, spec, loc, memberTypes, layoutInfo); + return parseStructElement(dialect, spec, loc, memberTypes, layoutInfo, + memberDecorationsInfo); } if (firstLAngle == StringRef::npos || firstComma < firstLAngle) { // Parse the type before the ',' if (parseStructElement(dialect, spec.substr(0, firstComma), loc, - memberTypes, layoutInfo)) { + memberTypes, layoutInfo, memberDecorationsInfo)) { return failure(); } return parseStructHelper(dialect, spec.substr(firstComma + 1).ltrim(), loc, - matchingRAngleOffset, memberTypes, layoutInfo); + matchingRAngleOffset, memberTypes, layoutInfo, + memberDecorationsInfo); } auto matchingRAngle = matchingRAngleOffset.front() + firstLAngle; // Find the next ',' or '>' auto endLoc = std::min(spec.find(',', matchingRAngle + 1), spec.size()); if (parseStructElement(dialect, spec.substr(0, endLoc), loc, memberTypes, - layoutInfo)) { + layoutInfo, memberDecorationsInfo)) { return failure(); } auto rest = spec.substr(endLoc + 1).ltrim(); @@ -512,14 +540,15 @@ parseStructHelper(SPIRVDialect const &dialect, StringRef spec, Location loc, dialect, rest.drop_front().trim(), loc, ArrayRef(std::next(matchingRAngleOffset.begin()), matchingRAngleOffset.end()), - memberTypes, layoutInfo); + memberTypes, layoutInfo, memberDecorationsInfo); } emitError(loc, "unexpected string : '") << rest << "'"; return failure(); } -// struct-type ::= `!spv.struct<` spirv-type (` [` integer-literal `]`)? -// (`, ` spirv-type ( ` [` integer-literal `] ` )? )* +// struct-member-decoration ::= integer-literal? spirv-decoration* +// struct-type ::= `!spv.struct<` spirv-type (`[` struct-member-decoration `]`)? +// (`, ` spirv-type (`[` struct-member-decoration `]`)? static Type parseStructType(SPIRVDialect const &dialect, StringRef spec, Location loc) { if (!spec.consume_front("struct<") || !spec.consume_back(">")) { @@ -534,20 +563,18 @@ static Type parseStructType(SPIRVDialect const &dialect, StringRef spec, SmallVector memberTypes; SmallVector layoutInfo; + SmallVector memberDecorationsInfo; SmallVector matchingRAngleOffset; if (!computeMatchingRAngles(loc, spec, matchingRAngleOffset) || parseStructHelper(dialect, spec, loc, matchingRAngleOffset, memberTypes, - layoutInfo)) { + layoutInfo, memberDecorationsInfo)) { return Type(); } - if (layoutInfo.empty()) { - return StructType::get(memberTypes); - } - if (memberTypes.size() != layoutInfo.size()) { + if (!layoutInfo.empty() && memberTypes.size() != layoutInfo.size()) { emitError(loc, "layout specification must be given for all members"); return Type(); } - return StructType::get(memberTypes, layoutInfo); + return StructType::get(memberTypes, layoutInfo, memberDecorationsInfo); } // spirv-type ::= array-type @@ -606,8 +633,22 @@ static void print(StructType type, llvm::raw_ostream &os) { os << "struct<"; auto printMember = [&](unsigned i) { os << type.getElementType(i); - if (type.hasLayout()) { - os << " [" << type.getOffset(i) << "]"; + SmallVector decorations; + type.getMemberDecorations(i, decorations); + if (type.hasLayout() || !decorations.empty()) { + os << " ["; + if (type.hasLayout()) { + os << type.getOffset(i); + if (!decorations.empty()) + os << " "; + } + auto between_fn = [&os]() { os << " "; }; + auto each_fn = [&os](spirv::Decoration decoration) { + os << stringifyDecoration(decoration); + }; + mlir::interleave(decorations.begin(), decorations.end(), each_fn, + between_fn); + os << "]"; } }; mlir::interleaveComma(llvm::seq(0, type.getNumElements()), os, diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp index f18d313ea1e4..ceaf71dafc99 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp @@ -363,34 +363,47 @@ Type RuntimeArrayType::getElementType() const { return getImpl()->elementType; } //===----------------------------------------------------------------------===// struct spirv::detail::StructTypeStorage : public TypeStorage { - StructTypeStorage(unsigned numMembers, Type const *memberTypes, - StructType::LayoutInfo const *layoutInfo) + StructTypeStorage( + unsigned numMembers, Type const *memberTypes, + StructType::LayoutInfo const *layoutInfo, unsigned numMemberDecorations, + StructType::MemberDecorationInfo const *memberDecorationsInfo) : TypeStorage(numMembers), memberTypes(memberTypes), - layoutInfo(layoutInfo) {} + layoutInfo(layoutInfo), numMemberDecorations(numMemberDecorations), + memberDecorationsInfo(memberDecorationsInfo) {} - using KeyTy = std::pair, ArrayRef>; + using KeyTy = std::tuple, ArrayRef, + ArrayRef>; bool operator==(const KeyTy &key) const { - return key == KeyTy(getMemberTypes(), getLayoutInfo()); + return key == + KeyTy(getMemberTypes(), getLayoutInfo(), getMemberDecorationsInfo()); } static StructTypeStorage *construct(TypeStorageAllocator &allocator, const KeyTy &key) { - ArrayRef keyTypes = key.first; + ArrayRef keyTypes = std::get<0>(key); // Copy the member type and layout information into the bump pointer auto typesList = allocator.copyInto(keyTypes).data(); const StructType::LayoutInfo *layoutInfoList = nullptr; - if (!key.second.empty()) { - ArrayRef keyLayoutInfo = key.second; + if (!std::get<1>(key).empty()) { + ArrayRef keyLayoutInfo = std::get<1>(key); assert(keyLayoutInfo.size() == keyTypes.size() && "size of layout information must be same as the size of number of " "elements"); layoutInfoList = allocator.copyInto(keyLayoutInfo).data(); } + const StructType::MemberDecorationInfo *memberDecorationList = nullptr; + unsigned numMemberDecorations = 0; + if (!std::get<2>(key).empty()) { + auto keyMemberDecorations = std::get<2>(key); + numMemberDecorations = keyMemberDecorations.size(); + memberDecorationList = allocator.copyInto(keyMemberDecorations).data(); + } return new (allocator.allocate()) - StructTypeStorage(keyTypes.size(), typesList, layoutInfoList); + StructTypeStorage(keyTypes.size(), typesList, layoutInfoList, + numMemberDecorations, memberDecorationList); } ArrayRef getMemberTypes() const { @@ -401,25 +414,34 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { if (layoutInfo) { return ArrayRef(layoutInfo, getSubclassData()); } - return ArrayRef(nullptr, size_t(0)); + return {}; + } + + ArrayRef getMemberDecorationsInfo() const { + if (memberDecorationsInfo) { + return ArrayRef(memberDecorationsInfo, + numMemberDecorations); + } + return {}; } Type const *memberTypes; StructType::LayoutInfo const *layoutInfo; + unsigned numMemberDecorations; + StructType::MemberDecorationInfo const *memberDecorationsInfo; }; -StructType StructType::get(ArrayRef memberTypes) { - assert(!memberTypes.empty() && "Struct needs at least one member type"); - ArrayRef noLayout(nullptr, size_t(0)); - return Base::get(memberTypes[0].getContext(), TypeKind::Struct, memberTypes, - noLayout); -} - -StructType StructType::get(ArrayRef memberTypes, - ArrayRef layoutInfo) { +StructType +StructType::get(ArrayRef memberTypes, + ArrayRef layoutInfo, + ArrayRef memberDecorations) { assert(!memberTypes.empty() && "Struct needs at least one member type"); + // Sort the decorations. + SmallVector sortedDecorations( + memberDecorations.begin(), memberDecorations.end()); + llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end()); return Base::get(memberTypes.vec().front().getContext(), TypeKind::Struct, - memberTypes, layoutInfo); + memberTypes, layoutInfo, sortedDecorations); } unsigned StructType::getNumElements() const { @@ -441,3 +463,28 @@ uint64_t StructType::getOffset(unsigned index) const { "element index is more than number of members of the SPIR-V StructType"); return getImpl()->layoutInfo[index]; } + +void StructType::getMemberDecorations( + SmallVectorImpl &memberDecorations) + const { + memberDecorations.clear(); + auto implMemberDecorations = getImpl()->getMemberDecorationsInfo(); + memberDecorations.append(implMemberDecorations.begin(), + implMemberDecorations.end()); +} + +void StructType::getMemberDecorations( + unsigned index, SmallVectorImpl &decorations) const { + assert(getNumElements() > index && "member index out of range"); + auto memberDecorations = getImpl()->getMemberDecorationsInfo(); + decorations.clear(); + for (auto &memberDecoration : memberDecorations) { + if (memberDecoration.first == index) { + decorations.push_back(memberDecoration.second); + } + if (memberDecoration.first > index) { + // Early exit since the decorations are stored sorted. + return; + } + } +} diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index 7c62dca0665f..3c024f5c682f 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -399,7 +399,11 @@ private: DenseMap typeDecorations; // Result to member decorations. - DenseMap> memberDecorationMap; + // decorated-struct-type- -> + // (struct-member-index -> (decoration -> decoration-operands)) + DenseMap>>> + memberDecorationMap; // Result to extended instruction set name. DenseMap extendedInstSets; @@ -622,18 +626,22 @@ LogicalResult Deserializer::processDecoration(ArrayRef words) { LogicalResult Deserializer::processMemberDecoration(ArrayRef words) { // The binary layout of OpMemberDecorate is different comparing to OpDecorate - if (words.size() != 4) { - return emitError(unknownLoc, "OpMemberDecorate must have 4 operands"); + if (words.size() < 3) { + return emitError(unknownLoc, + "OpMemberDecorate must have at least 3 operands"); } - switch (static_cast(words[2])) { - case spirv::Decoration::Offset: - memberDecorationMap[words[0]][words[1]] = words[3]; - break; - default: - return emitError(unknownLoc, "unhandled OpMemberDecoration case: ") - << words[2]; + auto decoration = static_cast(words[2]); + if (decoration == spirv::Decoration::Offset && words.size() != 4) { + return emitError(unknownLoc, + " missing offset specification in OpMemberDecorate with " + "Offset decoration"); } + ArrayRef decorationOperands; + if (words.size() > 3) { + decorationOperands = words.slice(3); + } + memberDecorationMap[words[0]][words[1]][decoration] = decorationOperands; return success(); } @@ -1098,25 +1106,35 @@ LogicalResult Deserializer::processStructType(ArrayRef operands) { } SmallVector layoutInfo; - // Check for layoutinfo - auto memberDecorationIt = memberDecorationMap.find(operands[0]); - if (memberDecorationIt != memberDecorationMap.end()) { - // Each member must have an offset - const auto &offsetDecorationMap = memberDecorationIt->second; - auto offsetDecorationMapEnd = offsetDecorationMap.end(); + SmallVector memberDecorationsInfo; + if (memberDecorationMap.count(operands[0])) { + auto &allMemberDecorations = memberDecorationMap[operands[0]]; for (auto memberIndex : llvm::seq(0, memberTypes.size())) { - // Check that specific member has an offset - auto offsetIt = offsetDecorationMap.find(memberIndex); - if (offsetIt == offsetDecorationMapEnd) { - return emitError(unknownLoc, "OpTypeStruct with ") - << operands[0] << " must have an offset for " << memberIndex - << "-th member"; + if (allMemberDecorations.count(memberIndex)) { + for (auto &memberDecoration : allMemberDecorations[memberIndex]) { + // Check for offset. + if (memberDecoration.first == spirv::Decoration::Offset) { + // If layoutInfo is empty, resize to the number of members; + if (layoutInfo.empty()) { + layoutInfo.resize(memberTypes.size()); + } + layoutInfo[memberIndex] = memberDecoration.second[0]; + } else { + if (!memberDecoration.second.empty()) { + return emitError(unknownLoc, + "unhandled OpMemberDecoration with decoration ") + << stringifyDecoration(memberDecoration.first) + << " which has additional operands"; + } + memberDecorationsInfo.emplace_back(memberIndex, + memberDecoration.first); + } + } } - layoutInfo.push_back( - static_cast(offsetIt->second)); } } - typeMap[operands[0]] = spirv::StructType::get(memberTypes, layoutInfo); + typeMap[operands[0]] = + spirv::StructType::get(memberTypes, layoutInfo, memberDecorationsInfo); return success(); } diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index 28e68eac15d3..34bc0e657ac8 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -167,7 +167,7 @@ private: /// Process member decoration LogicalResult processMemberDecoration(uint32_t structID, uint32_t memberNum, spirv::Decoration decorationType, - uint32_t value); + ArrayRef values = {}); //===--------------------------------------------------------------------===// // Types @@ -532,9 +532,12 @@ LogicalResult Serializer::processTypeDecoration( LogicalResult Serializer::processMemberDecoration(uint32_t structID, uint32_t memberIndex, spirv::Decoration decorationType, - uint32_t value) { + ArrayRef values) { SmallVector args( - {structID, memberIndex, static_cast(decorationType), value}); + {structID, memberIndex, static_cast(decorationType)}); + if (!values.empty()) { + args.append(values.begin(), values.end()); + } return encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, args); } @@ -793,11 +796,21 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID, resultID, elementIndex, spirv::Decoration::Offset, static_cast(structType.getOffset(elementIndex))))) { return emitError(loc, "cannot decorate ") - << elementIndex << "-th member of : " << structType - << "with its offset"; + << elementIndex << "-th member of " << structType + << " with its offset"; } } } + SmallVector memberDecorations; + structType.getMemberDecorations(memberDecorations); + for (auto &memberDecoration : memberDecorations) { + if (failed(processMemberDecoration(resultID, memberDecoration.first, + memberDecoration.second))) { + return emitError(loc, "cannot decorate ") + << memberDecoration.first << "-th member of " << structType + << " with " << stringifyDecoration(memberDecoration.second); + } + } typeEnum = spirv::Opcode::OpTypeStruct; return success(); } diff --git a/mlir/test/Dialect/SPIRV/Serialization/struct.mlir b/mlir/test/Dialect/SPIRV/Serialization/struct.mlir index b061dfebcc6b..7a4695e0cfc8 100644 --- a/mlir/test/Dialect/SPIRV/Serialization/struct.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/struct.mlir @@ -13,6 +13,15 @@ spv.module "Logical" "GLSL450" { // CHECK: !spv.ptr [0]> [4]> [0]>, StorageBuffer> spv.globalVariable @var3 : !spv.ptr [0]> [4]> [0]>, StorageBuffer> + // CHECK: !spv.ptr, StorageBuffer> + spv.globalVariable @var4 : !spv.ptr, StorageBuffer> + + // CHECK: !spv.ptr, StorageBuffer> + spv.globalVariable @var5 : !spv.ptr, StorageBuffer> + + // CHECK: !spv.ptr, StorageBuffer> + spv.globalVariable @var6 : !spv.ptr, StorageBuffer> + // CHECK: !spv.ptr [0]>, Input>, // CHECK-SAME: !spv.ptr [0]>, Output> func @kernel_1(%arg0: !spv.ptr [0]>, Input>, %arg1: !spv.ptr [0]>, Output>) -> () { diff --git a/mlir/test/Dialect/SPIRV/types.mlir b/mlir/test/Dialect/SPIRV/types.mlir index 552ef6ac0fb4..cc1bee951fd2 100644 --- a/mlir/test/Dialect/SPIRV/types.mlir +++ b/mlir/test/Dialect/SPIRV/types.mlir @@ -235,6 +235,24 @@ func @nested_struct(!spv.struct>) // CHECK: func @nested_struct_with_offset(!spv.struct [4]>) func @nested_struct_with_offset(!spv.struct [4]>) +// CHECK: func @struct_type_with_decoration(!spv.struct) +func @struct_type_with_decoration(!spv.struct) + +// CHECK: func @struct_type_with_decoration_and_offset(!spv.struct) +func @struct_type_with_decoration_and_offset(!spv.struct) + +// CHECK: func @struct_type_with_decoration2(!spv.struct) +func @struct_type_with_decoration2(!spv.struct) + +// CHECK: func @struct_type_with_decoration3(!spv.struct) +func @struct_type_with_decoration3(!spv.struct) + +// CHECK: func @struct_type_with_decoration4(!spv.struct) +func @struct_type_with_decoration4(!spv.struct) + +// CHECK: func @struct_type_with_decoration5(!spv.struct) +func @struct_type_with_decoration5(!spv.struct) + // ----- // expected-error @+1 {{layout specification must be given for all members}} @@ -252,10 +270,25 @@ func @struct_type_missing_comma1(!spv.struct) -> () // ----- -// expected-error @+1 {{unexpected extra tokens in layout information: ' i32'}} +// expected-error @+1 {{expected struct member offset/decoration within '[' ']' in 'f32 [0] i32'}} func @struct_type_missing_comma2(!spv.struct) -> () // ----- -// expected-error @+1 {{expected unsigned integer to specify layout info}} +// expected-error @+1 {{unknown attribute: '-1'}} func @struct_type_neg_offset(!spv.struct) -> () + +// ----- + +// expected-error @+1 {{unbalanced '>' character in pretty dialect name}} +func @struct_type_neg_offset(!spv.struct) -> () + +// ----- + +// expected-error @+1 {{unbalanced ']' character in pretty dialect name}} +func @struct_type_neg_offset(!spv.struct) -> () + +// ----- + +// expected-error @+1 {{unknown attribute: '0'}} +func @struct_type_neg_offset(!spv.struct) -> ()