forked from OSchip/llvm-project
[spirv] Support i1 as bool type
PiperOrigin-RevId: 264612014
This commit is contained in:
parent
31cfee6077
commit
8d18fdf2d3
|
@ -84,17 +84,26 @@ static bool parseNumberX(StringRef &spec, int64_t &number) {
|
|||
return true;
|
||||
}
|
||||
|
||||
static bool isValidSPIRVIntType(IntegerType type) {
|
||||
return llvm::is_contained(llvm::ArrayRef<unsigned>({1, 8, 16, 32, 64}),
|
||||
type.getWidth());
|
||||
}
|
||||
|
||||
static bool isValidSPIRVScalarType(Type type) {
|
||||
if (type.isa<FloatType>()) {
|
||||
return !type.isBF16();
|
||||
}
|
||||
if (auto intType = type.dyn_cast<IntegerType>()) {
|
||||
return llvm::is_contained(llvm::ArrayRef<unsigned>({1, 8, 16, 32, 64}),
|
||||
intType.getWidth());
|
||||
return isValidSPIRVIntType(intType);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool isValidSPIRVVectorType(VectorType type) {
|
||||
return type.getRank() == 1 && isValidSPIRVScalarType(type.getElementType()) &&
|
||||
type.getNumElements() >= 2 && type.getNumElements() <= 4;
|
||||
}
|
||||
|
||||
bool SPIRVDialect::isValidSPIRVType(Type type) const {
|
||||
// Allow SPIR-V dialect types
|
||||
if (&type.getDialect() == this) {
|
||||
|
@ -104,9 +113,7 @@ bool SPIRVDialect::isValidSPIRVType(Type type) const {
|
|||
return true;
|
||||
}
|
||||
if (auto vectorType = type.dyn_cast<VectorType>()) {
|
||||
return (isValidSPIRVScalarType(vectorType.getElementType()) &&
|
||||
vectorType.getNumElements() >= 2 &&
|
||||
vectorType.getNumElements() <= 4);
|
||||
return isValidSPIRVVectorType(vectorType);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
@ -132,9 +139,8 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect, StringRef spec,
|
|||
return Type();
|
||||
}
|
||||
} else if (auto t = type.dyn_cast<IntegerType>()) {
|
||||
if (!llvm::is_contained(llvm::ArrayRef<unsigned>({8, 16, 32, 64}),
|
||||
t.getWidth())) {
|
||||
emitError(loc, "only 8/16/32/64-bit integer type allowed but found ")
|
||||
if (!isValidSPIRVIntType(t)) {
|
||||
emitError(loc, "only 1/8/16/32/64-bit integer type allowed but found ")
|
||||
<< type;
|
||||
return Type();
|
||||
}
|
||||
|
|
|
@ -3,11 +3,15 @@
|
|||
func @spirv_module() -> () {
|
||||
spv.module "Logical" "GLSL450" {
|
||||
// CHECK-LABEL: @bool_const
|
||||
func @bool_const() -> (i1) {
|
||||
func @bool_const() -> () {
|
||||
// CHECK: spv.constant true
|
||||
%0 = spv.constant true
|
||||
// CHECK: spv.constant false
|
||||
%1 = spv.constant false
|
||||
|
||||
spv.ReturnValue %0 : i1
|
||||
%2 = spv.Variable init(%0): !spv.ptr<i1, Function>
|
||||
%3 = spv.Variable init(%1): !spv.ptr<i1, Function>
|
||||
spv.Return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @i32_const
|
||||
|
|
|
@ -62,7 +62,7 @@ func @bf16_type(!spv.array<4xbf16>) -> ()
|
|||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{only 8/16/32/64-bit integer type allowed but found 'i256'}}
|
||||
// expected-error @+1 {{only 1/8/16/32/64-bit integer type allowed but found 'i256'}}
|
||||
func @i256_type(!spv.array<4xi256>) -> ()
|
||||
|
||||
// -----
|
||||
|
@ -86,10 +86,13 @@ func @array_type_zero_stide(!spv.array<4xi32 [0]>) -> ()
|
|||
// PointerType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK: func @scalar_ptr_type(!spv.ptr<f32, Uniform>)
|
||||
// CHECK: @bool_ptr_type(!spv.ptr<i1, Uniform>)
|
||||
func @bool_ptr_type(!spv.ptr<i1, Uniform>) -> ()
|
||||
|
||||
// CHECK: @scalar_ptr_type(!spv.ptr<f32, Uniform>)
|
||||
func @scalar_ptr_type(!spv.ptr<f32, Uniform>) -> ()
|
||||
|
||||
// CHECK: func @vector_ptr_type(!spv.ptr<vector<4xi32>, PushConstant>)
|
||||
// CHECK: @vector_ptr_type(!spv.ptr<vector<4xi32>, PushConstant>)
|
||||
func @vector_ptr_type(!spv.ptr<vector<4xi32>,PushConstant>) -> ()
|
||||
|
||||
// -----
|
||||
|
|
Loading…
Reference in New Issue