forked from OSchip/llvm-project
(De)serialize float scalar spv.constant
This CL adds support for float scalar spv.constant in (de)serialization. PiperOrigin-RevId: 259311776
This commit is contained in:
parent
c1844220cd
commit
83c97a6784
|
@ -443,8 +443,8 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode,
|
|||
floatTy = opBuilder.getF64Type();
|
||||
break;
|
||||
default:
|
||||
return emitError(unknownLoc, "unsupported bitwdith ")
|
||||
<< operands[1] << " with OpTypeFloat";
|
||||
return emitError(unknownLoc, "unsupported OpTypeFloat bitwdith: ")
|
||||
<< operands[1];
|
||||
}
|
||||
typeMap[operands[0]] = floatTy;
|
||||
} break;
|
||||
|
@ -556,6 +556,31 @@ LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands) {
|
|||
|
||||
auto attr = opBuilder.getIntegerAttr(intType, value);
|
||||
op = opBuilder.create<spirv::ConstantOp>(unknownLoc, intType, attr);
|
||||
} else if (auto floatType = resultType.dyn_cast<FloatType>()) {
|
||||
auto bitwidth = floatType.getWidth();
|
||||
if (failed(checkOperandSizeForBitwidth(bitwidth))) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
APFloat value(0.f);
|
||||
if (floatType.isF64()) {
|
||||
// Double values are represented with two SPIR-V words. According to
|
||||
// SPIR-V spec: "When the type’s bit width is larger than one word, the
|
||||
// literal’s low-order words appear first."
|
||||
struct DoubleWord {
|
||||
uint32_t word1;
|
||||
uint32_t word2;
|
||||
} words = {operands[2], operands[3]};
|
||||
value = APFloat(llvm::bit_cast<double>(words));
|
||||
} else if (floatType.isF32()) {
|
||||
value = APFloat(llvm::bit_cast<float>(operands[2]));
|
||||
} else if (floatType.isF16()) {
|
||||
APInt data(16, operands[2]);
|
||||
value = APFloat(APFloat::IEEEhalf(), data);
|
||||
}
|
||||
|
||||
auto attr = opBuilder.getFloatAttr(floatType, value);
|
||||
op = opBuilder.create<spirv::ConstantOp>(unknownLoc, floatType, attr);
|
||||
} else {
|
||||
return emitError(unknownLoc, "OpConstant can only generate values of "
|
||||
"scalar integer or floating-point type");
|
||||
|
@ -564,6 +589,7 @@ LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands) {
|
|||
valueMap[operands[1]] = op.getResult();
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult Deserializer::processConstantBool(bool isTrue,
|
||||
ArrayRef<uint32_t> operands) {
|
||||
if (operands.size() != 2) {
|
||||
|
|
|
@ -173,6 +173,8 @@ private:
|
|||
|
||||
uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr);
|
||||
|
||||
uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr);
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Operations
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -488,6 +490,9 @@ Serializer::prepareFunctionType(Location loc, FunctionType type,
|
|||
|
||||
uint32_t Serializer::prepareConstant(Location loc, Type constType,
|
||||
Attribute valueAttr) {
|
||||
if (auto floatAttr = valueAttr.dyn_cast<FloatAttr>()) {
|
||||
return prepareConstantFp(loc, floatAttr);
|
||||
}
|
||||
if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) {
|
||||
return prepareConstantInt(loc, intAttr);
|
||||
}
|
||||
|
@ -575,6 +580,50 @@ uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr) {
|
|||
return constIDMap[intAttr] = resultID;
|
||||
}
|
||||
|
||||
uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr) {
|
||||
if (auto id = findConstantID(floatAttr)) {
|
||||
return id;
|
||||
}
|
||||
|
||||
// Process the type for this float literal
|
||||
uint32_t typeID = 0;
|
||||
if (failed(processType(loc, floatAttr.getType(), typeID))) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
auto resultID = getNextID();
|
||||
APFloat value = floatAttr.getValue();
|
||||
APInt intValue = value.bitcastToAPInt();
|
||||
|
||||
if (&value.getSemantics() == &APFloat::IEEEsingle()) {
|
||||
uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
|
||||
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant,
|
||||
{typeID, resultID, word});
|
||||
} else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
|
||||
struct DoubleWord {
|
||||
uint32_t word1;
|
||||
uint32_t word2;
|
||||
} words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
|
||||
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant,
|
||||
{typeID, resultID, words.word1, words.word2});
|
||||
} else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
|
||||
uint32_t word =
|
||||
static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
|
||||
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant,
|
||||
{typeID, resultID, word});
|
||||
} else {
|
||||
std::string valueStr;
|
||||
llvm::raw_string_ostream rss(valueStr);
|
||||
value.print(rss);
|
||||
|
||||
emitError(loc, "cannot serialize ")
|
||||
<< floatAttr.getType() << "-typed float literal: " << rss.str();
|
||||
return 0;
|
||||
}
|
||||
|
||||
return constIDMap[floatAttr] = resultID;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Operation
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -32,6 +32,33 @@ func @spirv_module() -> () {
|
|||
%9 = spv.constant -32768 : i16 // -2^15
|
||||
// CHECK: spv.constant 32767 : i16
|
||||
%10 = spv.constant 32767 : i16 // 2^15 - 1
|
||||
|
||||
// float
|
||||
// CHECK: spv.constant 0.000000e+00 : f32
|
||||
%11 = spv.constant 0. : f32
|
||||
// CHECK: spv.constant 1.000000e+00 : f32
|
||||
%12 = spv.constant 1. : f32
|
||||
// CHECK: spv.constant -0.000000e+00 : f32
|
||||
%13 = spv.constant -0. : f32
|
||||
// CHECK: spv.constant -1.000000e+00 : f32
|
||||
%14 = spv.constant -1. : f32
|
||||
// CHECK: spv.constant 7.500000e-01 : f32
|
||||
%15 = spv.constant 0.75 : f32
|
||||
// CHECK: spv.constant -2.500000e-01 : f32
|
||||
%16 = spv.constant -0.25 : f32
|
||||
|
||||
// double
|
||||
// TODO(antiagainst): test range boundary values
|
||||
// CHECK: spv.constant 1.024000e+03 : f64
|
||||
%17 = spv.constant 1024. : f64
|
||||
// CHECK: spv.constant -1.024000e+03 : f64
|
||||
%18 = spv.constant -1024. : f64
|
||||
|
||||
// half
|
||||
// CHECK: spv.constant 5.120000e+02 : f16
|
||||
%19 = spv.constant 512. : f16
|
||||
// CHECK: spv.constant -5.120000e+02 : f16
|
||||
%20 = spv.constant -512. : f16
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue