From 75906bd565a199edb6b9b8088236376d6234a9a6 Mon Sep 17 00:00:00 2001 From: Mahesh Ravishankar Date: Mon, 23 Sep 2019 17:10:49 -0700 Subject: [PATCH] Handle OpMemberName instruction in SPIR-V deserializer. Sdd support in deserializer for OpMemberName instruction. For now the name is just processed and not associated with the spirv::StructType being built. That needs an enhancement to spirv::StructTypes itself. Add tests to check for errors reported during deserialization with some refactoring to common out some utility functions. PiperOrigin-RevId: 270794524 --- mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td | 44 ++++++----- .../mlir/Dialect/SPIRV/SPIRVBinaryUtils.h | 6 ++ .../SPIRV/Serialization/Deserializer.cpp | 29 ++++++- .../SPIRV/Serialization/SPIRVBinaryUtils.cpp | 16 ++++ .../SPIRV/Serialization/Serializer.cpp | 35 +++------ mlir/test/Dialect/SPIRV/ops.mlir | 6 -- .../Dialect/SPIRV/DeserializationTest.cpp | 76 ++++++++++++++++++- 7 files changed, 153 insertions(+), 59 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index bb1805456c3f..341f6ca71cd7 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -74,6 +74,7 @@ class SPV_OpCode { def SPV_OC_OpNop : I32EnumAttrCase<"OpNop", 0>; def SPV_OC_OpName : I32EnumAttrCase<"OpName", 5>; +def SPV_OC_OpMemberName : I32EnumAttrCase<"OpMemberName", 6>; def SPV_OC_OpExtension : I32EnumAttrCase<"OpExtension", 10>; def SPV_OC_OpExtInstImport : I32EnumAttrCase<"OpExtInstImport", 11>; def SPV_OC_OpExtInst : I32EnumAttrCase<"OpExtInst", 12>; @@ -159,27 +160,28 @@ def SPV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>; def SPV_OpcodeAttr : I32EnumAttr<"Opcode", "valid SPIR-V instructions", [ - SPV_OC_OpNop, SPV_OC_OpName, SPV_OC_OpExtension, SPV_OC_OpExtInstImport, - SPV_OC_OpExtInst, SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint, - SPV_OC_OpExecutionMode, SPV_OC_OpCapability, SPV_OC_OpTypeVoid, - SPV_OC_OpTypeBool, SPV_OC_OpTypeInt, SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector, - SPV_OC_OpTypeArray, SPV_OC_OpTypeRuntimeArray, SPV_OC_OpTypeStruct, - SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue, - SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite, - SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse, - SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, - SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, - SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain, - SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, - SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, - SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, - SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect, - SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, - SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, - SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, - SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, - SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, - SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, + SPV_OC_OpNop, SPV_OC_OpName, SPV_OC_OpMemberName, SPV_OC_OpExtension, + SPV_OC_OpExtInstImport, SPV_OC_OpExtInst, SPV_OC_OpMemoryModel, + SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode, SPV_OC_OpCapability, + SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt, SPV_OC_OpTypeFloat, + SPV_OC_OpTypeVector, SPV_OC_OpTypeArray, SPV_OC_OpTypeRuntimeArray, + SPV_OC_OpTypeStruct, SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, + SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse, SPV_OC_OpConstant, + SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, + SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant, + SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter, + SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad, + SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate, + SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpIAdd, + SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, + SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, + SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect, SPV_OC_OpIEqual, + SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, + SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, + SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, + SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, + SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, + SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, SPV_OC_OpLoopMerge, diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBinaryUtils.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVBinaryUtils.h index f255446b91d2..3229e28ef1a3 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBinaryUtils.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBinaryUtils.h @@ -46,6 +46,12 @@ constexpr uint32_t kGeneratorNumber = 22; /// Appends a SPRI-V module header to `header` with the given `idBound`. void appendModuleHeader(SmallVectorImpl &header, uint32_t idBound); +/// Returns the word-count-prefixed opcode for an SPIR-V instruction. +uint32_t getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode); + +/// Encodes an SPIR-V `literal` string into the given `binary` vector. +LogicalResult encodeStringLiteralInto(SmallVectorImpl &binary, + StringRef literal); } // end namespace spirv } // end namespace mlir diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp index aba6aeef2314..0d9bf362ff6d 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp @@ -112,12 +112,15 @@ private: /// Process SPIR-V OpName with `operands`. LogicalResult processName(ArrayRef operands); - /// Method to process an OpDecorate instruction. + /// Processes an OpDecorate instruction. LogicalResult processDecoration(ArrayRef words); - // Method to process an OpMemberDecorate instruction. + // Processes an OpMemberDecorate instruction. LogicalResult processMemberDecoration(ArrayRef words); + /// Processes an OpMemberName instruction. + LogicalResult processMemberName(ArrayRef words); + /// Gets the FuncOp associated with a result of OpFunction. FuncOp getFunction(uint32_t id) { return funcMap.lookup(id); } @@ -410,6 +413,10 @@ private: DenseMap>>> memberDecorationMap; + // Result to member name. + // struct-type- -> (struct-member-index -> name) + DenseMap> memberNameMap; + // Result to extended instruction set name. DenseMap extendedInstSets; @@ -650,6 +657,20 @@ LogicalResult Deserializer::processMemberDecoration(ArrayRef words) { return success(); } +LogicalResult Deserializer::processMemberName(ArrayRef words) { + if (words.size() < 3) { + return emitError(unknownLoc, "OpMemberName must have at least 3 operands"); + } + unsigned wordIndex = 2; + auto name = decodeStringLiteral(words, wordIndex); + if (wordIndex != words.size()) { + return emitError(unknownLoc, + "unexpected trailing words in OpMemberName instruction"); + } + memberNameMap[words[0]][words[1]] = name; + return success(); +} + LogicalResult Deserializer::processFunction(ArrayRef operands) { if (curFunction) { return emitError(unknownLoc, "found function inside function"); @@ -1151,6 +1172,8 @@ LogicalResult Deserializer::processStructType(ArrayRef operands) { } typeMap[operands[0]] = spirv::StructType::get(memberTypes, layoutInfo, memberDecorationsInfo); + // TODO(ravishankarm): Update StructType to have member name as attribute as + // well. return success(); } @@ -1746,6 +1769,8 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode, return processExtInst(operands); case spirv::Opcode::OpExtInstImport: return processExtInstImport(operands); + case spirv::Opcode::OpMemberName: + return processMemberName(operands); case spirv::Opcode::OpMemoryModel: return processMemoryModel(operands); case spirv::Opcode::OpEntryPoint: diff --git a/mlir/lib/Dialect/SPIRV/Serialization/SPIRVBinaryUtils.cpp b/mlir/lib/Dialect/SPIRV/Serialization/SPIRVBinaryUtils.cpp index 1e432b38ef8f..ba383b2cc6ce 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/SPIRVBinaryUtils.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/SPIRVBinaryUtils.cpp @@ -51,3 +51,19 @@ void spirv::appendModuleHeader(SmallVectorImpl &header, header.push_back(idBound); // bound header.push_back(0); // Schema (reserved word) } + +/// Returns the word-count-prefixed opcode for an SPIR-V instruction. +uint32_t spirv::getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode) { + assert(((wordCount >> 16) == 0) && "word count out of range!"); + return (wordCount << 16) | static_cast(opcode); +} + +LogicalResult spirv::encodeStringLiteralInto(SmallVectorImpl &binary, + StringRef literal) { + // We need to encode the literal and the null termination. + auto encodingSize = literal.size() / 4 + 1; + auto bufferStartSize = binary.size(); + binary.resize(bufferStartSize + encodingSize, 0); + std::memcpy(binary.data() + bufferStartSize, literal.data(), literal.size()); + return success(); +} diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index ecca61bdf199..58aebddf29da 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -36,37 +36,19 @@ using namespace mlir; -/// Returns the word-count-prefixed opcode for an SPIR-V instruction. -static inline uint32_t getPrefixedOpcode(uint32_t wordCount, - spirv::Opcode opcode) { - assert(((wordCount >> 16) == 0) && "word count out of range!"); - return (wordCount << 16) | static_cast(opcode); -} - /// Encodes an SPIR-V instruction with the given `opcode` and `operands` into /// the given `binary` vector. -static LogicalResult encodeInstructionInto(SmallVectorImpl &binary, - spirv::Opcode op, - ArrayRef operands) { +LogicalResult encodeInstructionInto(SmallVectorImpl &binary, + spirv::Opcode op, + ArrayRef operands) { uint32_t wordCount = 1 + operands.size(); - binary.push_back(getPrefixedOpcode(wordCount, op)); + binary.push_back(spirv::getPrefixedOpcode(wordCount, op)); if (!operands.empty()) { binary.append(operands.begin(), operands.end()); } return success(); } -/// Encodes an SPIR-V `literal` string into the given `binary` vector. -static LogicalResult encodeStringLiteralInto(SmallVectorImpl &binary, - StringRef literal) { - // We need to encode the literal and the null termination. - auto encodingSize = literal.size() / 4 + 1; - auto bufferStartSize = binary.size(); - binary.resize(bufferStartSize + encodingSize, 0); - std::memcpy(binary.data() + bufferStartSize, literal.data(), literal.size()); - return success(); -} - namespace { /// A SPIR-V module serializer. @@ -435,7 +417,7 @@ void Serializer::processExtension() { for (auto ext : exts.getValue()) { auto extStr = ext.cast().getValue(); extName.clear(); - encodeStringLiteralInto(extName, extStr); + spirv::encodeStringLiteralInto(extName, extStr); encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName); } } @@ -508,7 +490,7 @@ LogicalResult Serializer::processName(uint32_t resultID, StringRef name) { SmallVector nameOperands; nameOperands.push_back(resultID); - if (failed(encodeStringLiteralInto(nameOperands, name))) { + if (failed(spirv::encodeStringLiteralInto(nameOperands, name))) { return failure(); } return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands); @@ -1388,7 +1370,8 @@ LogicalResult Serializer::encodeExtensionInstruction( setID = getNextID(); SmallVector importOperands; importOperands.push_back(setID); - if (failed(encodeStringLiteralInto(importOperands, extensionSetName)) || + if (failed( + spirv::encodeStringLiteralInto(importOperands, extensionSetName)) || failed(encodeInstructionInto( extendedSets, spirv::Opcode::OpExtInstImport, importOperands))) { return failure(); @@ -1490,7 +1473,7 @@ Serializer::processOp(spirv::EntryPointOp op) { } operands.push_back(funcID); // Add the name of the function. - encodeStringLiteralInto(operands, op.fn()); + spirv::encodeStringLiteralInto(operands, op.fn()); // Add the interface values. if (auto interface = op.interface()) { diff --git a/mlir/test/Dialect/SPIRV/ops.mlir b/mlir/test/Dialect/SPIRV/ops.mlir index 57adf97fb4c0..bda89d6784d1 100644 --- a/mlir/test/Dialect/SPIRV/ops.mlir +++ b/mlir/test/Dialect/SPIRV/ops.mlir @@ -12,8 +12,6 @@ func @access_chain_struct() -> () { return } -// ----- - func @access_chain_1D_array(%arg0 : i32) -> () { %0 = spv.Variable : !spv.ptr, Function> // CHECK: spv.AccessChain {{.*}}[{{.*}}] : !spv.ptr, Function> @@ -21,8 +19,6 @@ func @access_chain_1D_array(%arg0 : i32) -> () { return } -// ----- - func @access_chain_2D_array_1(%arg0 : i32) -> () { %0 = spv.Variable : !spv.ptr>, Function> // CHECK: spv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spv.ptr>, Function> @@ -31,8 +27,6 @@ func @access_chain_2D_array_1(%arg0 : i32) -> () { return } -// ----- - func @access_chain_2D_array_2(%arg0 : i32) -> () { %0 = spv.Variable : !spv.ptr>, Function> // CHECK: spv.AccessChain {{.*}}[{{.*}}] : !spv.ptr>, Function> diff --git a/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp b/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp index 9a180811d2cf..ed5bd9ebeccf 100644 --- a/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp +++ b/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp @@ -73,10 +73,7 @@ protected: /// Adds the SPIR-V instruction into `binary`. void addInstruction(spirv::Opcode op, ArrayRef operands) { uint32_t wordCount = 1 + operands.size(); - assert(((wordCount >> 16) == 0) && "word count out of range!"); - - uint32_t prefixedOpcode = (wordCount << 16) | static_cast(op); - binary.push_back(prefixedOpcode); + binary.push_back(spirv::getPrefixedOpcode(wordCount, op)); binary.append(operands.begin(), operands.end()); } @@ -92,6 +89,15 @@ protected: return id; } + uint32_t addStructType(ArrayRef memberTypes) { + auto id = nextID++; + SmallVector words; + words.push_back(id); + words.append(memberTypes.begin(), memberTypes.end()); + addInstruction(spirv::Opcode::OpTypeStruct, words); + return id; + } + uint32_t addFunctionType(uint32_t retType, ArrayRef paramTypes) { auto id = nextID++; SmallVector operands; @@ -173,6 +179,68 @@ TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) { expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters"); } +//===----------------------------------------------------------------------===// +// StructType +//===----------------------------------------------------------------------===// + +TEST_F(DeserializationTest, OpMemberNameSuccess) { + addHeader(); + SmallVector typeDecl; + std::swap(typeDecl, binary); + + auto int32Type = addIntType(32); + auto structType = addStructType({int32Type, int32Type}); + std::swap(typeDecl, binary); + + SmallVector operands1 = {structType, 0}; + spirv::encodeStringLiteralInto(operands1, "i1"); + addInstruction(spirv::Opcode::OpMemberName, operands1); + + SmallVector operands2 = {structType, 1}; + spirv::encodeStringLiteralInto(operands2, "i2"); + addInstruction(spirv::Opcode::OpMemberName, operands2); + + binary.append(typeDecl.begin(), typeDecl.end()); + EXPECT_NE(llvm::None, deserialize()); +} + +TEST_F(DeserializationTest, OpMemberNameMissingOperands) { + addHeader(); + SmallVector typeDecl; + std::swap(typeDecl, binary); + + auto int32Type = addIntType(32); + auto int64Type = addIntType(64); + auto structType = addStructType({int32Type, int64Type}); + std::swap(typeDecl, binary); + + SmallVector operands1 = {structType}; + addInstruction(spirv::Opcode::OpMemberName, operands1); + + binary.append(typeDecl.begin(), typeDecl.end()); + ASSERT_EQ(llvm::None, deserialize()); + expectDiagnostic("OpMemberName must have at least 3 operands"); +} + +TEST_F(DeserializationTest, OpMemberNameExcessOperands) { + addHeader(); + SmallVector typeDecl; + std::swap(typeDecl, binary); + + auto int32Type = addIntType(32); + auto structType = addStructType({int32Type}); + std::swap(typeDecl, binary); + + SmallVector operands = {structType, 0}; + spirv::encodeStringLiteralInto(operands, "int32"); + operands.push_back(42); + addInstruction(spirv::Opcode::OpMemberName, operands); + + binary.append(typeDecl.begin(), typeDecl.end()); + ASSERT_EQ(llvm::None, deserialize()); + expectDiagnostic("unexpected trailing words in OpMemberName instruction"); +} + //===----------------------------------------------------------------------===// // Functions //===----------------------------------------------------------------------===//