(De)serialize float scalar spv.constant

This CL adds support for float scalar spv.constant in (de)serialization.

PiperOrigin-RevId: 259311776
This commit is contained in:
Lei Zhang 2019-07-22 06:01:34 -07:00 committed by A. Unique TensorFlower
parent c1844220cd
commit 83c97a6784
3 changed files with 104 additions and 2 deletions

View File

@ -443,8 +443,8 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode,
floatTy = opBuilder.getF64Type();
break;
default:
return emitError(unknownLoc, "unsupported bitwdith ")
<< operands[1] << " with OpTypeFloat";
return emitError(unknownLoc, "unsupported OpTypeFloat bitwdith: ")
<< operands[1];
}
typeMap[operands[0]] = floatTy;
} break;
@ -556,6 +556,31 @@ LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands) {
auto attr = opBuilder.getIntegerAttr(intType, value);
op = opBuilder.create<spirv::ConstantOp>(unknownLoc, intType, attr);
} else if (auto floatType = resultType.dyn_cast<FloatType>()) {
auto bitwidth = floatType.getWidth();
if (failed(checkOperandSizeForBitwidth(bitwidth))) {
return failure();
}
APFloat value(0.f);
if (floatType.isF64()) {
// Double values are represented with two SPIR-V words. According to
// SPIR-V spec: "When the types bit width is larger than one word, the
// literals low-order words appear first."
struct DoubleWord {
uint32_t word1;
uint32_t word2;
} words = {operands[2], operands[3]};
value = APFloat(llvm::bit_cast<double>(words));
} else if (floatType.isF32()) {
value = APFloat(llvm::bit_cast<float>(operands[2]));
} else if (floatType.isF16()) {
APInt data(16, operands[2]);
value = APFloat(APFloat::IEEEhalf(), data);
}
auto attr = opBuilder.getFloatAttr(floatType, value);
op = opBuilder.create<spirv::ConstantOp>(unknownLoc, floatType, attr);
} else {
return emitError(unknownLoc, "OpConstant can only generate values of "
"scalar integer or floating-point type");
@ -564,6 +589,7 @@ LogicalResult Deserializer::processConstant(ArrayRef<uint32_t> operands) {
valueMap[operands[1]] = op.getResult();
return success();
}
LogicalResult Deserializer::processConstantBool(bool isTrue,
ArrayRef<uint32_t> operands) {
if (operands.size() != 2) {

View File

@ -173,6 +173,8 @@ private:
uint32_t prepareConstantInt(Location loc, IntegerAttr intAttr);
uint32_t prepareConstantFp(Location loc, FloatAttr floatAttr);
//===--------------------------------------------------------------------===//
// Operations
//===--------------------------------------------------------------------===//
@ -488,6 +490,9 @@ Serializer::prepareFunctionType(Location loc, FunctionType type,
uint32_t Serializer::prepareConstant(Location loc, Type constType,
Attribute valueAttr) {
if (auto floatAttr = valueAttr.dyn_cast<FloatAttr>()) {
return prepareConstantFp(loc, floatAttr);
}
if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) {
return prepareConstantInt(loc, intAttr);
}
@ -575,6 +580,50 @@ uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr) {
return constIDMap[intAttr] = resultID;
}
uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr) {
if (auto id = findConstantID(floatAttr)) {
return id;
}
// Process the type for this float literal
uint32_t typeID = 0;
if (failed(processType(loc, floatAttr.getType(), typeID))) {
return 0;
}
auto resultID = getNextID();
APFloat value = floatAttr.getValue();
APInt intValue = value.bitcastToAPInt();
if (&value.getSemantics() == &APFloat::IEEEsingle()) {
uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant,
{typeID, resultID, word});
} else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
struct DoubleWord {
uint32_t word1;
uint32_t word2;
} words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant,
{typeID, resultID, words.word1, words.word2});
} else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
uint32_t word =
static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstant,
{typeID, resultID, word});
} else {
std::string valueStr;
llvm::raw_string_ostream rss(valueStr);
value.print(rss);
emitError(loc, "cannot serialize ")
<< floatAttr.getType() << "-typed float literal: " << rss.str();
return 0;
}
return constIDMap[floatAttr] = resultID;
}
//===----------------------------------------------------------------------===//
// Operation
//===----------------------------------------------------------------------===//

View File

@ -32,6 +32,33 @@ func @spirv_module() -> () {
%9 = spv.constant -32768 : i16 // -2^15
// CHECK: spv.constant 32767 : i16
%10 = spv.constant 32767 : i16 // 2^15 - 1
// float
// CHECK: spv.constant 0.000000e+00 : f32
%11 = spv.constant 0. : f32
// CHECK: spv.constant 1.000000e+00 : f32
%12 = spv.constant 1. : f32
// CHECK: spv.constant -0.000000e+00 : f32
%13 = spv.constant -0. : f32
// CHECK: spv.constant -1.000000e+00 : f32
%14 = spv.constant -1. : f32
// CHECK: spv.constant 7.500000e-01 : f32
%15 = spv.constant 0.75 : f32
// CHECK: spv.constant -2.500000e-01 : f32
%16 = spv.constant -0.25 : f32
// double
// TODO(antiagainst): test range boundary values
// CHECK: spv.constant 1.024000e+03 : f64
%17 = spv.constant 1024. : f64
// CHECK: spv.constant -1.024000e+03 : f64
%18 = spv.constant -1024. : f64
// half
// CHECK: spv.constant 5.120000e+02 : f16
%19 = spv.constant 512. : f16
// CHECK: spv.constant -5.120000e+02 : f16
%20 = spv.constant -512. : f16
}
return
}