forked from OSchip/llvm-project
[mlir][spirv] Support integer signedness
This commit updates SPIR-V dialect to support integer signedness by relaxing various checks for signless to just normal integers. The hack for spv.Bitcast can now be removed. Differential Revision: https://reviews.llvm.org/D75611
This commit is contained in:
parent
c72d60d42f
commit
9600b55ac8
|
@ -252,12 +252,18 @@ The SPIR-V dialect reuses standard integer, float, and vector types:
|
|||
Specification | Dialect
|
||||
:----------------------------------: | :-------------------------------:
|
||||
`OpTypeBool` | `i1`
|
||||
`OpTypeInt <bitwidth>` | `i<bitwidth>`
|
||||
`OpTypeFloat <bitwidth>` | `f<bitwidth>`
|
||||
`OpTypeVector <scalar-type> <count>` | `vector<<count> x <scalar-type>>`
|
||||
|
||||
Similarly, `mlir::NoneType` can be used for SPIR-V `OpTypeVoid`; builtin
|
||||
function types can be used for SPIR-V `OpTypeFunction` types.
|
||||
For integer types, the SPIR-V dialect supports all signedness semantics
|
||||
(signless, signed, unsigned) in order to ease transformations from higher level
|
||||
dialects. However, SPIR-V spec only defines two signedness semantics state: 0
|
||||
indicates unsigned, or no signedness semantics, 1 indicates signed semantics. So
|
||||
both `iN` and `uiN` are serialized into the same `OpTypeInt N 0`. For
|
||||
deserialization, we always treat `OpTypeInt N 0` as `iN`.
|
||||
|
||||
`mlir::NoneType` is used for SPIR-V `OpTypeVoid`; builtin function types are
|
||||
used for SPIR-V `OpTypeFunction` types.
|
||||
|
||||
The SPIR-V dialect and defines the following dialect-specific types:
|
||||
|
||||
|
|
|
@ -2945,6 +2945,17 @@ def SPV_SamplerUseAttr:
|
|||
// SPIR-V type definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class IOrUI<int width>
|
||||
: Type<Or<[CPred<"$_self.isSignlessInteger(" # width # ")">,
|
||||
CPred<"$_self.isUnsignedInteger(" # width # ")">]>,
|
||||
width # "-bit signless/unsigned integer"> {
|
||||
int bitwidth = width;
|
||||
}
|
||||
|
||||
class SignlessOrUnsignedIntOfWidths<list<int> widths> :
|
||||
AnyTypeOf<!foreach(w, widths, IOrUI<w>),
|
||||
StrJoinInt<widths, "/">.result # "-bit signless/unsigned integer">;
|
||||
|
||||
def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">;
|
||||
def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">;
|
||||
def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">;
|
||||
|
@ -2953,8 +2964,8 @@ def SPV_IsStructType : CPred<"$_self.isa<::mlir::spirv::StructType>()">;
|
|||
// See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types
|
||||
// for the definition of the following types and type categories.
|
||||
|
||||
def SPV_Void : TypeAlias<NoneType, "void type">;
|
||||
def SPV_Bool : I<1>;
|
||||
def SPV_Void : TypeAlias<NoneType, "void">;
|
||||
def SPV_Bool : TypeAlias<I1, "bool">;
|
||||
def SPV_Integer : AnyIntOfWidths<[8, 16, 32, 64]>;
|
||||
def SPV_Float : FloatOfWidths<[16, 32, 64]>;
|
||||
def SPV_Float16or32 : FloatOfWidths<[16, 32]>;
|
||||
|
@ -2977,6 +2988,8 @@ def SPV_Type : AnyTypeOf<[
|
|||
SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct
|
||||
]>;
|
||||
|
||||
def SPV_SignlessOrUnsignedInt : SignlessOrUnsignedIntOfWidths<[8, 16, 32, 64]>;
|
||||
|
||||
class SPV_ScalarOrVectorOf<Type type> :
|
||||
AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4], [type]>]>;
|
||||
|
||||
|
@ -2985,7 +2998,8 @@ def SPV_ScalarOrVectorOrPtr : AnyTypeOf<[SPV_ScalarOrVector, SPV_AnyPtr]>;
|
|||
|
||||
class SPV_Vec4<Type type> : VectorOfLengthAndType<[4], [type]>;
|
||||
def SPV_IntVec4 : SPV_Vec4<SPV_Integer>;
|
||||
def SPV_I32Vec4 : SPV_Vec4<I32>;
|
||||
def SPV_IOrUIVec4 : SPV_Vec4<SPV_SignlessOrUnsignedInt>;
|
||||
def SPV_Int32Vec4 : SPV_Vec4<AnyI32>;
|
||||
|
||||
// TODO(antiagainst): Use a more appropriate way to model optional operands
|
||||
class SPV_Optional<Type type> : Variadic<type>;
|
||||
|
|
|
@ -61,7 +61,7 @@ def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> {
|
|||
);
|
||||
|
||||
let results = (outs
|
||||
SPV_I32Vec4:$result
|
||||
SPV_Int32Vec4:$result
|
||||
);
|
||||
|
||||
let verifier = [{ return success(); }];
|
||||
|
|
|
@ -95,7 +95,7 @@ def SPV_GroupNonUniformBallotOp : SPV_Op<"GroupNonUniformBallot", []> {
|
|||
);
|
||||
|
||||
let results = (outs
|
||||
SPV_IntVec4:$result
|
||||
SPV_IOrUIVec4:$result
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
|
|
|
@ -363,11 +363,12 @@ static unsigned getBitWidth(Type type) {
|
|||
// TODO: Make sure not caller relies on the actual pointer width value.
|
||||
return 64;
|
||||
}
|
||||
if (type.isSignlessIntOrFloat()) {
|
||||
|
||||
if (type.isIntOrFloat())
|
||||
return type.getIntOrFloatBitWidth();
|
||||
}
|
||||
|
||||
if (auto vectorType = type.dyn_cast<VectorType>()) {
|
||||
assert(vectorType.getElementType().isSignlessIntOrFloat());
|
||||
assert(vectorType.getElementType().isIntOrFloat());
|
||||
return vectorType.getNumElements() *
|
||||
vectorType.getElementType().getIntOrFloatBitWidth();
|
||||
}
|
||||
|
@ -500,7 +501,7 @@ static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) {
|
|||
static LogicalResult verifyAtomicUpdateOp(Operation *op) {
|
||||
auto ptrType = op->getOperand(0).getType().cast<spirv::PointerType>();
|
||||
auto elementType = ptrType.getPointeeType();
|
||||
if (!elementType.isSignlessInteger())
|
||||
if (!elementType.isa<IntegerType>())
|
||||
return op->emitOpError(
|
||||
"pointer operand must point to an integer value, found ")
|
||||
<< elementType;
|
||||
|
@ -1265,7 +1266,7 @@ static LogicalResult verify(spirv::ConstantOp constOp) {
|
|||
numElements *= t.getNumElements();
|
||||
opElemType = t.getElementType();
|
||||
}
|
||||
if (!opElemType.isSignlessIntOrFloat()) {
|
||||
if (!opElemType.isIntOrFloat()) {
|
||||
return constOp.emitOpError("only support nested array result type");
|
||||
}
|
||||
|
||||
|
@ -1769,8 +1770,6 @@ static LogicalResult verify(spirv::GlobalVariableOp varOp) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) {
|
||||
// TODO(antiagainst): check the result integer type's signedness bit is 0.
|
||||
|
||||
spirv::Scope scope = ballotOp.execution_scope();
|
||||
if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
|
||||
return ballotOp.emitOpError(
|
||||
|
|
|
@ -344,9 +344,6 @@ private:
|
|||
/// insertion point.
|
||||
LogicalResult processUndef(ArrayRef<uint32_t> operands);
|
||||
|
||||
/// Processes an OpBitcast instruction.
|
||||
LogicalResult processBitcast(ArrayRef<uint32_t> words);
|
||||
|
||||
/// Method to dispatch to the specialized deserialization function for an
|
||||
/// operation in SPIR-V dialect that is a mirror of an instruction in the
|
||||
/// SPIR-V spec. This is auto-generated from ODS. Dispatch is handled for
|
||||
|
@ -1045,30 +1042,35 @@ LogicalResult Deserializer::processType(spirv::Opcode opcode,
|
|||
|
||||
switch (opcode) {
|
||||
case spirv::Opcode::OpTypeVoid:
|
||||
if (operands.size() != 1) {
|
||||
if (operands.size() != 1)
|
||||
return emitError(unknownLoc, "OpTypeVoid must have no parameters");
|
||||
}
|
||||
typeMap[operands[0]] = opBuilder.getNoneType();
|
||||
break;
|
||||
case spirv::Opcode::OpTypeBool:
|
||||
if (operands.size() != 1) {
|
||||
if (operands.size() != 1)
|
||||
return emitError(unknownLoc, "OpTypeBool must have no parameters");
|
||||
}
|
||||
typeMap[operands[0]] = opBuilder.getI1Type();
|
||||
break;
|
||||
case spirv::Opcode::OpTypeInt:
|
||||
if (operands.size() != 3) {
|
||||
case spirv::Opcode::OpTypeInt: {
|
||||
if (operands.size() != 3)
|
||||
return emitError(
|
||||
unknownLoc, "OpTypeInt must have bitwidth and signedness parameters");
|
||||
}
|
||||
// TODO: Ignoring the signedness right now. Need to handle this effectively
|
||||
// in the MLIR representation.
|
||||
typeMap[operands[0]] = opBuilder.getIntegerType(operands[1]);
|
||||
break;
|
||||
|
||||
// SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
|
||||
// to preserve or validate.
|
||||
// 0 indicates unsigned, or no signedness semantics
|
||||
// 1 indicates signed semantics."
|
||||
//
|
||||
// So we cannot differentiate signless and unsigned integers; always use
|
||||
// signless semantics for such cases.
|
||||
auto sign = operands[2] == 1 ? IntegerType::SignednessSemantics::Signed
|
||||
: IntegerType::SignednessSemantics::Signless;
|
||||
typeMap[operands[0]] = IntegerType::get(operands[1], sign, context);
|
||||
} break;
|
||||
case spirv::Opcode::OpTypeFloat: {
|
||||
if (operands.size() != 2) {
|
||||
if (operands.size() != 2)
|
||||
return emitError(unknownLoc, "OpTypeFloat must have bitwidth parameter");
|
||||
}
|
||||
|
||||
Type floatTy;
|
||||
switch (operands[1]) {
|
||||
case 16:
|
||||
|
@ -1146,7 +1148,7 @@ LogicalResult Deserializer::processArrayType(ArrayRef<uint32_t> operands) {
|
|||
}
|
||||
|
||||
if (auto intVal = countInfo->first.dyn_cast<IntegerAttr>()) {
|
||||
count = intVal.getInt();
|
||||
count = intVal.getValue().getZExtValue();
|
||||
} else {
|
||||
return emitError(unknownLoc, "OpTypeArray count must come from a "
|
||||
"scalar integer constant instruction");
|
||||
|
@ -1451,8 +1453,7 @@ LogicalResult Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
|
|||
}
|
||||
|
||||
auto resultID = operands[1];
|
||||
if (resultType.isSignlessInteger() || resultType.isa<FloatType>() ||
|
||||
resultType.isa<VectorType>()) {
|
||||
if (resultType.isIntOrFloat() || resultType.isa<VectorType>()) {
|
||||
auto attr = opBuilder.getZeroAttr(resultType);
|
||||
// For normal constants, we just record the attribute (and its type) for
|
||||
// later materialization at use sites.
|
||||
|
@ -2051,8 +2052,6 @@ LogicalResult Deserializer::processInstruction(spirv::Opcode opcode,
|
|||
// First dispatch all the instructions whose opcode does not correspond to
|
||||
// those that have a direct mirror in the SPIR-V dialect
|
||||
switch (opcode) {
|
||||
case spirv::Opcode::OpBitcast:
|
||||
return processBitcast(operands);
|
||||
case spirv::Opcode::OpCapability:
|
||||
return processCapability(operands);
|
||||
case spirv::Opcode::OpExtension:
|
||||
|
@ -2152,76 +2151,6 @@ LogicalResult Deserializer::processUndef(ArrayRef<uint32_t> operands) {
|
|||
return success();
|
||||
}
|
||||
|
||||
// TODO(b/130356985): This method is copied from the auto-generated
|
||||
// deserialization function for OpBitcast instruction. This is to avoid
|
||||
// generating a Bitcast operations for cast from signed integer to unsigned
|
||||
// integer and viceversa. MLIR doesn't have native support for this so they both
|
||||
// end up mapping to the same type right now which is illegal according to
|
||||
// OpBitcast semantics (and enforced by the SPIR-V dialect).
|
||||
LogicalResult Deserializer::processBitcast(ArrayRef<uint32_t> words) {
|
||||
SmallVector<Type, 1> resultTypes;
|
||||
size_t wordIndex = 0;
|
||||
(void)wordIndex;
|
||||
uint32_t valueID = 0;
|
||||
(void)valueID;
|
||||
{
|
||||
if (wordIndex >= words.size()) {
|
||||
return emitError(
|
||||
unknownLoc,
|
||||
"expected result type <id> while deserializing spirv::BitcastOp");
|
||||
}
|
||||
auto ty = getType(words[wordIndex]);
|
||||
if (!ty) {
|
||||
return emitError(unknownLoc, "unknown type result <id> : ")
|
||||
<< words[wordIndex];
|
||||
}
|
||||
resultTypes.push_back(ty);
|
||||
wordIndex++;
|
||||
if (wordIndex >= words.size()) {
|
||||
return emitError(
|
||||
unknownLoc,
|
||||
"expected result <id> while deserializing spirv::BitcastOp");
|
||||
}
|
||||
}
|
||||
valueID = words[wordIndex++];
|
||||
SmallVector<Value, 4> operands;
|
||||
SmallVector<NamedAttribute, 4> attributes;
|
||||
if (wordIndex < words.size()) {
|
||||
auto arg = getValue(words[wordIndex]);
|
||||
if (!arg) {
|
||||
return emitError(unknownLoc, "unknown result <id> : ")
|
||||
<< words[wordIndex];
|
||||
}
|
||||
operands.push_back(arg);
|
||||
wordIndex++;
|
||||
}
|
||||
if (wordIndex != words.size()) {
|
||||
return emitError(unknownLoc,
|
||||
"found more operands than expected when deserializing "
|
||||
"spirv::BitcastOp, only ")
|
||||
<< wordIndex << " of " << words.size() << " processed";
|
||||
}
|
||||
if (resultTypes[0] == operands[0].getType() &&
|
||||
resultTypes[0].isSignlessInteger()) {
|
||||
// TODO(b/130356985): This check is added to ignore error in Op verification
|
||||
// due to both signed and unsigned integers mapping to the same
|
||||
// type. Without this check this method is same as what is auto-generated.
|
||||
valueMap[valueID] = operands[0];
|
||||
return success();
|
||||
}
|
||||
|
||||
auto op = opBuilder.create<spirv::BitcastOp>(unknownLoc, resultTypes,
|
||||
operands, attributes);
|
||||
(void)op;
|
||||
valueMap[valueID] = op.getResult();
|
||||
|
||||
if (decorations.count(valueID)) {
|
||||
auto attrs = decorations[valueID].getAttrs();
|
||||
attributes.append(attrs.begin(), attrs.end());
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult Deserializer::processExtInst(ArrayRef<uint32_t> operands) {
|
||||
if (operands.size() < 4) {
|
||||
return emitError(unknownLoc,
|
||||
|
|
|
@ -932,8 +932,11 @@ Serializer::prepareBasicType(Location loc, Type type, uint32_t resultID,
|
|||
|
||||
typeEnum = spirv::Opcode::OpTypeInt;
|
||||
operands.push_back(intType.getWidth());
|
||||
// TODO(antiagainst): support unsigned integers
|
||||
operands.push_back(1);
|
||||
// SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics
|
||||
// to preserve or validate.
|
||||
// 0 indicates unsigned, or no signedness semantics
|
||||
// 1 indicates signed semantics."
|
||||
operands.push_back(intType.isSigned() ? 1 : 0);
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -4,6 +4,10 @@ spv.module "Logical" "GLSL450" {
|
|||
spv.func @bit_cast(%arg0 : f32) "None" {
|
||||
// CHECK: {{%.*}} = spv.Bitcast {{%.*}} : f32 to i32
|
||||
%0 = spv.Bitcast %arg0 : f32 to i32
|
||||
// CHECK: {{%.*}} = spv.Bitcast {{%.*}} : i32 to si32
|
||||
%1 = spv.Bitcast %0 : i32 to si32
|
||||
// CHECK: {{%.*}} = spv.Bitcast {{%.*}} : si32 to i32
|
||||
%2 = spv.Bitcast %1 : si32 to ui32
|
||||
spv.Return
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,6 +27,37 @@ spv.module "Logical" "GLSL450" {
|
|||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @si32_const
|
||||
spv.func @si32_const() -> () "None" {
|
||||
// CHECK: spv.constant 0 : si32
|
||||
%0 = spv.constant 0 : si32
|
||||
// CHECK: spv.constant 10 : si32
|
||||
%1 = spv.constant 10 : si32
|
||||
// CHECK: spv.constant -5 : si32
|
||||
%2 = spv.constant -5 : si32
|
||||
|
||||
%3 = spv.IAdd %0, %1 : si32
|
||||
%4 = spv.IAdd %2, %3 : si32
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @ui32_const
|
||||
// We cannot differentiate signless vs. unsigned integers in SPIR-V blob
|
||||
// because they all use 1 as the signedness bit. So we always treat them
|
||||
// as signless integers.
|
||||
spv.func @ui32_const() -> () "None" {
|
||||
// CHECK: spv.constant 0 : i32
|
||||
%0 = spv.constant 0 : ui32
|
||||
// CHECK: spv.constant 10 : i32
|
||||
%1 = spv.constant 10 : ui32
|
||||
// CHECK: spv.constant -5 : i32
|
||||
%2 = spv.constant 4294967291 : ui32
|
||||
|
||||
%3 = spv.IAdd %0, %1 : ui32
|
||||
%4 = spv.IAdd %2, %3 : ui32
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @i64_const
|
||||
spv.func @i64_const() -> () "None" {
|
||||
// CHECK: spv.constant 4294967296 : i64
|
||||
|
@ -141,8 +172,23 @@ spv.module "Logical" "GLSL450" {
|
|||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @array_const
|
||||
spv.func @array_const() -> (!spv.array<2 x vector<2xf32>>) "None" {
|
||||
// CHECK-LABEL: @ui64_array_const
|
||||
spv.func @ui64_array_const() -> (!spv.array<3xui64>) "None" {
|
||||
// CHECK: spv.constant [5, 6, 7] : !spv.array<3 x i64>
|
||||
%0 = spv.constant [5 : ui64, 6 : ui64, 7 : ui64] : !spv.array<3 x ui64>
|
||||
|
||||
spv.ReturnValue %0: !spv.array<3xui64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @si32_array_const
|
||||
spv.func @si32_array_const() -> (!spv.array<3xsi32>) "None" {
|
||||
// CHECK: spv.constant [5 : si32, 6 : si32, 7 : si32] : !spv.array<3 x si32>
|
||||
%0 = spv.constant [5 : si32, 6 : si32, 7 : si32] : !spv.array<3 x si32>
|
||||
|
||||
spv.ReturnValue %0 : !spv.array<3xsi32>
|
||||
}
|
||||
// CHECK-LABEL: @float_array_const
|
||||
spv.func @float_array_const() -> (!spv.array<2 x vector<2xf32>>) "None" {
|
||||
// CHECK: spv.constant [dense<3.000000e+00> : vector<2xf32>, dense<[4.000000e+00, 5.000000e+00]> : vector<2xf32>] : !spv.array<2 x vector<2xf32>>
|
||||
%0 = spv.constant [dense<3.0> : vector<2xf32>, dense<[4., 5.]> : vector<2xf32>] : !spv.array<2 x vector<2xf32>>
|
||||
|
||||
|
|
|
@ -20,6 +20,14 @@ func @group_non_uniform_ballot(%predicate: i1) -> vector<4xi32> {
|
|||
|
||||
// -----
|
||||
|
||||
func @group_non_uniform_ballot(%predicate: i1) -> vector<4xsi32> {
|
||||
// expected-error @+1 {{op result #0 must be vector of 8/16/32/64-bit signless/unsigned integer values of length 4, but got 'vector<4xsi32>'}}
|
||||
%0 = spv.GroupNonUniformBallot "Workgroup" %predicate : vector<4xsi32>
|
||||
return %0: vector<4xsi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.GroupNonUniformElect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -752,7 +752,7 @@ func @logicalUnary(%arg0 : i1)
|
|||
|
||||
func @logicalUnary(%arg0 : i32)
|
||||
{
|
||||
// expected-error @+1 {{operand #0 must be 1-bit signless integer or vector of 1-bit signless integer values of length 2/3/4, but got 'i32'}}
|
||||
// expected-error @+1 {{operand #0 must be bool or vector of bool values of length 2/3/4, but got 'i32'}}
|
||||
%0 = spv.LogicalNot %arg0 : i32
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue