Handle OpMemberName instruction in SPIR-V deserializer.

Sdd support in deserializer for OpMemberName instruction. For now
the name is just processed and not associated with the
spirv::StructType being built. That needs an enhancement to
spirv::StructTypes itself.
Add tests to check for errors reported during deserialization with
some refactoring to common out some utility functions.
PiperOrigin-RevId: 270794524
This commit is contained in:
Mahesh Ravishankar 2019-09-23 17:10:49 -07:00 committed by A. Unique TensorFlower
parent 4a862fbd63
commit 75906bd565
7 changed files with 153 additions and 59 deletions

View File

@ -74,6 +74,7 @@ class SPV_OpCode<string name, int val> {
def SPV_OC_OpNop : I32EnumAttrCase<"OpNop", 0>;
def SPV_OC_OpName : I32EnumAttrCase<"OpName", 5>;
def SPV_OC_OpMemberName : I32EnumAttrCase<"OpMemberName", 6>;
def SPV_OC_OpExtension : I32EnumAttrCase<"OpExtension", 10>;
def SPV_OC_OpExtInstImport : I32EnumAttrCase<"OpExtInstImport", 11>;
def SPV_OC_OpExtInst : I32EnumAttrCase<"OpExtInst", 12>;
@ -159,27 +160,28 @@ def SPV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>;
def SPV_OpcodeAttr :
I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
SPV_OC_OpNop, SPV_OC_OpName, SPV_OC_OpExtension, SPV_OC_OpExtInstImport,
SPV_OC_OpExtInst, SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint,
SPV_OC_OpExecutionMode, SPV_OC_OpCapability, SPV_OC_OpTypeVoid,
SPV_OC_OpTypeBool, SPV_OC_OpTypeInt, SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector,
SPV_OC_OpTypeArray, SPV_OC_OpTypeRuntimeArray, SPV_OC_OpTypeStruct,
SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue,
SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite,
SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse,
SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction,
SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall,
SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain,
SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract,
SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul,
SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod,
SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect,
SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan,
SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
SPV_OC_OpNop, SPV_OC_OpName, SPV_OC_OpMemberName, SPV_OC_OpExtension,
SPV_OC_OpExtInstImport, SPV_OC_OpExtInst, SPV_OC_OpMemoryModel,
SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode, SPV_OC_OpCapability,
SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt, SPV_OC_OpTypeFloat,
SPV_OC_OpTypeVector, SPV_OC_OpTypeArray, SPV_OC_OpTypeRuntimeArray,
SPV_OC_OpTypeStruct, SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction,
SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse, SPV_OC_OpConstant,
SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue,
SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant,
SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad,
SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpIAdd,
SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul,
SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem,
SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect, SPV_OC_OpIEqual,
SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual,
SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan,
SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, SPV_OC_OpLoopMerge,

View File

@ -46,6 +46,12 @@ constexpr uint32_t kGeneratorNumber = 22;
/// Appends a SPRI-V module header to `header` with the given `idBound`.
void appendModuleHeader(SmallVectorImpl<uint32_t> &header, uint32_t idBound);
/// Returns the word-count-prefixed opcode for an SPIR-V instruction.
uint32_t getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode);
/// Encodes an SPIR-V `literal` string into the given `binary` vector.
LogicalResult encodeStringLiteralInto(SmallVectorImpl<uint32_t> &binary,
StringRef literal);
} // end namespace spirv
} // end namespace mlir

View File

@ -112,12 +112,15 @@ private:
/// Process SPIR-V OpName with `operands`.
LogicalResult processName(ArrayRef<uint32_t> operands);
/// Method to process an OpDecorate instruction.
/// Processes an OpDecorate instruction.
LogicalResult processDecoration(ArrayRef<uint32_t> words);
// Method to process an OpMemberDecorate instruction.
// Processes an OpMemberDecorate instruction.
LogicalResult processMemberDecoration(ArrayRef<uint32_t> words);
/// Processes an OpMemberName instruction.
LogicalResult processMemberName(ArrayRef<uint32_t> words);
/// Gets the FuncOp associated with a result <id> of OpFunction.
FuncOp getFunction(uint32_t id) { return funcMap.lookup(id); }
@ -410,6 +413,10 @@ private:
DenseMap<uint32_t, DenseMap<spirv::Decoration, ArrayRef<uint32_t>>>>
memberDecorationMap;
// Result <id> to member name.
// struct-type-<id> -> (struct-member-index -> name)
DenseMap<uint32_t, DenseMap<uint32_t, StringRef>> memberNameMap;
// Result <id> to extended instruction set name.
DenseMap<uint32_t, StringRef> extendedInstSets;
@ -650,6 +657,20 @@ LogicalResult Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
return success();
}
LogicalResult Deserializer::processMemberName(ArrayRef<uint32_t> words) {
if (words.size() < 3) {
return emitError(unknownLoc, "OpMemberName must have at least 3 operands");
}
unsigned wordIndex = 2;
auto name = decodeStringLiteral(words, wordIndex);
if (wordIndex != words.size()) {
return emitError(unknownLoc,
"unexpected trailing words in OpMemberName instruction");
}
memberNameMap[words[0]][words[1]] = name;
return success();
}
LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
if (curFunction) {
return emitError(unknownLoc, "found function inside function");
@ -1151,6 +1172,8 @@ LogicalResult Deserializer::processStructType(ArrayRef<uint32_t> operands) {
}
typeMap[operands[0]] =
spirv::StructType::get(memberTypes, layoutInfo, memberDecorationsInfo);
// TODO(ravishankarm): Update StructType to have member name as attribute as
// well.
return success();
}
@ -1746,6 +1769,8 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
return processExtInst(operands);
case spirv::Opcode::OpExtInstImport:
return processExtInstImport(operands);
case spirv::Opcode::OpMemberName:
return processMemberName(operands);
case spirv::Opcode::OpMemoryModel:
return processMemoryModel(operands);
case spirv::Opcode::OpEntryPoint:

View File

@ -51,3 +51,19 @@ void spirv::appendModuleHeader(SmallVectorImpl<uint32_t> &header,
header.push_back(idBound); // <id> bound
header.push_back(0); // Schema (reserved word)
}
/// Returns the word-count-prefixed opcode for an SPIR-V instruction.
uint32_t spirv::getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode) {
assert(((wordCount >> 16) == 0) && "word count out of range!");
return (wordCount << 16) | static_cast<uint32_t>(opcode);
}
LogicalResult spirv::encodeStringLiteralInto(SmallVectorImpl<uint32_t> &binary,
StringRef literal) {
// We need to encode the literal and the null termination.
auto encodingSize = literal.size() / 4 + 1;
auto bufferStartSize = binary.size();
binary.resize(bufferStartSize + encodingSize, 0);
std::memcpy(binary.data() + bufferStartSize, literal.data(), literal.size());
return success();
}

View File

@ -36,37 +36,19 @@
using namespace mlir;
/// Returns the word-count-prefixed opcode for an SPIR-V instruction.
static inline uint32_t getPrefixedOpcode(uint32_t wordCount,
spirv::Opcode opcode) {
assert(((wordCount >> 16) == 0) && "word count out of range!");
return (wordCount << 16) | static_cast<uint32_t>(opcode);
}
/// Encodes an SPIR-V instruction with the given `opcode` and `operands` into
/// the given `binary` vector.
static LogicalResult encodeInstructionInto(SmallVectorImpl<uint32_t> &binary,
spirv::Opcode op,
ArrayRef<uint32_t> operands) {
LogicalResult encodeInstructionInto(SmallVectorImpl<uint32_t> &binary,
spirv::Opcode op,
ArrayRef<uint32_t> operands) {
uint32_t wordCount = 1 + operands.size();
binary.push_back(getPrefixedOpcode(wordCount, op));
binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
if (!operands.empty()) {
binary.append(operands.begin(), operands.end());
}
return success();
}
/// Encodes an SPIR-V `literal` string into the given `binary` vector.
static LogicalResult encodeStringLiteralInto(SmallVectorImpl<uint32_t> &binary,
StringRef literal) {
// We need to encode the literal and the null termination.
auto encodingSize = literal.size() / 4 + 1;
auto bufferStartSize = binary.size();
binary.resize(bufferStartSize + encodingSize, 0);
std::memcpy(binary.data() + bufferStartSize, literal.data(), literal.size());
return success();
}
namespace {
/// A SPIR-V module serializer.
@ -435,7 +417,7 @@ void Serializer::processExtension() {
for (auto ext : exts.getValue()) {
auto extStr = ext.cast<StringAttr>().getValue();
extName.clear();
encodeStringLiteralInto(extName, extStr);
spirv::encodeStringLiteralInto(extName, extStr);
encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
}
}
@ -508,7 +490,7 @@ LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
SmallVector<uint32_t, 4> nameOperands;
nameOperands.push_back(resultID);
if (failed(encodeStringLiteralInto(nameOperands, name))) {
if (failed(spirv::encodeStringLiteralInto(nameOperands, name))) {
return failure();
}
return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
@ -1388,7 +1370,8 @@ LogicalResult Serializer::encodeExtensionInstruction(
setID = getNextID();
SmallVector<uint32_t, 16> importOperands;
importOperands.push_back(setID);
if (failed(encodeStringLiteralInto(importOperands, extensionSetName)) ||
if (failed(
spirv::encodeStringLiteralInto(importOperands, extensionSetName)) ||
failed(encodeInstructionInto(
extendedSets, spirv::Opcode::OpExtInstImport, importOperands))) {
return failure();
@ -1490,7 +1473,7 @@ Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
}
operands.push_back(funcID);
// Add the name of the function.
encodeStringLiteralInto(operands, op.fn());
spirv::encodeStringLiteralInto(operands, op.fn());
// Add the interface values.
if (auto interface = op.interface()) {

View File

@ -12,8 +12,6 @@ func @access_chain_struct() -> () {
return
}
// -----
func @access_chain_1D_array(%arg0 : i32) -> () {
%0 = spv.Variable : !spv.ptr<!spv.array<4xf32>, Function>
// CHECK: spv.AccessChain {{.*}}[{{.*}}] : !spv.ptr<!spv.array<4 x f32>, Function>
@ -21,8 +19,6 @@ func @access_chain_1D_array(%arg0 : i32) -> () {
return
}
// -----
func @access_chain_2D_array_1(%arg0 : i32) -> () {
%0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
// CHECK: spv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spv.ptr<!spv.array<4 x !spv.array<4 x f32>>, Function>
@ -31,8 +27,6 @@ func @access_chain_2D_array_1(%arg0 : i32) -> () {
return
}
// -----
func @access_chain_2D_array_2(%arg0 : i32) -> () {
%0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
// CHECK: spv.AccessChain {{.*}}[{{.*}}] : !spv.ptr<!spv.array<4 x !spv.array<4 x f32>>, Function>

View File

@ -73,10 +73,7 @@ protected:
/// Adds the SPIR-V instruction into `binary`.
void addInstruction(spirv::Opcode op, ArrayRef<uint32_t> operands) {
uint32_t wordCount = 1 + operands.size();
assert(((wordCount >> 16) == 0) && "word count out of range!");
uint32_t prefixedOpcode = (wordCount << 16) | static_cast<uint32_t>(op);
binary.push_back(prefixedOpcode);
binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
binary.append(operands.begin(), operands.end());
}
@ -92,6 +89,15 @@ protected:
return id;
}
uint32_t addStructType(ArrayRef<uint32_t> memberTypes) {
auto id = nextID++;
SmallVector<uint32_t, 2> words;
words.push_back(id);
words.append(memberTypes.begin(), memberTypes.end());
addInstruction(spirv::Opcode::OpTypeStruct, words);
return id;
}
uint32_t addFunctionType(uint32_t retType, ArrayRef<uint32_t> paramTypes) {
auto id = nextID++;
SmallVector<uint32_t, 4> operands;
@ -173,6 +179,68 @@ TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) {
expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters");
}
//===----------------------------------------------------------------------===//
// StructType
//===----------------------------------------------------------------------===//
TEST_F(DeserializationTest, OpMemberNameSuccess) {
addHeader();
SmallVector<uint32_t, 5> typeDecl;
std::swap(typeDecl, binary);
auto int32Type = addIntType(32);
auto structType = addStructType({int32Type, int32Type});
std::swap(typeDecl, binary);
SmallVector<uint32_t, 5> operands1 = {structType, 0};
spirv::encodeStringLiteralInto(operands1, "i1");
addInstruction(spirv::Opcode::OpMemberName, operands1);
SmallVector<uint32_t, 5> operands2 = {structType, 1};
spirv::encodeStringLiteralInto(operands2, "i2");
addInstruction(spirv::Opcode::OpMemberName, operands2);
binary.append(typeDecl.begin(), typeDecl.end());
EXPECT_NE(llvm::None, deserialize());
}
TEST_F(DeserializationTest, OpMemberNameMissingOperands) {
addHeader();
SmallVector<uint32_t, 5> typeDecl;
std::swap(typeDecl, binary);
auto int32Type = addIntType(32);
auto int64Type = addIntType(64);
auto structType = addStructType({int32Type, int64Type});
std::swap(typeDecl, binary);
SmallVector<uint32_t, 5> operands1 = {structType};
addInstruction(spirv::Opcode::OpMemberName, operands1);
binary.append(typeDecl.begin(), typeDecl.end());
ASSERT_EQ(llvm::None, deserialize());
expectDiagnostic("OpMemberName must have at least 3 operands");
}
TEST_F(DeserializationTest, OpMemberNameExcessOperands) {
addHeader();
SmallVector<uint32_t, 5> typeDecl;
std::swap(typeDecl, binary);
auto int32Type = addIntType(32);
auto structType = addStructType({int32Type});
std::swap(typeDecl, binary);
SmallVector<uint32_t, 5> operands = {structType, 0};
spirv::encodeStringLiteralInto(operands, "int32");
operands.push_back(42);
addInstruction(spirv::Opcode::OpMemberName, operands);
binary.append(typeDecl.begin(), typeDecl.end());
ASSERT_EQ(llvm::None, deserialize());
expectDiagnostic("unexpected trailing words in OpMemberName instruction");
}
//===----------------------------------------------------------------------===//
// Functions
//===----------------------------------------------------------------------===//