forked from OSchip/llvm-project
Support SPIR-V constant op to take DenseElementsAttr as input.
Iterates each element to build the array. This includes a little refactor to combine bool/int/float into a function, since they are similar. The only difference is calling different function in the end. PiperOrigin-RevId: 281210288
This commit is contained in:
parent
d8563c0e3a
commit
c614c92fdc
|
@ -1079,12 +1079,10 @@ static ParseResult parseConstantOp(OpAsmParser &parser, OperationState &state) {
|
|||
if (parser.parseAttribute(value, kValueAttrName, state.attributes))
|
||||
return failure();
|
||||
|
||||
Type type;
|
||||
if (value.getType().isa<NoneType>()) {
|
||||
Type type = value.getType();
|
||||
if (type.isa<NoneType>() || type.isa<TensorType>()) {
|
||||
if (parser.parseColonType(type))
|
||||
return failure();
|
||||
} else {
|
||||
type = value.getType();
|
||||
}
|
||||
|
||||
return parser.addTypeToList(type, state.types);
|
||||
|
@ -1108,14 +1106,46 @@ static LogicalResult verify(spirv::ConstantOp constOp) {
|
|||
switch (value.getKind()) {
|
||||
case StandardAttributes::Bool:
|
||||
case StandardAttributes::Integer:
|
||||
case StandardAttributes::Float:
|
||||
case StandardAttributes::DenseElements:
|
||||
case StandardAttributes::SparseElements: {
|
||||
case StandardAttributes::Float: {
|
||||
if (valueType != opType)
|
||||
return constOp.emitOpError("result type (")
|
||||
<< opType << ") does not match value type (" << valueType << ")";
|
||||
return success();
|
||||
} break;
|
||||
case StandardAttributes::DenseElements:
|
||||
case StandardAttributes::SparseElements: {
|
||||
if (valueType == opType)
|
||||
break;
|
||||
auto arrayType = opType.dyn_cast<spirv::ArrayType>();
|
||||
auto shapedType = valueType.dyn_cast<ShapedType>();
|
||||
if (!arrayType) {
|
||||
return constOp.emitOpError(
|
||||
"must have spv.array result type for array value");
|
||||
}
|
||||
|
||||
int numElements = arrayType.getNumElements();
|
||||
auto opElemType = arrayType.getElementType();
|
||||
while (auto t = opElemType.dyn_cast<spirv::ArrayType>()) {
|
||||
numElements *= t.getNumElements();
|
||||
opElemType = t.getElementType();
|
||||
}
|
||||
if (!opElemType.isIntOrFloat()) {
|
||||
return constOp.emitOpError("only support nested array result type");
|
||||
}
|
||||
|
||||
auto valueElemType = shapedType.getElementType();
|
||||
if (valueElemType != opElemType) {
|
||||
return constOp.emitOpError("result element type (")
|
||||
<< opElemType << ") does not match value element type ("
|
||||
<< valueElemType << ")";
|
||||
}
|
||||
|
||||
if (numElements != shapedType.getNumElements()) {
|
||||
return constOp.emitOpError("result number of elements (")
|
||||
<< numElements << ") does not match value number of elements ("
|
||||
<< shapedType.getNumElements() << ")";
|
||||
}
|
||||
} break;
|
||||
case StandardAttributes::Array: {
|
||||
auto arrayType = opType.dyn_cast<spirv::ArrayType>();
|
||||
if (!arrayType)
|
||||
|
|
|
@ -240,29 +240,21 @@ private:
|
|||
/// constants.
|
||||
uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr);
|
||||
|
||||
/// Prepares bool ElementsAttr serialization. This method updates `opcode`
|
||||
/// with a proper OpConstant* instruction and pushes literal values for the
|
||||
/// constant to `operands`.
|
||||
LogicalResult prepareBoolVectorConstant(Location loc,
|
||||
DenseIntElementsAttr elementsAttr,
|
||||
spirv::Opcode &opcode,
|
||||
SmallVectorImpl<uint32_t> &operands);
|
||||
/// Prepares array attribute serialization. This method emits corresponding
|
||||
/// OpConstant* and returns the result <id> associated with it. Returns 0 if
|
||||
/// failed.
|
||||
uint32_t prepareArrayConstant(Location loc, Type constType, ArrayAttr attr);
|
||||
|
||||
/// Prepares int ElementsAttr serialization. This method updates `opcode` with
|
||||
/// a proper OpConstant* instruction and pushes literal values for the
|
||||
/// constant to `operands`.
|
||||
LogicalResult prepareIntVectorConstant(Location loc,
|
||||
DenseIntElementsAttr elementsAttr,
|
||||
spirv::Opcode &opcode,
|
||||
SmallVectorImpl<uint32_t> &operands);
|
||||
|
||||
/// Prepares float ElementsAttr serialization. This method updates `opcode`
|
||||
/// with a proper OpConstant* instruction and pushes literal values for the
|
||||
/// constant to `operands`.
|
||||
LogicalResult prepareFloatVectorConstant(Location loc,
|
||||
DenseFPElementsAttr elementsAttr,
|
||||
spirv::Opcode &opcode,
|
||||
SmallVectorImpl<uint32_t> &operands);
|
||||
/// Prepares bool/int/float DenseElementsAttr serialization. This method
|
||||
/// iterates the DenseElementsAttr to construct the constant array, and
|
||||
/// returns the result <id> associated with it. Returns 0 if failed. Note
|
||||
/// that the size of `index` must match the rank.
|
||||
/// TODO(hanchung): Consider to enhance splat elements cases. For splat cases,
|
||||
/// we don't need to loop over all elements, especially when the splat value
|
||||
/// is zero. We can use OpConstantNull when the value is zero.
|
||||
uint32_t prepareDenseElementsConstant(Location loc, Type constType,
|
||||
DenseElementsAttr valueAttr, int dim,
|
||||
MutableArrayRef<uint64_t> index);
|
||||
|
||||
/// Prepares scalar attribute serialization. This method emits corresponding
|
||||
/// OpConstant* and returns the result <id> associated with it. Returns 0 if
|
||||
|
@ -1064,6 +1056,7 @@ uint32_t Serializer::prepareConstant(Location loc, Type constType,
|
|||
if (auto id = prepareConstantScalar(loc, valueAttr)) {
|
||||
return id;
|
||||
}
|
||||
|
||||
// This is a composite literal. We need to handle each component separately
|
||||
// and then emit an OpConstantComposite for the whole.
|
||||
|
||||
|
@ -1075,179 +1068,92 @@ uint32_t Serializer::prepareConstant(Location loc, Type constType,
|
|||
if (failed(processType(loc, constType, typeID))) {
|
||||
return 0;
|
||||
}
|
||||
auto resultID = getNextID();
|
||||
|
||||
spirv::Opcode opcode = spirv::Opcode::OpNop;
|
||||
SmallVector<uint32_t, 4> operands;
|
||||
operands.push_back(typeID);
|
||||
operands.push_back(resultID);
|
||||
|
||||
if (auto vectorAttr = valueAttr.dyn_cast<DenseIntElementsAttr>()) {
|
||||
if (vectorAttr.getType().getElementType().isInteger(1)) {
|
||||
if (failed(prepareBoolVectorConstant(loc, vectorAttr, opcode, operands)))
|
||||
return 0;
|
||||
} else if (failed(
|
||||
prepareIntVectorConstant(loc, vectorAttr, opcode, operands)))
|
||||
return 0;
|
||||
} else if (auto vectorAttr = valueAttr.dyn_cast<DenseFPElementsAttr>()) {
|
||||
if (failed(prepareFloatVectorConstant(loc, vectorAttr, opcode, operands)))
|
||||
return 0;
|
||||
uint32_t resultID = 0;
|
||||
if (auto attr = valueAttr.dyn_cast<DenseElementsAttr>()) {
|
||||
int rank = attr.getType().dyn_cast<ShapedType>().getRank();
|
||||
SmallVector<uint64_t, 4> index(rank);
|
||||
resultID = prepareDenseElementsConstant(loc, constType, attr,
|
||||
/*dim=*/0, index);
|
||||
} else if (auto arrayAttr = valueAttr.dyn_cast<ArrayAttr>()) {
|
||||
opcode = spirv::Opcode::OpConstantComposite;
|
||||
operands.reserve(arrayAttr.size() + 2);
|
||||
resultID = prepareArrayConstant(loc, constType, arrayAttr);
|
||||
}
|
||||
|
||||
if (resultID == 0) {
|
||||
emitError(loc, "cannot serialize attribute: ") << valueAttr;
|
||||
return 0;
|
||||
}
|
||||
|
||||
constIDMap[valueAttr] = resultID;
|
||||
return resultID;
|
||||
}
|
||||
|
||||
uint32_t Serializer::prepareArrayConstant(Location loc, Type constType,
|
||||
ArrayAttr attr) {
|
||||
uint32_t typeID = 0;
|
||||
if (failed(processType(loc, constType, typeID))) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t resultID = getNextID();
|
||||
SmallVector<uint32_t, 4> operands = {typeID, resultID};
|
||||
operands.reserve(attr.size() + 2);
|
||||
auto elementType = constType.cast<spirv::ArrayType>().getElementType();
|
||||
for (Attribute elementAttr : arrayAttr)
|
||||
for (Attribute elementAttr : attr) {
|
||||
if (auto elementID = prepareConstant(loc, elementType, elementAttr)) {
|
||||
operands.push_back(elementID);
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
} else {
|
||||
emitError(loc, "cannot serialize attribute: ") << valueAttr;
|
||||
return 0;
|
||||
}
|
||||
|
||||
spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
|
||||
encodeInstructionInto(typesGlobalValues, opcode, operands);
|
||||
constIDMap[valueAttr] = resultID;
|
||||
|
||||
return resultID;
|
||||
}
|
||||
|
||||
LogicalResult Serializer::prepareBoolVectorConstant(
|
||||
Location loc, DenseIntElementsAttr elementsAttr, spirv::Opcode &opcode,
|
||||
SmallVectorImpl<uint32_t> &operands) {
|
||||
auto type = elementsAttr.getType();
|
||||
assert(type.hasRank() && type.getRank() == 1 &&
|
||||
"spv.constant should have verified only vector literal uses "
|
||||
"ElementsAttr");
|
||||
assert(type.getElementType().isInteger(1) && "must be bool ElementsAttr");
|
||||
auto count = type.getNumElements();
|
||||
|
||||
// Operands for constructing the SPIR-V OpConstant* instruction
|
||||
operands.reserve(count + 2);
|
||||
|
||||
// For splat cases, we don't need to loop over all elements, especially when
|
||||
// the splat value is zero.
|
||||
if (elementsAttr.isSplat()) {
|
||||
// We can use OpConstantNull if this bool ElementsAttr is splatting false.
|
||||
if (!elementsAttr.getSplatValue<bool>()) {
|
||||
opcode = spirv::Opcode::OpConstantNull;
|
||||
return success();
|
||||
// TODO(hanchung): Turn the below function into iterative function, instead of
|
||||
// recursive function.
|
||||
uint32_t
|
||||
Serializer::prepareDenseElementsConstant(Location loc, Type constType,
|
||||
DenseElementsAttr valueAttr, int dim,
|
||||
MutableArrayRef<uint64_t> index) {
|
||||
auto shapedType = valueAttr.getType().dyn_cast<ShapedType>();
|
||||
assert(dim <= shapedType.getRank());
|
||||
if (shapedType.getRank() == dim) {
|
||||
if (auto attr = valueAttr.dyn_cast<DenseIntElementsAttr>()) {
|
||||
return attr.getType().getElementType().isInteger(1)
|
||||
? prepareConstantBool(loc, attr.getValue<BoolAttr>(index))
|
||||
: prepareConstantInt(loc, attr.getValue<IntegerAttr>(index));
|
||||
}
|
||||
if (auto attr = valueAttr.dyn_cast<DenseFPElementsAttr>()) {
|
||||
return prepareConstantFp(loc, attr.getValue<FloatAttr>(index));
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (auto id =
|
||||
prepareConstantBool(loc, elementsAttr.getSplatValue<BoolAttr>())) {
|
||||
opcode = spirv::Opcode::OpConstantComposite;
|
||||
operands.append(count, id);
|
||||
return success();
|
||||
uint32_t typeID = 0;
|
||||
if (failed(processType(loc, constType, typeID))) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Otherwise, we need to process each element and compose them with
|
||||
// OpConstantComposite.
|
||||
opcode = spirv::Opcode::OpConstantComposite;
|
||||
for (auto boolAttr : elementsAttr.getValues<BoolAttr>()) {
|
||||
// We are constructing an BoolAttr for each value here. But given that
|
||||
// we only use ElementsAttr for vectors with no more than 4 elements, it
|
||||
// should be fine here.
|
||||
if (auto elementID = prepareConstantBool(loc, boolAttr)) {
|
||||
uint32_t resultID = getNextID();
|
||||
SmallVector<uint32_t, 4> operands = {typeID, resultID};
|
||||
operands.reserve(shapedType.getDimSize(dim) + 2);
|
||||
auto elementType = constType.cast<spirv::CompositeType>().getElementType(0);
|
||||
for (int i = 0; i < shapedType.getDimSize(dim); ++i) {
|
||||
index[dim] = i;
|
||||
if (auto elementID = prepareDenseElementsConstant(
|
||||
loc, elementType, valueAttr, dim + 1, index)) {
|
||||
operands.push_back(elementID);
|
||||
} else {
|
||||
return failure();
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
|
||||
encodeInstructionInto(typesGlobalValues, opcode, operands);
|
||||
|
||||
LogicalResult Serializer::prepareIntVectorConstant(
|
||||
Location loc, DenseIntElementsAttr elementsAttr, spirv::Opcode &opcode,
|
||||
SmallVectorImpl<uint32_t> &operands) {
|
||||
auto type = elementsAttr.getType();
|
||||
assert(type.hasRank() && type.getRank() == 1 &&
|
||||
"spv.constant should have verified only vector literal uses "
|
||||
"ElementsAttr");
|
||||
assert(!type.getElementType().isInteger(1) &&
|
||||
"must be non-bool ElementsAttr");
|
||||
auto count = type.getNumElements();
|
||||
|
||||
// Operands for constructing the SPIR-V OpConstant* instruction
|
||||
operands.reserve(count + 2);
|
||||
|
||||
// For splat cases, we don't need to loop over all elements, especially when
|
||||
// the splat value is zero.
|
||||
if (elementsAttr.isSplat()) {
|
||||
auto splatAttr = elementsAttr.getSplatValue<IntegerAttr>();
|
||||
|
||||
// We can use OpConstantNull if this int ElementsAttr is splatting 0.
|
||||
if (splatAttr.getValue().isNullValue()) {
|
||||
opcode = spirv::Opcode::OpConstantNull;
|
||||
return success();
|
||||
}
|
||||
|
||||
if (auto id = prepareConstantInt(loc, splatAttr)) {
|
||||
opcode = spirv::Opcode::OpConstantComposite;
|
||||
operands.append(count, id);
|
||||
return success();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Otherwise, we need to process each element and compose them with
|
||||
// OpConstantComposite.
|
||||
opcode = spirv::Opcode::OpConstantComposite;
|
||||
for (auto intAttr : elementsAttr.getValues<IntegerAttr>()) {
|
||||
// We are constructing an IntegerAttr for each value here. But given that
|
||||
// we only use ElementsAttr for vectors with no more than 4 elements, it
|
||||
// should be fine here.
|
||||
// TODO(antiagainst): revisit this if special extensions enabling large
|
||||
// vectors are supported.
|
||||
if (auto elementID = prepareConstantInt(loc, intAttr)) {
|
||||
operands.push_back(elementID);
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult Serializer::prepareFloatVectorConstant(
|
||||
Location loc, DenseFPElementsAttr elementsAttr, spirv::Opcode &opcode,
|
||||
SmallVectorImpl<uint32_t> &operands) {
|
||||
auto type = elementsAttr.getType();
|
||||
assert(type.hasRank() && type.getRank() == 1 &&
|
||||
"spv.constant should have verified only vector literal uses "
|
||||
"ElementsAttr");
|
||||
auto count = type.getNumElements();
|
||||
|
||||
operands.reserve(count + 2);
|
||||
|
||||
if (elementsAttr.isSplat()) {
|
||||
FloatAttr splatAttr = elementsAttr.getSplatValue<FloatAttr>();
|
||||
if (splatAttr.getValue().isZero()) {
|
||||
opcode = spirv::Opcode::OpConstantNull;
|
||||
return success();
|
||||
}
|
||||
|
||||
if (auto id = prepareConstantFp(loc, splatAttr)) {
|
||||
opcode = spirv::Opcode::OpConstantComposite;
|
||||
operands.append(count, id);
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
opcode = spirv::Opcode::OpConstantComposite;
|
||||
for (auto fpAttr : elementsAttr.getValues<FloatAttr>()) {
|
||||
if (auto elementID = prepareConstantFp(loc, fpAttr)) {
|
||||
operands.push_back(elementID);
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
return success();
|
||||
return resultID;
|
||||
}
|
||||
|
||||
uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr,
|
||||
|
|
|
@ -178,4 +178,18 @@ spv.module "Logical" "GLSL450" {
|
|||
%4 = spv.IAdd %0, %3 : i32
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @multi_dimensions_const
|
||||
func @multi_dimensions_const() -> (!spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>) {
|
||||
// CHECK: spv.constant {{\[}}{{\[}}[1 : i32, 2 : i32, 3 : i32], [4 : i32, 5 : i32, 6 : i32]], {{\[}}[7 : i32, 8 : i32, 9 : i32], [10 : i32, 11 : i32, 12 : i32]]] : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>
|
||||
%0 = spv.constant dense<[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]> : tensor<2x2x3xi32> : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>
|
||||
spv.ReturnValue %0 : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @multi_dimensions_splat_const
|
||||
func @multi_dimensions_splat_const() -> (!spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>) {
|
||||
// CHECK: spv.constant {{\[}}{{\[}}[1 : i32, 1 : i32, 1 : i32], [1 : i32, 1 : i32, 1 : i32]], {{\[}}[1 : i32, 1 : i32, 1 : i32], [1 : i32, 1 : i32, 1 : i32]]] : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>
|
||||
%0 = spv.constant dense<1> : tensor<2x2x3xi32> : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>
|
||||
spv.ReturnValue %0 : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>
|
||||
}
|
||||
}
|
||||
|
|
|
@ -48,12 +48,20 @@ func @const() -> () {
|
|||
// CHECK: %2 = spv.constant 5.000000e-01 : f32
|
||||
// CHECK: %3 = spv.constant dense<[2, 3]> : vector<2xi32>
|
||||
// CHECK: %4 = spv.constant [dense<3.000000e+00> : vector<2xf32>] : !spv.array<1 x vector<2xf32>>
|
||||
// CHECK: %5 = spv.constant dense<1> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32 [4]> [12]>
|
||||
// CHECK: %6 = spv.constant dense<1.000000e+00> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32 [4]> [12]>
|
||||
// CHECK: %7 = spv.constant dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32 [4]> [12]>
|
||||
// CHECK: %8 = spv.constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32 [4]> [12]>
|
||||
|
||||
%0 = spv.constant true
|
||||
%1 = spv.constant 42 : i32
|
||||
%2 = spv.constant 0.5 : f32
|
||||
%3 = spv.constant dense<[2, 3]> : vector<2xi32>
|
||||
%4 = spv.constant [dense<3.0> : vector<2xf32>] : !spv.array<1xvector<2xf32>>
|
||||
%5 = spv.constant dense<1> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32 [4]> [12]>
|
||||
%6 = spv.constant dense<1.0> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32 [4]> [12]>
|
||||
%7 = spv.constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32 [4]> [12]>
|
||||
%8 = spv.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32 [4]> [12]>
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -83,11 +91,33 @@ func @array_constant() -> () {
|
|||
|
||||
// -----
|
||||
|
||||
func @non_nested_array_constant() -> () {
|
||||
// expected-error @+1 {{only support nested array result type}}
|
||||
%0 = spv.constant dense<3.0> : tensor<2x2xf32> : !spv.array<2xvector<2xf32>>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @value_result_type_mismatch() -> () {
|
||||
// expected-error @+1 {{result type ('vector<4xi32>') does not match value type ('tensor<4xi32>')}}
|
||||
// expected-error @+1 {{must have spv.array result type for array value}}
|
||||
%0 = "spv.constant"() {value = dense<0> : tensor<4xi32>} : () -> (vector<4xi32>)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @value_result_type_mismatch() -> () {
|
||||
// expected-error @+1 {{result element type ('i32') does not match value element type ('f32')}}
|
||||
%0 = spv.constant dense<1.0> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x i32 [4]> [12]>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @value_result_num_elements_mismatch() -> () {
|
||||
// expected-error @+1 {{result number of elements (6) does not match value number of elements (4)}}
|
||||
%0 = spv.constant dense<1.0> : tensor<2x2xf32> : !spv.array<2 x !spv.array<3 x f32 [4]> [12]>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
|
Loading…
Reference in New Issue