[mlir][spirv] Fix encoding of cooperative matrix type to match SPIRV spec

Scope, rows and columns need to be encoded in a separate constant operation.

Differential Revision: https://reviews.llvm.org/D80852
This commit is contained in:
Thomas Raoux 2020-06-02 16:14:24 -07:00
parent 587af86f1d
commit bbe79e27bd
2 changed files with 11 additions and 6 deletions

View File

@ -1251,15 +1251,15 @@ Deserializer::processCooperativeMatrixType(ArrayRef<uint32_t> operands) {
<< operands[1]; << operands[1];
} }
auto scope = spirv::symbolizeScope(operands[2]); auto scope = spirv::symbolizeScope(getConstantInt(operands[2]).getInt());
if (!scope) { if (!scope) {
return emitError(unknownLoc, return emitError(unknownLoc,
"OpTypeCooperativeMatrix references undefined scope <id> ") "OpTypeCooperativeMatrix references undefined scope <id> ")
<< operands[2]; << operands[2];
} }
unsigned rows = operands[3]; unsigned rows = getConstantInt(operands[3]).getInt();
unsigned columns = operands[4]; unsigned columns = getConstantInt(operands[4]).getInt();
typeMap[operands[0]] = spirv::CooperativeMatrixNVType::get( typeMap[operands[0]] = spirv::CooperativeMatrixNVType::get(
elementTy, scope.getValue(), rows, columns); elementTy, scope.getValue(), rows, columns);

View File

@ -1104,10 +1104,15 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
return failure(); return failure();
} }
typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV; typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV;
auto getConstantOp = [&](uint32_t id) {
auto attr = IntegerAttr::get(IntegerType::get(32, type.getContext()), id);
return prepareConstantInt(loc, attr);
};
operands.push_back(elementTypeID); operands.push_back(elementTypeID);
operands.push_back(static_cast<uint32_t>(cooperativeMatrixType.getScope())); operands.push_back(
operands.push_back(cooperativeMatrixType.getRows()); getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())));
operands.push_back(cooperativeMatrixType.getColumns()); operands.push_back(getConstantOp(cooperativeMatrixType.getRows()));
operands.push_back(getConstantOp(cooperativeMatrixType.getColumns()));
return success(); return success();
} }