forked from OSchip/llvm-project
[spirv] Use mlir::parseType in type parsers and add more checks
PiperOrigin-RevId: 252874386
This commit is contained in:
parent
a3e6f102ca
commit
8c6f188143
|
@ -25,19 +25,29 @@
|
|||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
namespace mlir {
|
||||
class MLIRContext;
|
||||
|
||||
namespace spirv {
|
||||
|
||||
class SPIRVDialect : public Dialect {
|
||||
public:
|
||||
explicit SPIRVDialect(MLIRContext *context);
|
||||
|
||||
static StringRef getDialectNamespace() { return "spv"; }
|
||||
|
||||
/// Parses a type registered to this dialect.
|
||||
Type parseType(llvm::StringRef spec, Location loc) const override;
|
||||
|
||||
/// Prints a type registered to this dialect.
|
||||
void printType(Type type, llvm::raw_ostream &os) const override;
|
||||
|
||||
private:
|
||||
/// Parses `spec` as a type and verifies it can be used in SPIR-V types.
|
||||
Type parseAndVerifyType(StringRef spec, Location loc) const;
|
||||
|
||||
/// Parses `spec` as a SPIR-V array type.
|
||||
Type parseArrayType(StringRef spec, Location loc) const;
|
||||
|
||||
/// Parses `spec` as a SPIR-V run-time array type.
|
||||
Type parseRuntimeArrayType(StringRef spec, Location loc) const;
|
||||
};
|
||||
|
||||
} // end namespace spirv
|
||||
|
|
|
@ -14,4 +14,7 @@ add_dependencies(MLIRSPIRV
|
|||
MLIRSPIRVEnumsIncGen
|
||||
MLIRStdOpsToSPIRVConversionIncGen)
|
||||
|
||||
target_link_libraries(MLIRSPIRV MLIRIR MLIRSupport)
|
||||
target_link_libraries(MLIRSPIRV
|
||||
MLIRIR
|
||||
MLIRParser
|
||||
MLIRSupport)
|
||||
|
|
|
@ -21,9 +21,9 @@
|
|||
|
||||
#include "mlir/SPIRV/SPIRVDialect.h"
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Parser.h"
|
||||
#include "mlir/SPIRV/SPIRVOps.h"
|
||||
#include "mlir/SPIRV/SPIRVTypes.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
|
@ -37,7 +37,8 @@ using namespace mlir::spirv;
|
|||
// SPIR-V Dialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
SPIRVDialect::SPIRVDialect(MLIRContext *context) : Dialect("spv", context) {
|
||||
SPIRVDialect::SPIRVDialect(MLIRContext *context)
|
||||
: Dialect(getDialectNamespace(), context) {
|
||||
addTypes<ArrayType, RuntimeArrayType>();
|
||||
|
||||
addOperations<
|
||||
|
@ -53,19 +54,6 @@ SPIRVDialect::SPIRVDialect(MLIRContext *context) : Dialect("spv", context) {
|
|||
// Type Parsing
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// TODO(b/133530217): The following implements some type parsing logic. It is
|
||||
// intended to be short-lived and used just before the main parser logic gets
|
||||
// exposed to dialects. So there is little type checking inside.
|
||||
|
||||
static Type parseScalarType(StringRef spec, Builder builder) {
|
||||
return llvm::StringSwitch<Type>(spec)
|
||||
.Case("f32", builder.getF32Type())
|
||||
.Case("i32", builder.getIntegerType(32))
|
||||
.Case("f16", builder.getF16Type())
|
||||
.Case("i16", builder.getIntegerType(16))
|
||||
.Default(Type());
|
||||
}
|
||||
|
||||
// Parses "<number> x" from the beginning of `spec`.
|
||||
static bool parseNumberX(StringRef &spec, int64_t &number) {
|
||||
spec = spec.ltrim();
|
||||
|
@ -85,56 +73,77 @@ static bool parseNumberX(StringRef &spec, int64_t &number) {
|
|||
return true;
|
||||
}
|
||||
|
||||
static Type parseVectorType(StringRef spec, Builder builder) {
|
||||
if (!spec.consume_front("vector<") || !spec.consume_back(">"))
|
||||
Type SPIRVDialect::parseAndVerifyType(StringRef spec, Location loc) const {
|
||||
auto *context = getContext();
|
||||
auto type = mlir::parseType(spec, context);
|
||||
if (!type) {
|
||||
context->emitError(loc, "cannot parse type: ") << spec;
|
||||
return Type();
|
||||
}
|
||||
|
||||
int64_t count = 0;
|
||||
if (!parseNumberX(spec, count))
|
||||
// Allow SPIR-V dialect types
|
||||
if (&type.getDialect() == this)
|
||||
return type;
|
||||
|
||||
// Check other allowed types
|
||||
if (auto t = type.dyn_cast<FloatType>()) {
|
||||
if (type.isBF16()) {
|
||||
context->emitError(loc, "cannot use 'bf16' to compose SPIR-V types");
|
||||
return Type();
|
||||
}
|
||||
} else if (auto t = type.dyn_cast<IntegerType>()) {
|
||||
if (!llvm::is_contained(llvm::ArrayRef<unsigned>({8, 16, 32, 64}),
|
||||
t.getWidth())) {
|
||||
context->emitError(loc,
|
||||
"only 8/16/32/64-bit integer type allowed but found ")
|
||||
<< type;
|
||||
return Type();
|
||||
}
|
||||
} else if (auto t = type.dyn_cast<VectorType>()) {
|
||||
if (t.getRank() != 1) {
|
||||
context->emitError(loc, "only 1-D vector allowed but found ") << t;
|
||||
return Type();
|
||||
}
|
||||
} else {
|
||||
context->emitError(loc, "cannot use ")
|
||||
<< type << " to compose SPIR-V types";
|
||||
return Type();
|
||||
}
|
||||
|
||||
spec = spec.trim();
|
||||
auto scalarType = parseScalarType(spec, builder);
|
||||
if (!scalarType)
|
||||
return Type();
|
||||
|
||||
return VectorType::get({count}, scalarType);
|
||||
return type;
|
||||
}
|
||||
|
||||
static Type parseArrayType(StringRef spec, Builder builder) {
|
||||
if (!spec.consume_front("array<") || !spec.consume_back(">"))
|
||||
Type SPIRVDialect::parseArrayType(StringRef spec, Location loc) const {
|
||||
auto *context = getContext();
|
||||
if (!spec.consume_front("array<") || !spec.consume_back(">")) {
|
||||
context->emitError(loc, "spv.array delimiter <...> mismatch");
|
||||
return Type();
|
||||
|
||||
Type elementType;
|
||||
int64_t count = 0;
|
||||
|
||||
spec = spec.trim();
|
||||
if (!parseNumberX(spec, count))
|
||||
return Type();
|
||||
|
||||
spec = spec.ltrim();
|
||||
if (spec.startswith("vector")) {
|
||||
elementType = parseVectorType(spec, builder);
|
||||
} else {
|
||||
elementType = parseScalarType(spec, builder);
|
||||
}
|
||||
|
||||
int64_t count = 0;
|
||||
spec = spec.trim();
|
||||
if (!parseNumberX(spec, count)) {
|
||||
context->emitError(
|
||||
loc, "expected array element count followed by 'x' but found '")
|
||||
<< spec << "'";
|
||||
return Type();
|
||||
}
|
||||
|
||||
Type elementType = parseAndVerifyType(spec, loc);
|
||||
if (!elementType)
|
||||
return Type();
|
||||
|
||||
return ArrayType::get(elementType, count);
|
||||
}
|
||||
|
||||
static Type parseRuntimeArrayType(StringRef spec, Builder builder) {
|
||||
if (!spec.consume_front("rtarray<") || !spec.consume_back(">"))
|
||||
Type SPIRVDialect::parseRuntimeArrayType(StringRef spec, Location loc) const {
|
||||
auto *context = getContext();
|
||||
if (!spec.consume_front("rtarray<") || !spec.consume_back(">")) {
|
||||
context->emitError(loc, "spv.rtarray delimiter <...> mismatch");
|
||||
return Type();
|
||||
|
||||
Type elementType;
|
||||
spec = spec.trim();
|
||||
if (spec.startswith("vector")) {
|
||||
elementType = parseVectorType(spec, builder);
|
||||
} else {
|
||||
elementType = parseScalarType(spec, builder);
|
||||
}
|
||||
|
||||
Type elementType = parseAndVerifyType(spec, loc);
|
||||
if (!elementType)
|
||||
return Type();
|
||||
|
||||
|
@ -142,12 +151,13 @@ static Type parseRuntimeArrayType(StringRef spec, Builder builder) {
|
|||
}
|
||||
|
||||
Type SPIRVDialect::parseType(StringRef spec, Location loc) const {
|
||||
Builder builder(getContext());
|
||||
|
||||
if (auto type = parseArrayType(spec, builder))
|
||||
return type;
|
||||
if (auto type = parseRuntimeArrayType(spec, builder))
|
||||
return type;
|
||||
if (spec.startswith("array")) {
|
||||
return parseArrayType(spec, loc);
|
||||
}
|
||||
if (spec.startswith("rtarray")) {
|
||||
return parseRuntimeArrayType(spec, loc);
|
||||
}
|
||||
|
||||
getContext()->emitError(loc, "unknown SPIR-V type: ") << spec;
|
||||
return Type();
|
||||
|
|
|
@ -14,21 +14,61 @@ func @vector_array_type(!spv.array< 32 x vector<4xf32> >) -> ()
|
|||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{unknown SPIR-V type}}
|
||||
// expected-error @+1 {{spv.array delimiter <...> mismatch}}
|
||||
func @missing_left_angle_bracket(!spv.array 4xf32>) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{expected array element count followed by 'x' but found 'f32'}}
|
||||
func @missing_count(!spv.array<f32>) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{unknown SPIR-V type}}
|
||||
// expected-error @+1 {{expected array element count followed by 'x' but found 'f32'}}
|
||||
func @missing_x(!spv.array<4 f32>) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{unknown SPIR-V type}}
|
||||
// expected-error @+1 {{cannot parse type: blabla}}
|
||||
func @cannot_parse_type(!spv.array<4xblabla>) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{cannot parse type: 3xf32}}
|
||||
func @more_than_one_dim(!spv.array<4x3xf32>) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{only 1-D vector allowed but found 'vector<4x3xf32>'}}
|
||||
func @non_1D_vector(!spv.array<4xvector<4x3xf32>>) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{cannot use 'tensor<4xf32>' to compose SPIR-V types}}
|
||||
func @tensor_type(!spv.array<4xtensor<4xf32>>) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{cannot use 'bf16' to compose SPIR-V types}}
|
||||
func @bf16_type(!spv.array<4xbf16>) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{only 8/16/32/64-bit integer type allowed but found 'i256'}}
|
||||
func @i256_type(!spv.array<4xi256>) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{cannot use 'index' to compose SPIR-V types}}
|
||||
func @index_type(!spv.array<4xindex>) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{cannot use '!llvm.i32' to compose SPIR-V types}}
|
||||
func @llvm_type(!spv.array<4x!llvm.i32>) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RuntimeArrayType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -41,5 +81,10 @@ func @vector_runtime_array_type(!spv.rtarray< vector<4xf32> >) -> ()
|
|||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{unknown SPIR-V type}}
|
||||
// expected-error @+1 {{spv.rtarray delimiter <...> mismatch}}
|
||||
func @missing_left_angle_bracket(!spv.rtarray f32>) -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{cannot parse type: 4xf32}}
|
||||
func @redundant_count(!spv.rtarray<4xf32>) -> ()
|
||||
|
|
Loading…
Reference in New Issue