[spirv] Support i1 as bool type

PiperOrigin-RevId: 264612014
This commit is contained in:
Lei Zhang 2019-08-21 08:17:19 -07:00 committed by A. Unique TensorFlower
parent 31cfee6077
commit 8d18fdf2d3
3 changed files with 26 additions and 13 deletions

View File

@ -84,17 +84,26 @@ static bool parseNumberX(StringRef &spec, int64_t &number) {
return true;
}
static bool isValidSPIRVIntType(IntegerType type) {
return llvm::is_contained(llvm::ArrayRef<unsigned>({1, 8, 16, 32, 64}),
type.getWidth());
}
static bool isValidSPIRVScalarType(Type type) {
if (type.isa<FloatType>()) {
return !type.isBF16();
}
if (auto intType = type.dyn_cast<IntegerType>()) {
return llvm::is_contained(llvm::ArrayRef<unsigned>({1, 8, 16, 32, 64}),
intType.getWidth());
return isValidSPIRVIntType(intType);
}
return false;
}
static bool isValidSPIRVVectorType(VectorType type) {
return type.getRank() == 1 && isValidSPIRVScalarType(type.getElementType()) &&
type.getNumElements() >= 2 && type.getNumElements() <= 4;
}
bool SPIRVDialect::isValidSPIRVType(Type type) const {
// Allow SPIR-V dialect types
if (&type.getDialect() == this) {
@ -104,9 +113,7 @@ bool SPIRVDialect::isValidSPIRVType(Type type) const {
return true;
}
if (auto vectorType = type.dyn_cast<VectorType>()) {
return (isValidSPIRVScalarType(vectorType.getElementType()) &&
vectorType.getNumElements() >= 2 &&
vectorType.getNumElements() <= 4);
return isValidSPIRVVectorType(vectorType);
}
return false;
}
@ -132,9 +139,8 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect, StringRef spec,
return Type();
}
} else if (auto t = type.dyn_cast<IntegerType>()) {
if (!llvm::is_contained(llvm::ArrayRef<unsigned>({8, 16, 32, 64}),
t.getWidth())) {
emitError(loc, "only 8/16/32/64-bit integer type allowed but found ")
if (!isValidSPIRVIntType(t)) {
emitError(loc, "only 1/8/16/32/64-bit integer type allowed but found ")
<< type;
return Type();
}

View File

@ -3,11 +3,15 @@
func @spirv_module() -> () {
spv.module "Logical" "GLSL450" {
// CHECK-LABEL: @bool_const
func @bool_const() -> (i1) {
func @bool_const() -> () {
// CHECK: spv.constant true
%0 = spv.constant true
// CHECK: spv.constant false
%1 = spv.constant false
spv.ReturnValue %0 : i1
%2 = spv.Variable init(%0): !spv.ptr<i1, Function>
%3 = spv.Variable init(%1): !spv.ptr<i1, Function>
spv.Return
}
// CHECK-LABEL: @i32_const

View File

@ -62,7 +62,7 @@ func @bf16_type(!spv.array<4xbf16>) -> ()
// -----
// expected-error @+1 {{only 8/16/32/64-bit integer type allowed but found 'i256'}}
// expected-error @+1 {{only 1/8/16/32/64-bit integer type allowed but found 'i256'}}
func @i256_type(!spv.array<4xi256>) -> ()
// -----
@ -86,10 +86,13 @@ func @array_type_zero_stide(!spv.array<4xi32 [0]>) -> ()
// PointerType
//===----------------------------------------------------------------------===//
// CHECK: func @scalar_ptr_type(!spv.ptr<f32, Uniform>)
// CHECK: @bool_ptr_type(!spv.ptr<i1, Uniform>)
func @bool_ptr_type(!spv.ptr<i1, Uniform>) -> ()
// CHECK: @scalar_ptr_type(!spv.ptr<f32, Uniform>)
func @scalar_ptr_type(!spv.ptr<f32, Uniform>) -> ()
// CHECK: func @vector_ptr_type(!spv.ptr<vector<4xi32>, PushConstant>)
// CHECK: @vector_ptr_type(!spv.ptr<vector<4xi32>, PushConstant>)
func @vector_ptr_type(!spv.ptr<vector<4xi32>,PushConstant>) -> ()
// -----