diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp index 40d877a7225a..85a6a6221d96 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -84,17 +84,26 @@ static bool parseNumberX(StringRef &spec, int64_t &number) { return true; } +static bool isValidSPIRVIntType(IntegerType type) { + return llvm::is_contained(llvm::ArrayRef({1, 8, 16, 32, 64}), + type.getWidth()); +} + static bool isValidSPIRVScalarType(Type type) { if (type.isa()) { return !type.isBF16(); } if (auto intType = type.dyn_cast()) { - return llvm::is_contained(llvm::ArrayRef({1, 8, 16, 32, 64}), - intType.getWidth()); + return isValidSPIRVIntType(intType); } return false; } +static bool isValidSPIRVVectorType(VectorType type) { + return type.getRank() == 1 && isValidSPIRVScalarType(type.getElementType()) && + type.getNumElements() >= 2 && type.getNumElements() <= 4; +} + bool SPIRVDialect::isValidSPIRVType(Type type) const { // Allow SPIR-V dialect types if (&type.getDialect() == this) { @@ -104,9 +113,7 @@ bool SPIRVDialect::isValidSPIRVType(Type type) const { return true; } if (auto vectorType = type.dyn_cast()) { - return (isValidSPIRVScalarType(vectorType.getElementType()) && - vectorType.getNumElements() >= 2 && - vectorType.getNumElements() <= 4); + return isValidSPIRVVectorType(vectorType); } return false; } @@ -132,9 +139,8 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect, StringRef spec, return Type(); } } else if (auto t = type.dyn_cast()) { - if (!llvm::is_contained(llvm::ArrayRef({8, 16, 32, 64}), - t.getWidth())) { - emitError(loc, "only 8/16/32/64-bit integer type allowed but found ") + if (!isValidSPIRVIntType(t)) { + emitError(loc, "only 1/8/16/32/64-bit integer type allowed but found ") << type; return Type(); } diff --git a/mlir/test/Dialect/SPIRV/Serialization/constant.mlir b/mlir/test/Dialect/SPIRV/Serialization/constant.mlir index 1d4d4756f6bf..ff422ca04906 100644 --- a/mlir/test/Dialect/SPIRV/Serialization/constant.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/constant.mlir @@ -3,11 +3,15 @@ func @spirv_module() -> () { spv.module "Logical" "GLSL450" { // CHECK-LABEL: @bool_const - func @bool_const() -> (i1) { + func @bool_const() -> () { // CHECK: spv.constant true %0 = spv.constant true + // CHECK: spv.constant false + %1 = spv.constant false - spv.ReturnValue %0 : i1 + %2 = spv.Variable init(%0): !spv.ptr + %3 = spv.Variable init(%1): !spv.ptr + spv.Return } // CHECK-LABEL: @i32_const diff --git a/mlir/test/Dialect/SPIRV/types.mlir b/mlir/test/Dialect/SPIRV/types.mlir index 2bfadae6b737..552ef6ac0fb4 100644 --- a/mlir/test/Dialect/SPIRV/types.mlir +++ b/mlir/test/Dialect/SPIRV/types.mlir @@ -62,7 +62,7 @@ func @bf16_type(!spv.array<4xbf16>) -> () // ----- -// expected-error @+1 {{only 8/16/32/64-bit integer type allowed but found 'i256'}} +// expected-error @+1 {{only 1/8/16/32/64-bit integer type allowed but found 'i256'}} func @i256_type(!spv.array<4xi256>) -> () // ----- @@ -86,10 +86,13 @@ func @array_type_zero_stide(!spv.array<4xi32 [0]>) -> () // PointerType //===----------------------------------------------------------------------===// -// CHECK: func @scalar_ptr_type(!spv.ptr) +// CHECK: @bool_ptr_type(!spv.ptr) +func @bool_ptr_type(!spv.ptr) -> () + +// CHECK: @scalar_ptr_type(!spv.ptr) func @scalar_ptr_type(!spv.ptr) -> () -// CHECK: func @vector_ptr_type(!spv.ptr, PushConstant>) +// CHECK: @vector_ptr_type(!spv.ptr, PushConstant>) func @vector_ptr_type(!spv.ptr,PushConstant>) -> () // -----