diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h index 264fed3c5aeb..b25c7a309174 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h @@ -73,14 +73,23 @@ class ArrayType : public Type::TypeBase { public: using Base::Base; + // Zero layout specifies that is no layout + using LayoutInfo = uint64_t; static bool kindof(unsigned kind) { return kind == TypeKind::Array; } static ArrayType get(Type elementType, unsigned elementCount); + static ArrayType get(Type elementType, unsigned elementCount, + LayoutInfo layoutInfo); + unsigned getNumElements() const; Type getElementType() const; + + bool hasLayout() const; + + uint64_t getArrayStride() const; }; // SPIR-V image type diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp index 622bb221b3fc..40d877a7225a 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -53,6 +53,18 @@ SPIRVDialect::SPIRVDialect(MLIRContext *context) // Type Parsing //===----------------------------------------------------------------------===// +// Forward declarations. +template +static Optional parseAndVerify(SPIRVDialect const &dialect, Location loc, + StringRef spec); +template <> +Optional parseAndVerify(SPIRVDialect const &dialect, Location loc, + StringRef spec); + +template <> +Optional parseAndVerify(SPIRVDialect const &dialect, Location loc, + StringRef spec); + // Parses " x" from the beginning of `spec`. static bool parseNumberX(StringRef &spec, int64_t &number) { spec = spec.ltrim(); @@ -150,7 +162,8 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect, StringRef spec, // | vector-type // | spirv-type // -// array-type ::= `!spv.array<` integer-literal `x` element-type `>` +// array-type ::= `!spv.array<` integer-literal `x` element-type +// (`[` integer-literal `]`)? `>` static Type parseArrayType(SPIRVDialect const &dialect, StringRef spec, Location loc) { if (!spec.consume_front("array<") || !spec.consume_back(">")) { @@ -171,11 +184,37 @@ static Type parseArrayType(SPIRVDialect const &dialect, StringRef spec, return Type(); } + ArrayType::LayoutInfo layoutInfo = 0; + size_t lastLSquare; + + // Handle case when element type is not a trivial type + auto lastRDelimiter = spec.rfind('>'); + if (lastRDelimiter != StringRef::npos) { + lastLSquare = spec.find('[', lastRDelimiter); + } else { + lastLSquare = spec.rfind('['); + } + + if (lastLSquare != StringRef::npos) { + auto layoutSpec = spec.substr(lastLSquare); + auto layout = + parseAndVerify(dialect, loc, layoutSpec); + if (!layout) { + return Type(); + } + + if (!(layoutInfo = layout.getValue())) { + emitError(loc, "ArrayStride must be greater than zero"); + return Type(); + } + spec = spec.substr(0, lastLSquare); + } + Type elementType = parseAndVerifyType(dialect, spec, loc); if (!elementType) return Type(); - return ArrayType::get(elementType, count); + return ArrayType::get(elementType, count, layoutInfo); } // TODO(ravishankarm) : Reorder methods to be utilities first and parse*Type @@ -267,18 +306,17 @@ Optional parseAndVerify(SPIRVDialect const &dialect, Location loc, } template <> -Optional -parseAndVerify(SPIRVDialect const &dialect, Location loc, StringRef spec) { +Optional parseAndVerify(SPIRVDialect const &dialect, Location loc, + StringRef spec) { uint64_t offsetVal = std::numeric_limits::max(); if (!spec.consume_front("[")) { emitError(loc, "expected '[' while parsing layout specification in '") << spec << "'"; return llvm::None; } + spec = spec.trim(); if (spec.consumeInteger(10, offsetVal)) { - emitError( - loc, - "expected unsigned integer to specify offset of member in struct: '") + emitError(loc, "expected unsigned integer to specify layout information: '") << spec << "'"; return llvm::None; } @@ -292,7 +330,7 @@ parseAndVerify(SPIRVDialect const &dialect, Location loc, StringRef spec) { << spec << "'"; return llvm::None; } - return spirv::StructType::LayoutInfo{offsetVal}; + return offsetVal; } // Functor object to parse a comma separated list of specs. The function @@ -530,8 +568,11 @@ Type SPIRVDialect::parseType(StringRef spec, Location loc) const { //===----------------------------------------------------------------------===// static void print(ArrayType type, llvm::raw_ostream &os) { - os << "array<" << type.getNumElements() << " x " << type.getElementType() - << ">"; + os << "array<" << type.getNumElements() << " x " << type.getElementType(); + if (type.hasLayout()) { + os << " [" << type.getArrayStride() << "]"; + } + os << ">"; } static void print(RuntimeArrayType type, llvm::raw_ostream &os) { diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp index 345d13d42aae..f79db01998f4 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp @@ -34,7 +34,7 @@ using namespace mlir::spirv; //===----------------------------------------------------------------------===// struct spirv::detail::ArrayTypeStorage : public TypeStorage { - using KeyTy = std::pair; + using KeyTy = std::tuple; static ArrayTypeStorage *construct(TypeStorageAllocator &allocator, const KeyTy &key) { @@ -42,18 +42,26 @@ struct spirv::detail::ArrayTypeStorage : public TypeStorage { } bool operator==(const KeyTy &key) const { - return key == KeyTy(elementType, getSubclassData()); + return key == KeyTy(elementType, getSubclassData(), layoutInfo); } ArrayTypeStorage(const KeyTy &key) - : TypeStorage(key.second), elementType(key.first) {} + : TypeStorage(std::get<1>(key)), elementType(std::get<0>(key)), + layoutInfo(std::get<2>(key)) {} Type elementType; + ArrayType::LayoutInfo layoutInfo; }; ArrayType ArrayType::get(Type elementType, unsigned elementCount) { return Base::get(elementType.getContext(), TypeKind::Array, elementType, - elementCount); + elementCount, 0); +} + +ArrayType ArrayType::get(Type elementType, unsigned elementCount, + ArrayType::LayoutInfo layoutInfo) { + return Base::get(elementType.getContext(), TypeKind::Array, elementType, + elementCount, layoutInfo); } unsigned ArrayType::getNumElements() const { @@ -62,6 +70,11 @@ unsigned ArrayType::getNumElements() const { Type ArrayType::getElementType() const { return getImpl()->elementType; } +// ArrayStride must be greater than zero +bool ArrayType::hasLayout() const { return getImpl()->layoutInfo; } + +uint64_t ArrayType::getArrayStride() const { return getImpl()->layoutInfo; } + //===----------------------------------------------------------------------===// // CompositeType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index 217f9b190dd6..1aad7173dc6c 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -207,6 +207,9 @@ private: // Result to decorations mapping. DenseMap decorations; + // Result to type decorations. + DenseMap typeDecorations; + // List of instructions that are processed in a defered fashion (after an // initial processing of the entire binary). Some operations like // OpEntryPoint, and OpExecutionMode use forward references to function @@ -330,6 +333,13 @@ LogicalResult Deserializer::processDecoration(ArrayRef words) { opBuilder.getStringAttr(stringifyBuiltIn( static_cast(words[2])))); break; + case spirv::Decoration::ArrayStride: + if (words.size() != 3) { + return emitError(unknownLoc, "OpDecorate with ") + << decorationName << " needs a single integer literal"; + } + typeDecorations[words[0]] = static_cast(words[2]); + break; default: return emitError(unknownLoc, "unhandled Decoration : '") << decorationName; } @@ -590,7 +600,8 @@ LogicalResult Deserializer::processArrayType(ArrayRef operands) { << defOp->getName(); } - typeMap[operands[0]] = spirv::ArrayType::get(elementTy, count); + typeMap[operands[0]] = spirv::ArrayType::get( + elementTy, count, typeDecorations.lookup(operands[0])); return success(); } diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index 8b55873c5c0c..d06363a1a8c5 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -132,6 +132,12 @@ private: LogicalResult processDecoration(Location loc, uint32_t resultID, NamedAttribute attr); + template + LogicalResult processTypeDecoration(Location loc, DType type, + uint32_t resultId) { + return emitError(loc, "unhandled decoraion for type:") << type; + } + //===--------------------------------------------------------------------===// // Types //===--------------------------------------------------------------------===// @@ -148,7 +154,7 @@ private: /// Method for preparing basic SPIR-V type serialization. Returns the type's /// opcode and operands for the instruction via `typeEnum` and `operands`. - LogicalResult prepareBasicType(Location loc, Type type, + LogicalResult prepareBasicType(Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum, SmallVectorImpl &operands); @@ -366,6 +372,22 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, return encodeInstructionInto(decorations, spirv::Opcode::OpDecorate, args); } +namespace { +template <> +LogicalResult Serializer::processTypeDecoration( + Location loc, spirv::ArrayType type, uint32_t resultID) { + if (type.hasLayout()) { + // OpDecorate %arrayTypeSSA ArrayStride strideLiteral + SmallVector args; + args.push_back(resultID); + args.push_back(static_cast(spirv::Decoration::ArrayStride)); + args.push_back(type.getArrayStride()); + return encodeInstructionInto(decorations, spirv::Opcode::OpDecorate, args); + } + return success(); +} +} // namespace + LogicalResult Serializer::processFuncOp(FuncOp op) { uint32_t fnTypeID = 0; // Generate type of the function. @@ -445,7 +467,7 @@ LogicalResult Serializer::processType(Location loc, Type type, if ((type.isa() && succeeded(prepareFunctionType(loc, type.cast(), typeEnum, operands))) || - succeeded(prepareBasicType(loc, type, typeEnum, operands))) { + succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands))) { typeIDMap[type] = typeID; return encodeInstructionInto(typesGlobalValues, typeEnum, operands); } @@ -453,7 +475,8 @@ LogicalResult Serializer::processType(Location loc, Type type, } LogicalResult -Serializer::prepareBasicType(Location loc, Type type, spirv::Opcode &typeEnum, +Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID, + spirv::Opcode &typeEnum, SmallVectorImpl &operands) { if (isVoidType(type)) { typeEnum = spirv::Opcode::OpTypeVoid; @@ -501,9 +524,8 @@ Serializer::prepareBasicType(Location loc, Type type, spirv::Opcode &typeEnum, loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()), /*isSpec=*/false)) { operands.push_back(elementCountID); - return success(); } - return failure(); + return processTypeDecoration(loc, arrayType, resultID); } if (auto ptrType = type.dyn_cast()) { diff --git a/mlir/test/Dialect/SPIRV/Serialization/array_stride.mlir b/mlir/test/Dialect/SPIRV/Serialization/array_stride.mlir new file mode 100644 index 000000000000..b7229e80ffd4 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/Serialization/array_stride.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s + +func @spirvmodule() { + spv.module "Logical" "VulkanKHR" { + func @array_stride(%arg0 : !spv.ptr [128]>, StorageBuffer>, + %arg1 : i32, %arg2 : i32) { + // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr [128]>, StorageBuffer> + %2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr [128]>, StorageBuffer> + spv.Return + } + } + return +} diff --git a/mlir/test/Dialect/SPIRV/types.mlir b/mlir/test/Dialect/SPIRV/types.mlir index 58d16cf887e4..2bfadae6b737 100644 --- a/mlir/test/Dialect/SPIRV/types.mlir +++ b/mlir/test/Dialect/SPIRV/types.mlir @@ -12,6 +12,9 @@ func @scalar_array_type(!spv.array<16xf32>, !spv.array<8 x i32>) -> () // CHECK: func @vector_array_type(!spv.array<32 x vector<4xf32>>) func @vector_array_type(!spv.array< 32 x vector<4xf32> >) -> () +// CHECK: func @array_type_stride(!spv.array<4 x !spv.array<4 x f32 [4]> [128]>) +func @array_type_stride(!spv.array< 4 x !spv.array<4 x f32 [4]> [128]>) -> () + // ----- // expected-error @+1 {{spv.array delimiter <...> mismatch}} @@ -74,6 +77,11 @@ func @llvm_type(!spv.array<4x!llvm.i32>) -> () // ----- +// expected-error @+1 {{ArrayStride must be greater than zero}} +func @array_type_zero_stide(!spv.array<4xi32 [0]>) -> () + +// ----- + //===----------------------------------------------------------------------===// // PointerType //===----------------------------------------------------------------------===// @@ -246,5 +254,5 @@ func @struct_type_missing_comma2(!spv.struct) -> () // ----- -// expected-error @+1 {{expected unsigned integer to specify offset of member in struct}} +// expected-error @+1 {{expected unsigned integer to specify layout info}} func @struct_type_neg_offset(!spv.struct) -> ()