forked from OSchip/llvm-project
Enable (de)serialization support for spirv::AccessChainOp
Automatic generation of spirv::AccessChainOp (de)serialization needs the (de)serialization emitters to handle argument specified as Variadic<...>. To handle this correctly, this argument can only be the last entry in the arguments list. Add a test to (de)serialize spirv::AccessChainOp PiperOrigin-RevId: 260532598
This commit is contained in:
parent
d5a02fcd96
commit
673bb7cbbe
|
@ -95,8 +95,6 @@ def SPV_AccessChainOp : SPV_Op<"AccessChain", [NoSideEffect]> {
|
|||
let results = (outs
|
||||
SPV_AnyPtr:$component_ptr
|
||||
);
|
||||
|
||||
let autogenSerialization = 0;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
// RUN: mlir-translate -serialize-spirv %s | mlir-translate -deserialize-spirv | FileCheck %s
|
||||
|
||||
func @foo() {
|
||||
spv.module "Logical" "VulkanKHR" {
|
||||
func @access_chain(%arg0 : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>,
|
||||
%arg1 : i32, %arg2 : i32) {
|
||||
// CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr<!spv.array<4 x !spv.array<4 x f32>>, Function>
|
||||
// CHECK-NEXT: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spv.ptr<!spv.array<4 x !spv.array<4 x f32>>, Function>
|
||||
%1 = spv.AccessChain %arg0[%arg1] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
|
||||
%2 = spv.AccessChain %arg0[%arg1, %arg2] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
|
||||
spv.Return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
|
@ -126,15 +126,13 @@ static void emitSerializationFunction(const Record *record, const Operator &op,
|
|||
auto argument = op.getArg(i);
|
||||
os << " {\n";
|
||||
if (argument.is<NamedTypeConstraint *>()) {
|
||||
os << " if (" << operandNum
|
||||
<< " < op.getOperation()->getNumOperands()) {\n";
|
||||
os << " auto arg = findValueID(op.getOperation()->getOperand("
|
||||
<< operandNum << "));\n";
|
||||
os << " if (!arg) {\n";
|
||||
os << " for (auto arg : op.getODSOperands(" << i << ")) {\n";
|
||||
os << " auto argID = findValueID(arg);\n";
|
||||
os << " if (!argID) {\n";
|
||||
os << " emitError(op.getLoc(), \"operand " << operandNum
|
||||
<< " has a use before def\");\n";
|
||||
os << " }\n";
|
||||
os << " operands.push_back(arg);\n";
|
||||
os << " operands.push_back(argID);\n";
|
||||
os << " }\n";
|
||||
operandNum++;
|
||||
} else {
|
||||
|
@ -243,32 +241,53 @@ static void emitDeserializationFunction(const Record *record,
|
|||
"SPIR-V ops can have only zero or one result");
|
||||
}
|
||||
|
||||
// Process arguments/attributes
|
||||
// Process operands/attributes
|
||||
os << " SmallVector<Value *, 4> operands;\n";
|
||||
os << " SmallVector<NamedAttribute, 4> attributes;\n";
|
||||
unsigned operandNum = 0;
|
||||
for (unsigned i = 0, e = op.getNumArgs(); i < e; ++i) {
|
||||
auto argument = op.getArg(i);
|
||||
os << " if (wordIndex < words.size()) {\n";
|
||||
if (argument.is<NamedTypeConstraint *>()) {
|
||||
if (auto valueArg = argument.dyn_cast<NamedTypeConstraint *>()) {
|
||||
if (valueArg->isVariadic()) {
|
||||
if (i != e - 1) {
|
||||
PrintFatalError(record->getLoc(),
|
||||
"SPIR-V ops can have Variadic<..> argument only if "
|
||||
"it's the last argument");
|
||||
}
|
||||
os << " for (; wordIndex < words.size(); ++wordIndex)";
|
||||
} else {
|
||||
os << " if (wordIndex < words.size())";
|
||||
}
|
||||
os << " {\n";
|
||||
os << " auto arg = getValue(words[wordIndex]);\n";
|
||||
os << " if (!arg) {\n";
|
||||
os << " return emitError(unknownLoc, \"unknown result <id> : \") << "
|
||||
"words[wordIndex];\n";
|
||||
os << " }\n";
|
||||
os << " operands.push_back(arg);\n";
|
||||
os << " wordIndex++;\n";
|
||||
if (!valueArg->isVariadic()) {
|
||||
os << " wordIndex++;\n";
|
||||
}
|
||||
operandNum++;
|
||||
os << " }\n";
|
||||
} else {
|
||||
os << " if (wordIndex < words.size()) {\n";
|
||||
auto attr = argument.get<NamedAttribute *>();
|
||||
emitAttributeDeserialization(
|
||||
(attr->attr.isOptional() ? attr->attr.getBaseAttr() : attr->attr),
|
||||
record->getLoc(), "attributes", attr->name, "words", "wordIndex",
|
||||
"words.size()", os);
|
||||
os << " }\n";
|
||||
}
|
||||
os << " }\n";
|
||||
}
|
||||
|
||||
os << " if (wordIndex != words.size()) {\n";
|
||||
os << " return emitError(unknownLoc, \"found more operands than expected "
|
||||
"when deserializing "
|
||||
<< op.getQualCppClassName()
|
||||
<< ", only \") << wordIndex << \" of \" << words.size() << \" "
|
||||
"processed\";\n";
|
||||
os << " }\n";
|
||||
os << formatv(" auto op = opBuilder.create<{0}>(unknownLoc, resultTypes, "
|
||||
"operands, attributes); (void)op;\n",
|
||||
op.getQualCppClassName());
|
||||
|
|
Loading…
Reference in New Issue