[spirv] Use mlir::parseType in type parsers and add more checks

PiperOrigin-RevId: 252874386
This commit is contained in:
Lei Zhang 2019-06-12 12:16:05 -07:00 committed by Mehdi Amini
parent a3e6f102ca
commit 8c6f188143
4 changed files with 129 additions and 61 deletions

View File

@ -25,19 +25,29 @@
#include "mlir/IR/Dialect.h"
namespace mlir {
class MLIRContext;
namespace spirv {
class SPIRVDialect : public Dialect {
public:
explicit SPIRVDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "spv"; }
/// Parses a type registered to this dialect.
Type parseType(llvm::StringRef spec, Location loc) const override;
/// Prints a type registered to this dialect.
void printType(Type type, llvm::raw_ostream &os) const override;
private:
/// Parses `spec` as a type and verifies it can be used in SPIR-V types.
Type parseAndVerifyType(StringRef spec, Location loc) const;
/// Parses `spec` as a SPIR-V array type.
Type parseArrayType(StringRef spec, Location loc) const;
/// Parses `spec` as a SPIR-V run-time array type.
Type parseRuntimeArrayType(StringRef spec, Location loc) const;
};
} // end namespace spirv

View File

@ -14,4 +14,7 @@ add_dependencies(MLIRSPIRV
MLIRSPIRVEnumsIncGen
MLIRStdOpsToSPIRVConversionIncGen)
target_link_libraries(MLIRSPIRV MLIRIR MLIRSupport)
target_link_libraries(MLIRSPIRV
MLIRIR
MLIRParser
MLIRSupport)

View File

@ -21,9 +21,9 @@
#include "mlir/SPIRV/SPIRVDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Parser.h"
#include "mlir/SPIRV/SPIRVOps.h"
#include "mlir/SPIRV/SPIRVTypes.h"
#include "llvm/ADT/StringExtras.h"
@ -37,7 +37,8 @@ using namespace mlir::spirv;
// SPIR-V Dialect
//===----------------------------------------------------------------------===//
SPIRVDialect::SPIRVDialect(MLIRContext *context) : Dialect("spv", context) {
SPIRVDialect::SPIRVDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
addTypes<ArrayType, RuntimeArrayType>();
addOperations<
@ -53,19 +54,6 @@ SPIRVDialect::SPIRVDialect(MLIRContext *context) : Dialect("spv", context) {
// Type Parsing
//===----------------------------------------------------------------------===//
// TODO(b/133530217): The following implements some type parsing logic. It is
// intended to be short-lived and used just before the main parser logic gets
// exposed to dialects. So there is little type checking inside.
static Type parseScalarType(StringRef spec, Builder builder) {
return llvm::StringSwitch<Type>(spec)
.Case("f32", builder.getF32Type())
.Case("i32", builder.getIntegerType(32))
.Case("f16", builder.getF16Type())
.Case("i16", builder.getIntegerType(16))
.Default(Type());
}
// Parses "<number> x" from the beginning of `spec`.
static bool parseNumberX(StringRef &spec, int64_t &number) {
spec = spec.ltrim();
@ -85,56 +73,77 @@ static bool parseNumberX(StringRef &spec, int64_t &number) {
return true;
}
static Type parseVectorType(StringRef spec, Builder builder) {
if (!spec.consume_front("vector<") || !spec.consume_back(">"))
Type SPIRVDialect::parseAndVerifyType(StringRef spec, Location loc) const {
auto *context = getContext();
auto type = mlir::parseType(spec, context);
if (!type) {
context->emitError(loc, "cannot parse type: ") << spec;
return Type();
}
int64_t count = 0;
if (!parseNumberX(spec, count))
// Allow SPIR-V dialect types
if (&type.getDialect() == this)
return type;
// Check other allowed types
if (auto t = type.dyn_cast<FloatType>()) {
if (type.isBF16()) {
context->emitError(loc, "cannot use 'bf16' to compose SPIR-V types");
return Type();
}
} else if (auto t = type.dyn_cast<IntegerType>()) {
if (!llvm::is_contained(llvm::ArrayRef<unsigned>({8, 16, 32, 64}),
t.getWidth())) {
context->emitError(loc,
"only 8/16/32/64-bit integer type allowed but found ")
<< type;
return Type();
}
} else if (auto t = type.dyn_cast<VectorType>()) {
if (t.getRank() != 1) {
context->emitError(loc, "only 1-D vector allowed but found ") << t;
return Type();
}
} else {
context->emitError(loc, "cannot use ")
<< type << " to compose SPIR-V types";
return Type();
}
spec = spec.trim();
auto scalarType = parseScalarType(spec, builder);
if (!scalarType)
return Type();
return VectorType::get({count}, scalarType);
return type;
}
static Type parseArrayType(StringRef spec, Builder builder) {
if (!spec.consume_front("array<") || !spec.consume_back(">"))
Type SPIRVDialect::parseArrayType(StringRef spec, Location loc) const {
auto *context = getContext();
if (!spec.consume_front("array<") || !spec.consume_back(">")) {
context->emitError(loc, "spv.array delimiter <...> mismatch");
return Type();
Type elementType;
int64_t count = 0;
spec = spec.trim();
if (!parseNumberX(spec, count))
return Type();
spec = spec.ltrim();
if (spec.startswith("vector")) {
elementType = parseVectorType(spec, builder);
} else {
elementType = parseScalarType(spec, builder);
}
int64_t count = 0;
spec = spec.trim();
if (!parseNumberX(spec, count)) {
context->emitError(
loc, "expected array element count followed by 'x' but found '")
<< spec << "'";
return Type();
}
Type elementType = parseAndVerifyType(spec, loc);
if (!elementType)
return Type();
return ArrayType::get(elementType, count);
}
static Type parseRuntimeArrayType(StringRef spec, Builder builder) {
if (!spec.consume_front("rtarray<") || !spec.consume_back(">"))
Type SPIRVDialect::parseRuntimeArrayType(StringRef spec, Location loc) const {
auto *context = getContext();
if (!spec.consume_front("rtarray<") || !spec.consume_back(">")) {
context->emitError(loc, "spv.rtarray delimiter <...> mismatch");
return Type();
Type elementType;
spec = spec.trim();
if (spec.startswith("vector")) {
elementType = parseVectorType(spec, builder);
} else {
elementType = parseScalarType(spec, builder);
}
Type elementType = parseAndVerifyType(spec, loc);
if (!elementType)
return Type();
@ -142,12 +151,13 @@ static Type parseRuntimeArrayType(StringRef spec, Builder builder) {
}
Type SPIRVDialect::parseType(StringRef spec, Location loc) const {
Builder builder(getContext());
if (auto type = parseArrayType(spec, builder))
return type;
if (auto type = parseRuntimeArrayType(spec, builder))
return type;
if (spec.startswith("array")) {
return parseArrayType(spec, loc);
}
if (spec.startswith("rtarray")) {
return parseRuntimeArrayType(spec, loc);
}
getContext()->emitError(loc, "unknown SPIR-V type: ") << spec;
return Type();

View File

@ -14,21 +14,61 @@ func @vector_array_type(!spv.array< 32 x vector<4xf32> >) -> ()
// -----
// expected-error @+1 {{unknown SPIR-V type}}
// expected-error @+1 {{spv.array delimiter <...> mismatch}}
func @missing_left_angle_bracket(!spv.array 4xf32>) -> ()
// -----
// expected-error @+1 {{expected array element count followed by 'x' but found 'f32'}}
func @missing_count(!spv.array<f32>) -> ()
// -----
// expected-error @+1 {{unknown SPIR-V type}}
// expected-error @+1 {{expected array element count followed by 'x' but found 'f32'}}
func @missing_x(!spv.array<4 f32>) -> ()
// -----
// expected-error @+1 {{unknown SPIR-V type}}
// expected-error @+1 {{cannot parse type: blabla}}
func @cannot_parse_type(!spv.array<4xblabla>) -> ()
// -----
// expected-error @+1 {{cannot parse type: 3xf32}}
func @more_than_one_dim(!spv.array<4x3xf32>) -> ()
// -----
// expected-error @+1 {{only 1-D vector allowed but found 'vector<4x3xf32>'}}
func @non_1D_vector(!spv.array<4xvector<4x3xf32>>) -> ()
// -----
// expected-error @+1 {{cannot use 'tensor<4xf32>' to compose SPIR-V types}}
func @tensor_type(!spv.array<4xtensor<4xf32>>) -> ()
// -----
// expected-error @+1 {{cannot use 'bf16' to compose SPIR-V types}}
func @bf16_type(!spv.array<4xbf16>) -> ()
// -----
// expected-error @+1 {{only 8/16/32/64-bit integer type allowed but found 'i256'}}
func @i256_type(!spv.array<4xi256>) -> ()
// -----
// expected-error @+1 {{cannot use 'index' to compose SPIR-V types}}
func @index_type(!spv.array<4xindex>) -> ()
// -----
// expected-error @+1 {{cannot use '!llvm.i32' to compose SPIR-V types}}
func @llvm_type(!spv.array<4x!llvm.i32>) -> ()
// -----
//===----------------------------------------------------------------------===//
// RuntimeArrayType
//===----------------------------------------------------------------------===//
@ -41,5 +81,10 @@ func @vector_runtime_array_type(!spv.rtarray< vector<4xf32> >) -> ()
// -----
// expected-error @+1 {{unknown SPIR-V type}}
// expected-error @+1 {{spv.rtarray delimiter <...> mismatch}}
func @missing_left_angle_bracket(!spv.rtarray f32>) -> ()
// -----
// expected-error @+1 {{cannot parse type: 4xf32}}
func @redundant_count(!spv.rtarray<4xf32>) -> ()