Support SPIR-V constant op to take DenseElementsAttr as input.

Iterates each element to build the array. This includes a little refactor to
combine bool/int/float into a function, since they are similar. The only
difference is calling different function in the end.

PiperOrigin-RevId: 281210288
This commit is contained in:
Hanhan Wang 2019-11-18 20:01:28 -08:00 committed by A. Unique TensorFlower
parent d8563c0e3a
commit c614c92fdc
4 changed files with 155 additions and 175 deletions

View File

@ -1079,12 +1079,10 @@ static ParseResult parseConstantOp(OpAsmParser &parser, OperationState &state) {
if (parser.parseAttribute(value, kValueAttrName, state.attributes))
return failure();
Type type;
if (value.getType().isa<NoneType>()) {
Type type = value.getType();
if (type.isa<NoneType>() || type.isa<TensorType>()) {
if (parser.parseColonType(type))
return failure();
} else {
type = value.getType();
}
return parser.addTypeToList(type, state.types);
@ -1108,14 +1106,46 @@ static LogicalResult verify(spirv::ConstantOp constOp) {
switch (value.getKind()) {
case StandardAttributes::Bool:
case StandardAttributes::Integer:
case StandardAttributes::Float:
case StandardAttributes::DenseElements:
case StandardAttributes::SparseElements: {
case StandardAttributes::Float: {
if (valueType != opType)
return constOp.emitOpError("result type (")
<< opType << ") does not match value type (" << valueType << ")";
return success();
} break;
case StandardAttributes::DenseElements:
case StandardAttributes::SparseElements: {
if (valueType == opType)
break;
auto arrayType = opType.dyn_cast<spirv::ArrayType>();
auto shapedType = valueType.dyn_cast<ShapedType>();
if (!arrayType) {
return constOp.emitOpError(
"must have spv.array result type for array value");
}
int numElements = arrayType.getNumElements();
auto opElemType = arrayType.getElementType();
while (auto t = opElemType.dyn_cast<spirv::ArrayType>()) {
numElements *= t.getNumElements();
opElemType = t.getElementType();
}
if (!opElemType.isIntOrFloat()) {
return constOp.emitOpError("only support nested array result type");
}
auto valueElemType = shapedType.getElementType();
if (valueElemType != opElemType) {
return constOp.emitOpError("result element type (")
<< opElemType << ") does not match value element type ("
<< valueElemType << ")";
}
if (numElements != shapedType.getNumElements()) {
return constOp.emitOpError("result number of elements (")
<< numElements << ") does not match value number of elements ("
<< shapedType.getNumElements() << ")";
}
} break;
case StandardAttributes::Array: {
auto arrayType = opType.dyn_cast<spirv::ArrayType>();
if (!arrayType)

View File

@ -240,29 +240,21 @@ private:
/// constants.
uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr);
/// Prepares bool ElementsAttr serialization. This method updates `opcode`
/// with a proper OpConstant* instruction and pushes literal values for the
/// constant to `operands`.
LogicalResult prepareBoolVectorConstant(Location loc,
DenseIntElementsAttr elementsAttr,
spirv::Opcode &opcode,
SmallVectorImpl<uint32_t> &operands);
/// Prepares array attribute serialization. This method emits corresponding
/// OpConstant* and returns the result <id> associated with it. Returns 0 if
/// failed.
uint32_t prepareArrayConstant(Location loc, Type constType, ArrayAttr attr);
/// Prepares int ElementsAttr serialization. This method updates `opcode` with
/// a proper OpConstant* instruction and pushes literal values for the
/// constant to `operands`.
LogicalResult prepareIntVectorConstant(Location loc,
DenseIntElementsAttr elementsAttr,
spirv::Opcode &opcode,
SmallVectorImpl<uint32_t> &operands);
/// Prepares float ElementsAttr serialization. This method updates `opcode`
/// with a proper OpConstant* instruction and pushes literal values for the
/// constant to `operands`.
LogicalResult prepareFloatVectorConstant(Location loc,
DenseFPElementsAttr elementsAttr,
spirv::Opcode &opcode,
SmallVectorImpl<uint32_t> &operands);
/// Prepares bool/int/float DenseElementsAttr serialization. This method
/// iterates the DenseElementsAttr to construct the constant array, and
/// returns the result <id> associated with it. Returns 0 if failed. Note
/// that the size of `index` must match the rank.
/// TODO(hanchung): Consider to enhance splat elements cases. For splat cases,
/// we don't need to loop over all elements, especially when the splat value
/// is zero. We can use OpConstantNull when the value is zero.
uint32_t prepareDenseElementsConstant(Location loc, Type constType,
DenseElementsAttr valueAttr, int dim,
MutableArrayRef<uint64_t> index);
/// Prepares scalar attribute serialization. This method emits corresponding
/// OpConstant* and returns the result <id> associated with it. Returns 0 if
@ -1064,6 +1056,7 @@ uint32_t Serializer::prepareConstant(Location loc, Type constType,
if (auto id = prepareConstantScalar(loc, valueAttr)) {
return id;
}
// This is a composite literal. We need to handle each component separately
// and then emit an OpConstantComposite for the whole.
@ -1075,179 +1068,92 @@ uint32_t Serializer::prepareConstant(Location loc, Type constType,
if (failed(processType(loc, constType, typeID))) {
return 0;
}
auto resultID = getNextID();
spirv::Opcode opcode = spirv::Opcode::OpNop;
SmallVector<uint32_t, 4> operands;
operands.push_back(typeID);
operands.push_back(resultID);
if (auto vectorAttr = valueAttr.dyn_cast<DenseIntElementsAttr>()) {
if (vectorAttr.getType().getElementType().isInteger(1)) {
if (failed(prepareBoolVectorConstant(loc, vectorAttr, opcode, operands)))
return 0;
} else if (failed(
prepareIntVectorConstant(loc, vectorAttr, opcode, operands)))
return 0;
} else if (auto vectorAttr = valueAttr.dyn_cast<DenseFPElementsAttr>()) {
if (failed(prepareFloatVectorConstant(loc, vectorAttr, opcode, operands)))
return 0;
uint32_t resultID = 0;
if (auto attr = valueAttr.dyn_cast<DenseElementsAttr>()) {
int rank = attr.getType().dyn_cast<ShapedType>().getRank();
SmallVector<uint64_t, 4> index(rank);
resultID = prepareDenseElementsConstant(loc, constType, attr,
/*dim=*/0, index);
} else if (auto arrayAttr = valueAttr.dyn_cast<ArrayAttr>()) {
opcode = spirv::Opcode::OpConstantComposite;
operands.reserve(arrayAttr.size() + 2);
resultID = prepareArrayConstant(loc, constType, arrayAttr);
}
auto elementType = constType.cast<spirv::ArrayType>().getElementType();
for (Attribute elementAttr : arrayAttr)
if (auto elementID = prepareConstant(loc, elementType, elementAttr)) {
operands.push_back(elementID);
} else {
return 0;
}
} else {
if (resultID == 0) {
emitError(loc, "cannot serialize attribute: ") << valueAttr;
return 0;
}
encodeInstructionInto(typesGlobalValues, opcode, operands);
constIDMap[valueAttr] = resultID;
return resultID;
}
LogicalResult Serializer::prepareBoolVectorConstant(
Location loc, DenseIntElementsAttr elementsAttr, spirv::Opcode &opcode,
SmallVectorImpl<uint32_t> &operands) {
auto type = elementsAttr.getType();
assert(type.hasRank() && type.getRank() == 1 &&
"spv.constant should have verified only vector literal uses "
"ElementsAttr");
assert(type.getElementType().isInteger(1) && "must be bool ElementsAttr");
auto count = type.getNumElements();
// Operands for constructing the SPIR-V OpConstant* instruction
operands.reserve(count + 2);
// For splat cases, we don't need to loop over all elements, especially when
// the splat value is zero.
if (elementsAttr.isSplat()) {
// We can use OpConstantNull if this bool ElementsAttr is splatting false.
if (!elementsAttr.getSplatValue<bool>()) {
opcode = spirv::Opcode::OpConstantNull;
return success();
}
if (auto id =
prepareConstantBool(loc, elementsAttr.getSplatValue<BoolAttr>())) {
opcode = spirv::Opcode::OpConstantComposite;
operands.append(count, id);
return success();
}
return failure();
uint32_t Serializer::prepareArrayConstant(Location loc, Type constType,
ArrayAttr attr) {
uint32_t typeID = 0;
if (failed(processType(loc, constType, typeID))) {
return 0;
}
// Otherwise, we need to process each element and compose them with
// OpConstantComposite.
opcode = spirv::Opcode::OpConstantComposite;
for (auto boolAttr : elementsAttr.getValues<BoolAttr>()) {
// We are constructing an BoolAttr for each value here. But given that
// we only use ElementsAttr for vectors with no more than 4 elements, it
// should be fine here.
if (auto elementID = prepareConstantBool(loc, boolAttr)) {
uint32_t resultID = getNextID();
SmallVector<uint32_t, 4> operands = {typeID, resultID};
operands.reserve(attr.size() + 2);
auto elementType = constType.cast<spirv::ArrayType>().getElementType();
for (Attribute elementAttr : attr) {
if (auto elementID = prepareConstant(loc, elementType, elementAttr)) {
operands.push_back(elementID);
} else {
return failure();
return 0;
}
}
return success();
spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
encodeInstructionInto(typesGlobalValues, opcode, operands);
return resultID;
}
LogicalResult Serializer::prepareIntVectorConstant(
Location loc, DenseIntElementsAttr elementsAttr, spirv::Opcode &opcode,
SmallVectorImpl<uint32_t> &operands) {
auto type = elementsAttr.getType();
assert(type.hasRank() && type.getRank() == 1 &&
"spv.constant should have verified only vector literal uses "
"ElementsAttr");
assert(!type.getElementType().isInteger(1) &&
"must be non-bool ElementsAttr");
auto count = type.getNumElements();
// Operands for constructing the SPIR-V OpConstant* instruction
operands.reserve(count + 2);
// For splat cases, we don't need to loop over all elements, especially when
// the splat value is zero.
if (elementsAttr.isSplat()) {
auto splatAttr = elementsAttr.getSplatValue<IntegerAttr>();
// We can use OpConstantNull if this int ElementsAttr is splatting 0.
if (splatAttr.getValue().isNullValue()) {
opcode = spirv::Opcode::OpConstantNull;
return success();
// TODO(hanchung): Turn the below function into iterative function, instead of
// recursive function.
uint32_t
Serializer::prepareDenseElementsConstant(Location loc, Type constType,
DenseElementsAttr valueAttr, int dim,
MutableArrayRef<uint64_t> index) {
auto shapedType = valueAttr.getType().dyn_cast<ShapedType>();
assert(dim <= shapedType.getRank());
if (shapedType.getRank() == dim) {
if (auto attr = valueAttr.dyn_cast<DenseIntElementsAttr>()) {
return attr.getType().getElementType().isInteger(1)
? prepareConstantBool(loc, attr.getValue<BoolAttr>(index))
: prepareConstantInt(loc, attr.getValue<IntegerAttr>(index));
}
if (auto id = prepareConstantInt(loc, splatAttr)) {
opcode = spirv::Opcode::OpConstantComposite;
operands.append(count, id);
return success();
if (auto attr = valueAttr.dyn_cast<DenseFPElementsAttr>()) {
return prepareConstantFp(loc, attr.getValue<FloatAttr>(index));
}
return failure();
return 0;
}
// Otherwise, we need to process each element and compose them with
// OpConstantComposite.
opcode = spirv::Opcode::OpConstantComposite;
for (auto intAttr : elementsAttr.getValues<IntegerAttr>()) {
// We are constructing an IntegerAttr for each value here. But given that
// we only use ElementsAttr for vectors with no more than 4 elements, it
// should be fine here.
// TODO(antiagainst): revisit this if special extensions enabling large
// vectors are supported.
if (auto elementID = prepareConstantInt(loc, intAttr)) {
uint32_t typeID = 0;
if (failed(processType(loc, constType, typeID))) {
return 0;
}
uint32_t resultID = getNextID();
SmallVector<uint32_t, 4> operands = {typeID, resultID};
operands.reserve(shapedType.getDimSize(dim) + 2);
auto elementType = constType.cast<spirv::CompositeType>().getElementType(0);
for (int i = 0; i < shapedType.getDimSize(dim); ++i) {
index[dim] = i;
if (auto elementID = prepareDenseElementsConstant(
loc, elementType, valueAttr, dim + 1, index)) {
operands.push_back(elementID);
} else {
return failure();
return 0;
}
}
return success();
}
spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
encodeInstructionInto(typesGlobalValues, opcode, operands);
LogicalResult Serializer::prepareFloatVectorConstant(
Location loc, DenseFPElementsAttr elementsAttr, spirv::Opcode &opcode,
SmallVectorImpl<uint32_t> &operands) {
auto type = elementsAttr.getType();
assert(type.hasRank() && type.getRank() == 1 &&
"spv.constant should have verified only vector literal uses "
"ElementsAttr");
auto count = type.getNumElements();
operands.reserve(count + 2);
if (elementsAttr.isSplat()) {
FloatAttr splatAttr = elementsAttr.getSplatValue<FloatAttr>();
if (splatAttr.getValue().isZero()) {
opcode = spirv::Opcode::OpConstantNull;
return success();
}
if (auto id = prepareConstantFp(loc, splatAttr)) {
opcode = spirv::Opcode::OpConstantComposite;
operands.append(count, id);
return success();
}
return failure();
}
opcode = spirv::Opcode::OpConstantComposite;
for (auto fpAttr : elementsAttr.getValues<FloatAttr>()) {
if (auto elementID = prepareConstantFp(loc, fpAttr)) {
operands.push_back(elementID);
} else {
return failure();
}
}
return success();
return resultID;
}
uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr,

View File

@ -178,4 +178,18 @@ spv.module "Logical" "GLSL450" {
%4 = spv.IAdd %0, %3 : i32
spv.Return
}
// CHECK-LABEL: @multi_dimensions_const
func @multi_dimensions_const() -> (!spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>) {
// CHECK: spv.constant {{\[}}{{\[}}[1 : i32, 2 : i32, 3 : i32], [4 : i32, 5 : i32, 6 : i32]], {{\[}}[7 : i32, 8 : i32, 9 : i32], [10 : i32, 11 : i32, 12 : i32]]] : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>
%0 = spv.constant dense<[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]> : tensor<2x2x3xi32> : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>
spv.ReturnValue %0 : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>
}
// CHECK-LABEL: @multi_dimensions_splat_const
func @multi_dimensions_splat_const() -> (!spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>) {
// CHECK: spv.constant {{\[}}{{\[}}[1 : i32, 1 : i32, 1 : i32], [1 : i32, 1 : i32, 1 : i32]], {{\[}}[1 : i32, 1 : i32, 1 : i32], [1 : i32, 1 : i32, 1 : i32]]] : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>
%0 = spv.constant dense<1> : tensor<2x2x3xi32> : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>
spv.ReturnValue %0 : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>
}
}

View File

@ -48,12 +48,20 @@ func @const() -> () {
// CHECK: %2 = spv.constant 5.000000e-01 : f32
// CHECK: %3 = spv.constant dense<[2, 3]> : vector<2xi32>
// CHECK: %4 = spv.constant [dense<3.000000e+00> : vector<2xf32>] : !spv.array<1 x vector<2xf32>>
// CHECK: %5 = spv.constant dense<1> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32 [4]> [12]>
// CHECK: %6 = spv.constant dense<1.000000e+00> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32 [4]> [12]>
// CHECK: %7 = spv.constant dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32 [4]> [12]>
// CHECK: %8 = spv.constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32 [4]> [12]>
%0 = spv.constant true
%1 = spv.constant 42 : i32
%2 = spv.constant 0.5 : f32
%3 = spv.constant dense<[2, 3]> : vector<2xi32>
%4 = spv.constant [dense<3.0> : vector<2xf32>] : !spv.array<1xvector<2xf32>>
%5 = spv.constant dense<1> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32 [4]> [12]>
%6 = spv.constant dense<1.0> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32 [4]> [12]>
%7 = spv.constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32 [4]> [12]>
%8 = spv.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32 [4]> [12]>
return
}
@ -83,11 +91,33 @@ func @array_constant() -> () {
// -----
func @non_nested_array_constant() -> () {
// expected-error @+1 {{only support nested array result type}}
%0 = spv.constant dense<3.0> : tensor<2x2xf32> : !spv.array<2xvector<2xf32>>
return
}
// -----
func @value_result_type_mismatch() -> () {
// expected-error @+1 {{result type ('vector<4xi32>') does not match value type ('tensor<4xi32>')}}
// expected-error @+1 {{must have spv.array result type for array value}}
%0 = "spv.constant"() {value = dense<0> : tensor<4xi32>} : () -> (vector<4xi32>)
}
// -----
func @value_result_type_mismatch() -> () {
// expected-error @+1 {{result element type ('i32') does not match value element type ('f32')}}
%0 = spv.constant dense<1.0> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x i32 [4]> [12]>
}
// -----
func @value_result_num_elements_mismatch() -> () {
// expected-error @+1 {{result number of elements (6) does not match value number of elements (4)}}
%0 = spv.constant dense<1.0> : tensor<2x2xf32> : !spv.array<2 x !spv.array<3 x f32 [4]> [12]>
return
}
// -----