From c614c92fdc679d484662756234c9d9cf5c7238e4 Mon Sep 17 00:00:00 2001 From: Hanhan Wang Date: Mon, 18 Nov 2019 20:01:28 -0800 Subject: [PATCH] Support SPIR-V constant op to take DenseElementsAttr as input. Iterates each element to build the array. This includes a little refactor to combine bool/int/float into a function, since they are similar. The only difference is calling different function in the end. PiperOrigin-RevId: 281210288 --- mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 44 +++- .../SPIRV/Serialization/Serializer.cpp | 240 ++++++------------ .../Dialect/SPIRV/Serialization/constant.mlir | 14 + mlir/test/Dialect/SPIRV/structure-ops.mlir | 32 ++- 4 files changed, 155 insertions(+), 175 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 8964963cb0b9..4c9dd5b79d20 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -1079,12 +1079,10 @@ static ParseResult parseConstantOp(OpAsmParser &parser, OperationState &state) { if (parser.parseAttribute(value, kValueAttrName, state.attributes)) return failure(); - Type type; - if (value.getType().isa()) { + Type type = value.getType(); + if (type.isa() || type.isa()) { if (parser.parseColonType(type)) return failure(); - } else { - type = value.getType(); } return parser.addTypeToList(type, state.types); @@ -1108,14 +1106,46 @@ static LogicalResult verify(spirv::ConstantOp constOp) { switch (value.getKind()) { case StandardAttributes::Bool: case StandardAttributes::Integer: - case StandardAttributes::Float: - case StandardAttributes::DenseElements: - case StandardAttributes::SparseElements: { + case StandardAttributes::Float: { if (valueType != opType) return constOp.emitOpError("result type (") << opType << ") does not match value type (" << valueType << ")"; return success(); } break; + case StandardAttributes::DenseElements: + case StandardAttributes::SparseElements: { + if (valueType == opType) + break; + auto arrayType = opType.dyn_cast(); + auto shapedType = valueType.dyn_cast(); + if (!arrayType) { + return constOp.emitOpError( + "must have spv.array result type for array value"); + } + + int numElements = arrayType.getNumElements(); + auto opElemType = arrayType.getElementType(); + while (auto t = opElemType.dyn_cast()) { + numElements *= t.getNumElements(); + opElemType = t.getElementType(); + } + if (!opElemType.isIntOrFloat()) { + return constOp.emitOpError("only support nested array result type"); + } + + auto valueElemType = shapedType.getElementType(); + if (valueElemType != opElemType) { + return constOp.emitOpError("result element type (") + << opElemType << ") does not match value element type (" + << valueElemType << ")"; + } + + if (numElements != shapedType.getNumElements()) { + return constOp.emitOpError("result number of elements (") + << numElements << ") does not match value number of elements (" + << shapedType.getNumElements() << ")"; + } + } break; case StandardAttributes::Array: { auto arrayType = opType.dyn_cast(); if (!arrayType) diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp index 0ff79d92ee10..ebe3ceba3366 100644 --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -240,29 +240,21 @@ private: /// constants. uint32_t prepareConstant(Location loc, Type constType, Attribute valueAttr); - /// Prepares bool ElementsAttr serialization. This method updates `opcode` - /// with a proper OpConstant* instruction and pushes literal values for the - /// constant to `operands`. - LogicalResult prepareBoolVectorConstant(Location loc, - DenseIntElementsAttr elementsAttr, - spirv::Opcode &opcode, - SmallVectorImpl &operands); + /// Prepares array attribute serialization. This method emits corresponding + /// OpConstant* and returns the result associated with it. Returns 0 if + /// failed. + uint32_t prepareArrayConstant(Location loc, Type constType, ArrayAttr attr); - /// Prepares int ElementsAttr serialization. This method updates `opcode` with - /// a proper OpConstant* instruction and pushes literal values for the - /// constant to `operands`. - LogicalResult prepareIntVectorConstant(Location loc, - DenseIntElementsAttr elementsAttr, - spirv::Opcode &opcode, - SmallVectorImpl &operands); - - /// Prepares float ElementsAttr serialization. This method updates `opcode` - /// with a proper OpConstant* instruction and pushes literal values for the - /// constant to `operands`. - LogicalResult prepareFloatVectorConstant(Location loc, - DenseFPElementsAttr elementsAttr, - spirv::Opcode &opcode, - SmallVectorImpl &operands); + /// Prepares bool/int/float DenseElementsAttr serialization. This method + /// iterates the DenseElementsAttr to construct the constant array, and + /// returns the result associated with it. Returns 0 if failed. Note + /// that the size of `index` must match the rank. + /// TODO(hanchung): Consider to enhance splat elements cases. For splat cases, + /// we don't need to loop over all elements, especially when the splat value + /// is zero. We can use OpConstantNull when the value is zero. + uint32_t prepareDenseElementsConstant(Location loc, Type constType, + DenseElementsAttr valueAttr, int dim, + MutableArrayRef index); /// Prepares scalar attribute serialization. This method emits corresponding /// OpConstant* and returns the result associated with it. Returns 0 if @@ -1064,6 +1056,7 @@ uint32_t Serializer::prepareConstant(Location loc, Type constType, if (auto id = prepareConstantScalar(loc, valueAttr)) { return id; } + // This is a composite literal. We need to handle each component separately // and then emit an OpConstantComposite for the whole. @@ -1075,179 +1068,92 @@ uint32_t Serializer::prepareConstant(Location loc, Type constType, if (failed(processType(loc, constType, typeID))) { return 0; } - auto resultID = getNextID(); - spirv::Opcode opcode = spirv::Opcode::OpNop; - SmallVector operands; - operands.push_back(typeID); - operands.push_back(resultID); - - if (auto vectorAttr = valueAttr.dyn_cast()) { - if (vectorAttr.getType().getElementType().isInteger(1)) { - if (failed(prepareBoolVectorConstant(loc, vectorAttr, opcode, operands))) - return 0; - } else if (failed( - prepareIntVectorConstant(loc, vectorAttr, opcode, operands))) - return 0; - } else if (auto vectorAttr = valueAttr.dyn_cast()) { - if (failed(prepareFloatVectorConstant(loc, vectorAttr, opcode, operands))) - return 0; + uint32_t resultID = 0; + if (auto attr = valueAttr.dyn_cast()) { + int rank = attr.getType().dyn_cast().getRank(); + SmallVector index(rank); + resultID = prepareDenseElementsConstant(loc, constType, attr, + /*dim=*/0, index); } else if (auto arrayAttr = valueAttr.dyn_cast()) { - opcode = spirv::Opcode::OpConstantComposite; - operands.reserve(arrayAttr.size() + 2); + resultID = prepareArrayConstant(loc, constType, arrayAttr); + } - auto elementType = constType.cast().getElementType(); - for (Attribute elementAttr : arrayAttr) - if (auto elementID = prepareConstant(loc, elementType, elementAttr)) { - operands.push_back(elementID); - } else { - return 0; - } - } else { + if (resultID == 0) { emitError(loc, "cannot serialize attribute: ") << valueAttr; return 0; } - encodeInstructionInto(typesGlobalValues, opcode, operands); constIDMap[valueAttr] = resultID; return resultID; } -LogicalResult Serializer::prepareBoolVectorConstant( - Location loc, DenseIntElementsAttr elementsAttr, spirv::Opcode &opcode, - SmallVectorImpl &operands) { - auto type = elementsAttr.getType(); - assert(type.hasRank() && type.getRank() == 1 && - "spv.constant should have verified only vector literal uses " - "ElementsAttr"); - assert(type.getElementType().isInteger(1) && "must be bool ElementsAttr"); - auto count = type.getNumElements(); - - // Operands for constructing the SPIR-V OpConstant* instruction - operands.reserve(count + 2); - - // For splat cases, we don't need to loop over all elements, especially when - // the splat value is zero. - if (elementsAttr.isSplat()) { - // We can use OpConstantNull if this bool ElementsAttr is splatting false. - if (!elementsAttr.getSplatValue()) { - opcode = spirv::Opcode::OpConstantNull; - return success(); - } - - if (auto id = - prepareConstantBool(loc, elementsAttr.getSplatValue())) { - opcode = spirv::Opcode::OpConstantComposite; - operands.append(count, id); - return success(); - } - - return failure(); +uint32_t Serializer::prepareArrayConstant(Location loc, Type constType, + ArrayAttr attr) { + uint32_t typeID = 0; + if (failed(processType(loc, constType, typeID))) { + return 0; } - // Otherwise, we need to process each element and compose them with - // OpConstantComposite. - opcode = spirv::Opcode::OpConstantComposite; - for (auto boolAttr : elementsAttr.getValues()) { - // We are constructing an BoolAttr for each value here. But given that - // we only use ElementsAttr for vectors with no more than 4 elements, it - // should be fine here. - if (auto elementID = prepareConstantBool(loc, boolAttr)) { + uint32_t resultID = getNextID(); + SmallVector operands = {typeID, resultID}; + operands.reserve(attr.size() + 2); + auto elementType = constType.cast().getElementType(); + for (Attribute elementAttr : attr) { + if (auto elementID = prepareConstant(loc, elementType, elementAttr)) { operands.push_back(elementID); } else { - return failure(); + return 0; } } - return success(); + spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; + encodeInstructionInto(typesGlobalValues, opcode, operands); + + return resultID; } -LogicalResult Serializer::prepareIntVectorConstant( - Location loc, DenseIntElementsAttr elementsAttr, spirv::Opcode &opcode, - SmallVectorImpl &operands) { - auto type = elementsAttr.getType(); - assert(type.hasRank() && type.getRank() == 1 && - "spv.constant should have verified only vector literal uses " - "ElementsAttr"); - assert(!type.getElementType().isInteger(1) && - "must be non-bool ElementsAttr"); - auto count = type.getNumElements(); - - // Operands for constructing the SPIR-V OpConstant* instruction - operands.reserve(count + 2); - - // For splat cases, we don't need to loop over all elements, especially when - // the splat value is zero. - if (elementsAttr.isSplat()) { - auto splatAttr = elementsAttr.getSplatValue(); - - // We can use OpConstantNull if this int ElementsAttr is splatting 0. - if (splatAttr.getValue().isNullValue()) { - opcode = spirv::Opcode::OpConstantNull; - return success(); +// TODO(hanchung): Turn the below function into iterative function, instead of +// recursive function. +uint32_t +Serializer::prepareDenseElementsConstant(Location loc, Type constType, + DenseElementsAttr valueAttr, int dim, + MutableArrayRef index) { + auto shapedType = valueAttr.getType().dyn_cast(); + assert(dim <= shapedType.getRank()); + if (shapedType.getRank() == dim) { + if (auto attr = valueAttr.dyn_cast()) { + return attr.getType().getElementType().isInteger(1) + ? prepareConstantBool(loc, attr.getValue(index)) + : prepareConstantInt(loc, attr.getValue(index)); } - - if (auto id = prepareConstantInt(loc, splatAttr)) { - opcode = spirv::Opcode::OpConstantComposite; - operands.append(count, id); - return success(); + if (auto attr = valueAttr.dyn_cast()) { + return prepareConstantFp(loc, attr.getValue(index)); } - return failure(); + return 0; } - // Otherwise, we need to process each element and compose them with - // OpConstantComposite. - opcode = spirv::Opcode::OpConstantComposite; - for (auto intAttr : elementsAttr.getValues()) { - // We are constructing an IntegerAttr for each value here. But given that - // we only use ElementsAttr for vectors with no more than 4 elements, it - // should be fine here. - // TODO(antiagainst): revisit this if special extensions enabling large - // vectors are supported. - if (auto elementID = prepareConstantInt(loc, intAttr)) { + uint32_t typeID = 0; + if (failed(processType(loc, constType, typeID))) { + return 0; + } + + uint32_t resultID = getNextID(); + SmallVector operands = {typeID, resultID}; + operands.reserve(shapedType.getDimSize(dim) + 2); + auto elementType = constType.cast().getElementType(0); + for (int i = 0; i < shapedType.getDimSize(dim); ++i) { + index[dim] = i; + if (auto elementID = prepareDenseElementsConstant( + loc, elementType, valueAttr, dim + 1, index)) { operands.push_back(elementID); } else { - return failure(); + return 0; } } - return success(); -} + spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; + encodeInstructionInto(typesGlobalValues, opcode, operands); -LogicalResult Serializer::prepareFloatVectorConstant( - Location loc, DenseFPElementsAttr elementsAttr, spirv::Opcode &opcode, - SmallVectorImpl &operands) { - auto type = elementsAttr.getType(); - assert(type.hasRank() && type.getRank() == 1 && - "spv.constant should have verified only vector literal uses " - "ElementsAttr"); - auto count = type.getNumElements(); - - operands.reserve(count + 2); - - if (elementsAttr.isSplat()) { - FloatAttr splatAttr = elementsAttr.getSplatValue(); - if (splatAttr.getValue().isZero()) { - opcode = spirv::Opcode::OpConstantNull; - return success(); - } - - if (auto id = prepareConstantFp(loc, splatAttr)) { - opcode = spirv::Opcode::OpConstantComposite; - operands.append(count, id); - return success(); - } - - return failure(); - } - - opcode = spirv::Opcode::OpConstantComposite; - for (auto fpAttr : elementsAttr.getValues()) { - if (auto elementID = prepareConstantFp(loc, fpAttr)) { - operands.push_back(elementID); - } else { - return failure(); - } - } - return success(); + return resultID; } uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr, diff --git a/mlir/test/Dialect/SPIRV/Serialization/constant.mlir b/mlir/test/Dialect/SPIRV/Serialization/constant.mlir index 953120946db5..50005bed5e2e 100644 --- a/mlir/test/Dialect/SPIRV/Serialization/constant.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/constant.mlir @@ -178,4 +178,18 @@ spv.module "Logical" "GLSL450" { %4 = spv.IAdd %0, %3 : i32 spv.Return } + + // CHECK-LABEL: @multi_dimensions_const + func @multi_dimensions_const() -> (!spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>) { + // CHECK: spv.constant {{\[}}{{\[}}[1 : i32, 2 : i32, 3 : i32], [4 : i32, 5 : i32, 6 : i32]], {{\[}}[7 : i32, 8 : i32, 9 : i32], [10 : i32, 11 : i32, 12 : i32]]] : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]> + %0 = spv.constant dense<[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]> : tensor<2x2x3xi32> : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]> + spv.ReturnValue %0 : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]> + } + + // CHECK-LABEL: @multi_dimensions_splat_const + func @multi_dimensions_splat_const() -> (!spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]>) { + // CHECK: spv.constant {{\[}}{{\[}}[1 : i32, 1 : i32, 1 : i32], [1 : i32, 1 : i32, 1 : i32]], {{\[}}[1 : i32, 1 : i32, 1 : i32], [1 : i32, 1 : i32, 1 : i32]]] : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]> + %0 = spv.constant dense<1> : tensor<2x2x3xi32> : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]> + spv.ReturnValue %0 : !spv.array<2 x !spv.array<2 x !spv.array<3 x i32 [4]> [12]> [24]> + } } diff --git a/mlir/test/Dialect/SPIRV/structure-ops.mlir b/mlir/test/Dialect/SPIRV/structure-ops.mlir index 2bbb03fe1e94..8fe03f46323f 100644 --- a/mlir/test/Dialect/SPIRV/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/structure-ops.mlir @@ -48,12 +48,20 @@ func @const() -> () { // CHECK: %2 = spv.constant 5.000000e-01 : f32 // CHECK: %3 = spv.constant dense<[2, 3]> : vector<2xi32> // CHECK: %4 = spv.constant [dense<3.000000e+00> : vector<2xf32>] : !spv.array<1 x vector<2xf32>> + // CHECK: %5 = spv.constant dense<1> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32 [4]> [12]> + // CHECK: %6 = spv.constant dense<1.000000e+00> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32 [4]> [12]> + // CHECK: %7 = spv.constant dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32 [4]> [12]> + // CHECK: %8 = spv.constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32 [4]> [12]> %0 = spv.constant true %1 = spv.constant 42 : i32 %2 = spv.constant 0.5 : f32 %3 = spv.constant dense<[2, 3]> : vector<2xi32> %4 = spv.constant [dense<3.0> : vector<2xf32>] : !spv.array<1xvector<2xf32>> + %5 = spv.constant dense<1> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32 [4]> [12]> + %6 = spv.constant dense<1.0> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32 [4]> [12]> + %7 = spv.constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spv.array<2 x !spv.array<3 x i32 [4]> [12]> + %8 = spv.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x f32 [4]> [12]> return } @@ -83,11 +91,33 @@ func @array_constant() -> () { // ----- +func @non_nested_array_constant() -> () { + // expected-error @+1 {{only support nested array result type}} + %0 = spv.constant dense<3.0> : tensor<2x2xf32> : !spv.array<2xvector<2xf32>> + return +} + +// ----- + func @value_result_type_mismatch() -> () { - // expected-error @+1 {{result type ('vector<4xi32>') does not match value type ('tensor<4xi32>')}} + // expected-error @+1 {{must have spv.array result type for array value}} %0 = "spv.constant"() {value = dense<0> : tensor<4xi32>} : () -> (vector<4xi32>) } +// ----- + +func @value_result_type_mismatch() -> () { + // expected-error @+1 {{result element type ('i32') does not match value element type ('f32')}} + %0 = spv.constant dense<1.0> : tensor<2x3xf32> : !spv.array<2 x !spv.array<3 x i32 [4]> [12]> +} + +// ----- + +func @value_result_num_elements_mismatch() -> () { + // expected-error @+1 {{result number of elements (6) does not match value number of elements (4)}} + %0 = spv.constant dense<1.0> : tensor<2x2xf32> : !spv.array<2 x !spv.array<3 x f32 [4]> [12]> + return +} // -----