forked from OSchip/llvm-project
[mlir][spirv] Fix encoding of cooperative matrix type to match SPIRV spec
Scope, rows and columns need to be encoded in a separate constant operation. Differential Revision: https://reviews.llvm.org/D80852
This commit is contained in:
parent
587af86f1d
commit
bbe79e27bd
|
@ -1251,15 +1251,15 @@ Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) {
|
||||||
<< operands[1];
|
<< operands[1];
|
||||||
}
|
}
|
||||||
|
|
||||||
auto scope = spirv::symbolizeScope(operands[2]);
|
auto scope = spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
|
||||||
if (!scope) {
|
if (!scope) {
|
||||||
return emitError(unknownLoc,
|
return emitError(unknownLoc,
|
||||||
"OpTypeCooperativeMatrix references undefined scope <id> ")
|
"OpTypeCooperativeMatrix references undefined scope <id> ")
|
||||||
<< operands[2];
|
<< operands[2];
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned rows = operands[3];
|
unsigned rows = getConstantInt(operands[3]).getInt();
|
||||||
unsigned columns = operands[4];
|
unsigned columns = getConstantInt(operands[4]).getInt();
|
||||||
|
|
||||||
typeMap[operands[0]] = spirv::CooperativeMatrixNVType::get(
|
typeMap[operands[0]] = spirv::CooperativeMatrixNVType::get(
|
||||||
elementTy, scope.getValue(), rows, columns);
|
elementTy, scope.getValue(), rows, columns);
|
||||||
|
|
|
@ -1104,10 +1104,15 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV;
|
typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV;
|
||||||
|
auto getConstantOp = [&](uint32_t id) {
|
||||||
|
auto attr = IntegerAttr::get(IntegerType::get(32, type.getContext()), id);
|
||||||
|
return prepareConstantInt(loc, attr);
|
||||||
|
};
|
||||||
operands.push_back(elementTypeID);
|
operands.push_back(elementTypeID);
|
||||||
operands.push_back(static_cast<uint32_t>(cooperativeMatrixType.getScope()));
|
operands.push_back(
|
||||||
operands.push_back(cooperativeMatrixType.getRows());
|
getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())));
|
||||||
operands.push_back(cooperativeMatrixType.getColumns());
|
operands.push_back(getConstantOp(cooperativeMatrixType.getRows()));
|
||||||
|
operands.push_back(getConstantOp(cooperativeMatrixType.getColumns()));
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue