forked from OSchip/llvm-project
Handle OpMemberName instruction in SPIR-V deserializer.
Sdd support in deserializer for OpMemberName instruction. For now the name is just processed and not associated with the spirv::StructType being built. That needs an enhancement to spirv::StructTypes itself. Add tests to check for errors reported during deserialization with some refactoring to common out some utility functions. PiperOrigin-RevId: 270794524
This commit is contained in:
parent
4a862fbd63
commit
75906bd565
|
@ -74,6 +74,7 @@ class SPV_OpCode<string name, int val> {
|
|||
|
||||
def SPV_OC_OpNop : I32EnumAttrCase<"OpNop", 0>;
|
||||
def SPV_OC_OpName : I32EnumAttrCase<"OpName", 5>;
|
||||
def SPV_OC_OpMemberName : I32EnumAttrCase<"OpMemberName", 6>;
|
||||
def SPV_OC_OpExtension : I32EnumAttrCase<"OpExtension", 10>;
|
||||
def SPV_OC_OpExtInstImport : I32EnumAttrCase<"OpExtInstImport", 11>;
|
||||
def SPV_OC_OpExtInst : I32EnumAttrCase<"OpExtInst", 12>;
|
||||
|
@ -159,27 +160,28 @@ def SPV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>;
|
|||
|
||||
def SPV_OpcodeAttr :
|
||||
I32EnumAttr<"Opcode", "valid SPIR-V instructions", [
|
||||
SPV_OC_OpNop, SPV_OC_OpName, SPV_OC_OpExtension, SPV_OC_OpExtInstImport,
|
||||
SPV_OC_OpExtInst, SPV_OC_OpMemoryModel, SPV_OC_OpEntryPoint,
|
||||
SPV_OC_OpExecutionMode, SPV_OC_OpCapability, SPV_OC_OpTypeVoid,
|
||||
SPV_OC_OpTypeBool, SPV_OC_OpTypeInt, SPV_OC_OpTypeFloat, SPV_OC_OpTypeVector,
|
||||
SPV_OC_OpTypeArray, SPV_OC_OpTypeRuntimeArray, SPV_OC_OpTypeStruct,
|
||||
SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction, SPV_OC_OpConstantTrue,
|
||||
SPV_OC_OpConstantFalse, SPV_OC_OpConstant, SPV_OC_OpConstantComposite,
|
||||
SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue, SPV_OC_OpSpecConstantFalse,
|
||||
SPV_OC_OpSpecConstant, SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction,
|
||||
SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall,
|
||||
SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain,
|
||||
SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract,
|
||||
SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul,
|
||||
SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod,
|
||||
SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect,
|
||||
SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan,
|
||||
SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual,
|
||||
SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual,
|
||||
SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual,
|
||||
SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan,
|
||||
SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
|
||||
SPV_OC_OpNop, SPV_OC_OpName, SPV_OC_OpMemberName, SPV_OC_OpExtension,
|
||||
SPV_OC_OpExtInstImport, SPV_OC_OpExtInst, SPV_OC_OpMemoryModel,
|
||||
SPV_OC_OpEntryPoint, SPV_OC_OpExecutionMode, SPV_OC_OpCapability,
|
||||
SPV_OC_OpTypeVoid, SPV_OC_OpTypeBool, SPV_OC_OpTypeInt, SPV_OC_OpTypeFloat,
|
||||
SPV_OC_OpTypeVector, SPV_OC_OpTypeArray, SPV_OC_OpTypeRuntimeArray,
|
||||
SPV_OC_OpTypeStruct, SPV_OC_OpTypePointer, SPV_OC_OpTypeFunction,
|
||||
SPV_OC_OpConstantTrue, SPV_OC_OpConstantFalse, SPV_OC_OpConstant,
|
||||
SPV_OC_OpConstantComposite, SPV_OC_OpConstantNull, SPV_OC_OpSpecConstantTrue,
|
||||
SPV_OC_OpSpecConstantFalse, SPV_OC_OpSpecConstant,
|
||||
SPV_OC_OpSpecConstantComposite, SPV_OC_OpFunction, SPV_OC_OpFunctionParameter,
|
||||
SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad,
|
||||
SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate,
|
||||
SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeExtract, SPV_OC_OpIAdd,
|
||||
SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul,
|
||||
SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem,
|
||||
SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpSelect, SPV_OC_OpIEqual,
|
||||
SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
|
||||
SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
|
||||
SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
|
||||
SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual,
|
||||
SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan,
|
||||
SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan,
|
||||
SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual,
|
||||
SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual,
|
||||
SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, SPV_OC_OpLoopMerge,
|
||||
|
|
|
@ -46,6 +46,12 @@ constexpr uint32_t kGeneratorNumber = 22;
|
|||
/// Appends a SPRI-V module header to `header` with the given `idBound`.
|
||||
void appendModuleHeader(SmallVectorImpl<uint32_t> &header, uint32_t idBound);
|
||||
|
||||
/// Returns the word-count-prefixed opcode for an SPIR-V instruction.
|
||||
uint32_t getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode);
|
||||
|
||||
/// Encodes an SPIR-V `literal` string into the given `binary` vector.
|
||||
LogicalResult encodeStringLiteralInto(SmallVectorImpl<uint32_t> &binary,
|
||||
StringRef literal);
|
||||
} // end namespace spirv
|
||||
} // end namespace mlir
|
||||
|
||||
|
|
|
@ -112,12 +112,15 @@ private:
|
|||
/// Process SPIR-V OpName with `operands`.
|
||||
LogicalResult processName(ArrayRef<uint32_t> operands);
|
||||
|
||||
/// Method to process an OpDecorate instruction.
|
||||
/// Processes an OpDecorate instruction.
|
||||
LogicalResult processDecoration(ArrayRef<uint32_t> words);
|
||||
|
||||
// Method to process an OpMemberDecorate instruction.
|
||||
// Processes an OpMemberDecorate instruction.
|
||||
LogicalResult processMemberDecoration(ArrayRef<uint32_t> words);
|
||||
|
||||
/// Processes an OpMemberName instruction.
|
||||
LogicalResult processMemberName(ArrayRef<uint32_t> words);
|
||||
|
||||
/// Gets the FuncOp associated with a result <id> of OpFunction.
|
||||
FuncOp getFunction(uint32_t id) { return funcMap.lookup(id); }
|
||||
|
||||
|
@ -410,6 +413,10 @@ private:
|
|||
DenseMap<uint32_t, DenseMap<spirv::Decoration, ArrayRef<uint32_t>>>>
|
||||
memberDecorationMap;
|
||||
|
||||
// Result <id> to member name.
|
||||
// struct-type-<id> -> (struct-member-index -> name)
|
||||
DenseMap<uint32_t, DenseMap<uint32_t, StringRef>> memberNameMap;
|
||||
|
||||
// Result <id> to extended instruction set name.
|
||||
DenseMap<uint32_t, StringRef> extendedInstSets;
|
||||
|
||||
|
@ -650,6 +657,20 @@ LogicalResult Deserializer::processMemberDecoration(ArrayRef<uint32_t> words) {
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult Deserializer::processMemberName(ArrayRef<uint32_t> words) {
|
||||
if (words.size() < 3) {
|
||||
return emitError(unknownLoc, "OpMemberName must have at least 3 operands");
|
||||
}
|
||||
unsigned wordIndex = 2;
|
||||
auto name = decodeStringLiteral(words, wordIndex);
|
||||
if (wordIndex != words.size()) {
|
||||
return emitError(unknownLoc,
|
||||
"unexpected trailing words in OpMemberName instruction");
|
||||
}
|
||||
memberNameMap[words[0]][words[1]] = name;
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult Deserializer::processFunction(ArrayRef<uint32_t> operands) {
|
||||
if (curFunction) {
|
||||
return emitError(unknownLoc, "found function inside function");
|
||||
|
@ -1151,6 +1172,8 @@ LogicalResult Deserializer::processStructType(ArrayRef<uint32_t> operands) {
|
|||
}
|
||||
typeMap[operands[0]] =
|
||||
spirv::StructType::get(memberTypes, layoutInfo, memberDecorationsInfo);
|
||||
// TODO(ravishankarm): Update StructType to have member name as attribute as
|
||||
// well.
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -1746,6 +1769,8 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
|
|||
return processExtInst(operands);
|
||||
case spirv::Opcode::OpExtInstImport:
|
||||
return processExtInstImport(operands);
|
||||
case spirv::Opcode::OpMemberName:
|
||||
return processMemberName(operands);
|
||||
case spirv::Opcode::OpMemoryModel:
|
||||
return processMemoryModel(operands);
|
||||
case spirv::Opcode::OpEntryPoint:
|
||||
|
|
|
@ -51,3 +51,19 @@ void spirv::appendModuleHeader(SmallVectorImpl<uint32_t> &header,
|
|||
header.push_back(idBound); // <id> bound
|
||||
header.push_back(0); // Schema (reserved word)
|
||||
}
|
||||
|
||||
/// Returns the word-count-prefixed opcode for an SPIR-V instruction.
|
||||
uint32_t spirv::getPrefixedOpcode(uint32_t wordCount, spirv::Opcode opcode) {
|
||||
assert(((wordCount >> 16) == 0) && "word count out of range!");
|
||||
return (wordCount << 16) | static_cast<uint32_t>(opcode);
|
||||
}
|
||||
|
||||
LogicalResult spirv::encodeStringLiteralInto(SmallVectorImpl<uint32_t> &binary,
|
||||
StringRef literal) {
|
||||
// We need to encode the literal and the null termination.
|
||||
auto encodingSize = literal.size() / 4 + 1;
|
||||
auto bufferStartSize = binary.size();
|
||||
binary.resize(bufferStartSize + encodingSize, 0);
|
||||
std::memcpy(binary.data() + bufferStartSize, literal.data(), literal.size());
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -36,37 +36,19 @@
|
|||
|
||||
using namespace mlir;
|
||||
|
||||
/// Returns the word-count-prefixed opcode for an SPIR-V instruction.
|
||||
static inline uint32_t getPrefixedOpcode(uint32_t wordCount,
|
||||
spirv::Opcode opcode) {
|
||||
assert(((wordCount >> 16) == 0) && "word count out of range!");
|
||||
return (wordCount << 16) | static_cast<uint32_t>(opcode);
|
||||
}
|
||||
|
||||
/// Encodes an SPIR-V instruction with the given `opcode` and `operands` into
|
||||
/// the given `binary` vector.
|
||||
static LogicalResult encodeInstructionInto(SmallVectorImpl<uint32_t> &binary,
|
||||
spirv::Opcode op,
|
||||
ArrayRef<uint32_t> operands) {
|
||||
LogicalResult encodeInstructionInto(SmallVectorImpl<uint32_t> &binary,
|
||||
spirv::Opcode op,
|
||||
ArrayRef<uint32_t> operands) {
|
||||
uint32_t wordCount = 1 + operands.size();
|
||||
binary.push_back(getPrefixedOpcode(wordCount, op));
|
||||
binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
|
||||
if (!operands.empty()) {
|
||||
binary.append(operands.begin(), operands.end());
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Encodes an SPIR-V `literal` string into the given `binary` vector.
|
||||
static LogicalResult encodeStringLiteralInto(SmallVectorImpl<uint32_t> &binary,
|
||||
StringRef literal) {
|
||||
// We need to encode the literal and the null termination.
|
||||
auto encodingSize = literal.size() / 4 + 1;
|
||||
auto bufferStartSize = binary.size();
|
||||
binary.resize(bufferStartSize + encodingSize, 0);
|
||||
std::memcpy(binary.data() + bufferStartSize, literal.data(), literal.size());
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/// A SPIR-V module serializer.
|
||||
|
@ -435,7 +417,7 @@ void Serializer::processExtension() {
|
|||
for (auto ext : exts.getValue()) {
|
||||
auto extStr = ext.cast<StringAttr>().getValue();
|
||||
extName.clear();
|
||||
encodeStringLiteralInto(extName, extStr);
|
||||
spirv::encodeStringLiteralInto(extName, extStr);
|
||||
encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
|
||||
}
|
||||
}
|
||||
|
@ -508,7 +490,7 @@ LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
|
|||
|
||||
SmallVector<uint32_t, 4> nameOperands;
|
||||
nameOperands.push_back(resultID);
|
||||
if (failed(encodeStringLiteralInto(nameOperands, name))) {
|
||||
if (failed(spirv::encodeStringLiteralInto(nameOperands, name))) {
|
||||
return failure();
|
||||
}
|
||||
return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
|
||||
|
@ -1388,7 +1370,8 @@ LogicalResult Serializer::encodeExtensionInstruction(
|
|||
setID = getNextID();
|
||||
SmallVector<uint32_t, 16> importOperands;
|
||||
importOperands.push_back(setID);
|
||||
if (failed(encodeStringLiteralInto(importOperands, extensionSetName)) ||
|
||||
if (failed(
|
||||
spirv::encodeStringLiteralInto(importOperands, extensionSetName)) ||
|
||||
failed(encodeInstructionInto(
|
||||
extendedSets, spirv::Opcode::OpExtInstImport, importOperands))) {
|
||||
return failure();
|
||||
|
@ -1490,7 +1473,7 @@ Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
|
|||
}
|
||||
operands.push_back(funcID);
|
||||
// Add the name of the function.
|
||||
encodeStringLiteralInto(operands, op.fn());
|
||||
spirv::encodeStringLiteralInto(operands, op.fn());
|
||||
|
||||
// Add the interface values.
|
||||
if (auto interface = op.interface()) {
|
||||
|
|
|
@ -12,8 +12,6 @@ func @access_chain_struct() -> () {
|
|||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @access_chain_1D_array(%arg0 : i32) -> () {
|
||||
%0 = spv.Variable : !spv.ptr<!spv.array<4xf32>, Function>
|
||||
// CHECK: spv.AccessChain {{.*}}[{{.*}}] : !spv.ptr<!spv.array<4 x f32>, Function>
|
||||
|
@ -21,8 +19,6 @@ func @access_chain_1D_array(%arg0 : i32) -> () {
|
|||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @access_chain_2D_array_1(%arg0 : i32) -> () {
|
||||
%0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
|
||||
// CHECK: spv.AccessChain {{.*}}[{{.*}}, {{.*}}] : !spv.ptr<!spv.array<4 x !spv.array<4 x f32>>, Function>
|
||||
|
@ -31,8 +27,6 @@ func @access_chain_2D_array_1(%arg0 : i32) -> () {
|
|||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @access_chain_2D_array_2(%arg0 : i32) -> () {
|
||||
%0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
|
||||
// CHECK: spv.AccessChain {{.*}}[{{.*}}] : !spv.ptr<!spv.array<4 x !spv.array<4 x f32>>, Function>
|
||||
|
|
|
@ -73,10 +73,7 @@ protected:
|
|||
/// Adds the SPIR-V instruction into `binary`.
|
||||
void addInstruction(spirv::Opcode op, ArrayRef<uint32_t> operands) {
|
||||
uint32_t wordCount = 1 + operands.size();
|
||||
assert(((wordCount >> 16) == 0) && "word count out of range!");
|
||||
|
||||
uint32_t prefixedOpcode = (wordCount << 16) | static_cast<uint32_t>(op);
|
||||
binary.push_back(prefixedOpcode);
|
||||
binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
|
||||
binary.append(operands.begin(), operands.end());
|
||||
}
|
||||
|
||||
|
@ -92,6 +89,15 @@ protected:
|
|||
return id;
|
||||
}
|
||||
|
||||
uint32_t addStructType(ArrayRef<uint32_t> memberTypes) {
|
||||
auto id = nextID++;
|
||||
SmallVector<uint32_t, 2> words;
|
||||
words.push_back(id);
|
||||
words.append(memberTypes.begin(), memberTypes.end());
|
||||
addInstruction(spirv::Opcode::OpTypeStruct, words);
|
||||
return id;
|
||||
}
|
||||
|
||||
uint32_t addFunctionType(uint32_t retType, ArrayRef<uint32_t> paramTypes) {
|
||||
auto id = nextID++;
|
||||
SmallVector<uint32_t, 4> operands;
|
||||
|
@ -173,6 +179,68 @@ TEST_F(DeserializationTest, IntTypeMissingSignednessFailure) {
|
|||
expectDiagnostic("OpTypeInt must have bitwidth and signedness parameters");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// StructType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
TEST_F(DeserializationTest, OpMemberNameSuccess) {
|
||||
addHeader();
|
||||
SmallVector<uint32_t, 5> typeDecl;
|
||||
std::swap(typeDecl, binary);
|
||||
|
||||
auto int32Type = addIntType(32);
|
||||
auto structType = addStructType({int32Type, int32Type});
|
||||
std::swap(typeDecl, binary);
|
||||
|
||||
SmallVector<uint32_t, 5> operands1 = {structType, 0};
|
||||
spirv::encodeStringLiteralInto(operands1, "i1");
|
||||
addInstruction(spirv::Opcode::OpMemberName, operands1);
|
||||
|
||||
SmallVector<uint32_t, 5> operands2 = {structType, 1};
|
||||
spirv::encodeStringLiteralInto(operands2, "i2");
|
||||
addInstruction(spirv::Opcode::OpMemberName, operands2);
|
||||
|
||||
binary.append(typeDecl.begin(), typeDecl.end());
|
||||
EXPECT_NE(llvm::None, deserialize());
|
||||
}
|
||||
|
||||
TEST_F(DeserializationTest, OpMemberNameMissingOperands) {
|
||||
addHeader();
|
||||
SmallVector<uint32_t, 5> typeDecl;
|
||||
std::swap(typeDecl, binary);
|
||||
|
||||
auto int32Type = addIntType(32);
|
||||
auto int64Type = addIntType(64);
|
||||
auto structType = addStructType({int32Type, int64Type});
|
||||
std::swap(typeDecl, binary);
|
||||
|
||||
SmallVector<uint32_t, 5> operands1 = {structType};
|
||||
addInstruction(spirv::Opcode::OpMemberName, operands1);
|
||||
|
||||
binary.append(typeDecl.begin(), typeDecl.end());
|
||||
ASSERT_EQ(llvm::None, deserialize());
|
||||
expectDiagnostic("OpMemberName must have at least 3 operands");
|
||||
}
|
||||
|
||||
TEST_F(DeserializationTest, OpMemberNameExcessOperands) {
|
||||
addHeader();
|
||||
SmallVector<uint32_t, 5> typeDecl;
|
||||
std::swap(typeDecl, binary);
|
||||
|
||||
auto int32Type = addIntType(32);
|
||||
auto structType = addStructType({int32Type});
|
||||
std::swap(typeDecl, binary);
|
||||
|
||||
SmallVector<uint32_t, 5> operands = {structType, 0};
|
||||
spirv::encodeStringLiteralInto(operands, "int32");
|
||||
operands.push_back(42);
|
||||
addInstruction(spirv::Opcode::OpMemberName, operands);
|
||||
|
||||
binary.append(typeDecl.begin(), typeDecl.end());
|
||||
ASSERT_EQ(llvm::None, deserialize());
|
||||
expectDiagnostic("unexpected trailing words in OpMemberName instruction");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Functions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
Loading…
Reference in New Issue