[mlir][spirv] Fix encoding of cooperative matrix type to match SPIRV spec

Scope, rows and columns need to be encoded in a separate constant operation.

Differential Revision: https://reviews.llvm.org/D80852
This commit is contained in:
Thomas Raoux 2020-06-02 16:14:24 -07:00
parent 587af86f1d
commit bbe79e27bd
2 changed files with 11 additions and 6 deletions

View File

@ -1251,15 +1251,15 @@ Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) {
<< operands[1];
}
auto scope = spirv::symbolizeScope(operands[2]);
auto scope = spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
if (!scope) {
return emitError(unknownLoc,
"OpTypeCooperativeMatrix references undefined scope <id> ")
<< operands[2];
}
unsigned rows = operands[3];
unsigned columns = operands[4];
unsigned rows = getConstantInt(operands[3]).getInt();
unsigned columns = getConstantInt(operands[4]).getInt();
typeMap[operands[0]] = spirv::CooperativeMatrixNVType::get(
elementTy, scope.getValue(), rows, columns);

View File

@ -1104,10 +1104,15 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
return failure();
}
typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV;
auto getConstantOp = [&](uint32_t id) {
auto attr = IntegerAttr::get(IntegerType::get(32, type.getContext()), id);
return prepareConstantInt(loc, attr);
};
operands.push_back(elementTypeID);
operands.push_back(static_cast<uint32_t>(cooperativeMatrixType.getScope()));
operands.push_back(cooperativeMatrixType.getRows());
operands.push_back(cooperativeMatrixType.getColumns());
operands.push_back(
getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())));
operands.push_back(getConstantOp(cooperativeMatrixType.getRows()));
operands.push_back(getConstantOp(cooperativeMatrixType.getColumns()));
return success();
}