Enable (de)serialization support for spirv::AccessChainOp

Automatic generation of spirv::AccessChainOp (de)serialization needs
the (de)serialization emitters to handle argument specified as
Variadic<...>. To handle this correctly, this argument can only be
the last entry in the arguments list.
Add a test to (de)serialize spirv::AccessChainOp

PiperOrigin-RevId: 260532598
This commit is contained in:
Mahesh Ravishankar 2019-07-29 10:45:17 -07:00 committed by jpienaar
parent d5a02fcd96
commit 673bb7cbbe
3 changed files with 45 additions and 13 deletions

View File

@ -95,8 +95,6 @@ def SPV_AccessChainOp : SPV_Op<"AccessChain", [NoSideEffect]> {
let results = (outs
SPV_AnyPtr:$component_ptr
);
let autogenSerialization = 0;
}
// -----

View File

@ -0,0 +1,15 @@
// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s
func @foo() {
spv.module "Logical" "VulkanKHR" {
func @access_chain(%arg0 : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>,
%arg1 : i32, %arg2 : i32) {
// CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.array<4 x !spv.array<4 x f32>>, Function>
// CHECK-NEXT: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.array<4 x !spv.array<4 x f32>>, Function>
%1 = spv.AccessChain %arg0[%arg1] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
%2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
spv.Return
}
}
return
}

View File

@ -126,15 +126,13 @@ static void emitSerializationFunction(const Record *record, const Operator &op,
auto argument = op.getArg(i);
os << " {\n";
if (argument.is<NamedTypeConstraint *>()) {
os << " if (" << operandNum
<< " < op.getOperation()->getNumOperands()) {\n";
os << " auto arg = findValueID(op.getOperation()->getOperand("
<< operandNum << "));\n";
os << " if (!arg) {\n";
os << " for (auto arg : op.getODSOperands(" << i << ")) {\n";
os << " auto argID = findValueID(arg);\n";
os << " if (!argID) {\n";
os << " emitError(op.getLoc(), \"operand " << operandNum
<< " has a use before def\");\n";
os << " }\n";
os << " operands.push_back(arg);\n";
os << " operands.push_back(argID);\n";
os << " }\n";
operandNum++;
} else {
@ -243,32 +241,53 @@ static void emitDeserializationFunction(const Record *record,
"SPIR-V ops can have only zero or one result");
}
// Process arguments/attributes
// Process operands/attributes
os << " SmallVector<Value *, 4> operands;\n";
os << " SmallVector<NamedAttribute, 4> attributes;\n";
unsigned operandNum = 0;
for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
auto argument = op.getArg(i);
os << " if (wordIndex < words.size()) {\n";
if (argument.is<NamedTypeConstraint *>()) {
if (auto valueArg = argument.dyn_cast<NamedTypeConstraint *>()) {
if (valueArg->isVariadic()) {
if (i != e - 1) {
PrintFatalError(record->getLoc(),
"SPIR-V ops can have Variadic<..> argument only if "
"it's the last argument");
}
os << " for (; wordIndex < words.size(); ++wordIndex)";
} else {
os << " if (wordIndex < words.size())";
}
os << " {\n";
os << " auto arg = getValue(words[wordIndex]);\n";
os << " if (!arg) {\n";
os << " return emitError(unknownLoc, \"unknown result <id> : \") << "
"words[wordIndex];\n";
os << " }\n";
os << " operands.push_back(arg);\n";
os << " wordIndex++;\n";
if (!valueArg->isVariadic()) {
os << " wordIndex++;\n";
}
operandNum++;
os << " }\n";
} else {
os << " if (wordIndex < words.size()) {\n";
auto attr = argument.get<NamedAttribute *>();
emitAttributeDeserialization(
(attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
record->getLoc(), "attributes", attr->name, "words", "wordIndex",
"words.size()", os);
os << " }\n";
}
os << " }\n";
}
os << " if (wordIndex != words.size()) {\n";
os << " return emitError(unknownLoc, \"found more operands than expected "
"when deserializing "
<< op.getQualCppClassName()
<< ", only \") << wordIndex << \" of \" << words.size() << \" "
"processed\";\n";
os << " }\n";
os << formatv(" auto op = opBuilder.create<{0}>(unknownLoc, resultTypes, "
"operands, attributes); (void)op;\n",
op.getQualCppClassName());