[mlir] Add a DialectAsmParser::getChecked method

This function simplifies calling the getChecked methods on Attributes and Types from within the parser, and removes any need to use `getEncodedSourceLocation` for these methods (by using an SMLoc instead). This is much more efficient than using an mlir::Location, as the encoding process to produce an mlir::Location is inefficient and undesirable for parsing (locations used during parsing should not persist afterwards unless otherwise necessary).

Differential Revision: https://reviews.llvm.org/D97900
This commit is contained in:
River Riddle 2021-03-04 11:52:14 -08:00
parent 9783e20988
commit 6bc767cd07
4 changed files with 50 additions and 33 deletions

View File

@ -121,6 +121,10 @@ public:
virtual llvm::SMLoc getNameLoc() const = 0;
/// Re-encode the given source location as an MLIR location and return it.
/// Note: This method should only be used when a `Location` is necessary, as
/// the encoding process is not efficient. In other cases a more suitable
/// alternative should be used, such as the `getChecked` methods defined
/// below.
virtual Location getEncodedSourceLoc(llvm::SMLoc loc) = 0;
/// Returns the full specification of the symbol being parsed. This allows for
@ -163,6 +167,22 @@ public:
return success();
}
/// Invoke the `getChecked` method of the given Attribute or Type class, using
/// the provided location to emit errors in the case of failure. Note that
/// unlike `OpBuilder::getType`, this method does not implicitly insert a
/// context parameter.
template <typename T, typename... ParamsT>
T getChecked(llvm::SMLoc loc, ParamsT &&...params) {
return T::getChecked([&] { return emitError(loc); },
std::forward<ParamsT>(params)...);
}
/// A variant of `getChecked` that uses the result of `getNameLoc` to emit
/// errors.
template <typename T, typename... ParamsT> T getChecked(ParamsT &&...params) {
return T::getChecked([&] { return emitError(getNameLoc()); },
std::forward<ParamsT>(params)...);
}
//===--------------------------------------------------------------------===//
// Token Parsing
//===--------------------------------------------------------------------===//

View File

@ -178,7 +178,7 @@ static ParseResult dispatchParse(DialectAsmParser &parser, Type &type);
/// Parses an LLVM dialect function type.
/// llvm-type :: = `func<` llvm-type `(` llvm-type-list `...`? `)>`
static LLVMFunctionType parseFunctionType(DialectAsmParser &parser) {
Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
llvm::SMLoc loc = parser.getCurrentLocation();
Type returnType;
if (parser.parseLess() || dispatchParse(parser, returnType) ||
parser.parseLParen())
@ -187,8 +187,8 @@ static LLVMFunctionType parseFunctionType(DialectAsmParser &parser) {
// Function type without arguments.
if (succeeded(parser.parseOptionalRParen())) {
if (succeeded(parser.parseGreater()))
return LLVMFunctionType::getChecked(loc, returnType, llvm::None,
/*isVarArg=*/false);
return parser.getChecked<LLVMFunctionType>(loc, returnType, llvm::None,
/*isVarArg=*/false);
return LLVMFunctionType();
}
@ -198,8 +198,8 @@ static LLVMFunctionType parseFunctionType(DialectAsmParser &parser) {
if (succeeded(parser.parseOptionalEllipsis())) {
if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
return LLVMFunctionType();
return LLVMFunctionType::getChecked(loc, returnType, argTypes,
/*isVarArg=*/true);
return parser.getChecked<LLVMFunctionType>(loc, returnType, argTypes,
/*isVarArg=*/true);
}
Type arg;
@ -210,14 +210,14 @@ static LLVMFunctionType parseFunctionType(DialectAsmParser &parser) {
if (parser.parseOptionalRParen() || parser.parseOptionalGreater())
return LLVMFunctionType();
return LLVMFunctionType::getChecked(loc, returnType, argTypes,
/*isVarArg=*/false);
return parser.getChecked<LLVMFunctionType>(loc, returnType, argTypes,
/*isVarArg=*/false);
}
/// Parses an LLVM dialect pointer type.
/// llvm-type ::= `ptr<` llvm-type (`,` integer)? `>`
static LLVMPointerType parsePointerType(DialectAsmParser &parser) {
Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
llvm::SMLoc loc = parser.getCurrentLocation();
Type elementType;
if (parser.parseLess() || dispatchParse(parser, elementType))
return LLVMPointerType();
@ -228,7 +228,7 @@ static LLVMPointerType parsePointerType(DialectAsmParser &parser) {
return LLVMPointerType();
if (failed(parser.parseGreater()))
return LLVMPointerType();
return LLVMPointerType::getChecked(loc, elementType, addressSpace);
return parser.getChecked<LLVMPointerType>(loc, elementType, addressSpace);
}
/// Parses an LLVM dialect vector type.
@ -238,7 +238,7 @@ static Type parseVectorType(DialectAsmParser &parser) {
SmallVector<int64_t, 2> dims;
llvm::SMLoc dimPos, typePos;
Type elementType;
Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
llvm::SMLoc loc = parser.getCurrentLocation();
if (parser.parseLess() || parser.getCurrentLocation(&dimPos) ||
parser.parseDimensionList(dims, /*allowDynamic=*/true) ||
parser.getCurrentLocation(&typePos) ||
@ -259,13 +259,13 @@ static Type parseVectorType(DialectAsmParser &parser) {
bool isScalable = dims.size() == 2;
if (isScalable)
return LLVMScalableVectorType::getChecked(loc, elementType, dims[1]);
return parser.getChecked<LLVMScalableVectorType>(loc, elementType, dims[1]);
if (elementType.isSignlessIntOrFloat()) {
parser.emitError(typePos)
<< "cannot use !llvm.vec for built-in primitives, use 'vector' instead";
return Type();
}
return LLVMFixedVectorType::getChecked(loc, elementType, dims[0]);
return parser.getChecked<LLVMFixedVectorType>(loc, elementType, dims[0]);
}
/// Parses an LLVM dialect array type.
@ -274,7 +274,7 @@ static LLVMArrayType parseArrayType(DialectAsmParser &parser) {
SmallVector<int64_t, 1> dims;
llvm::SMLoc sizePos;
Type elementType;
Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
llvm::SMLoc loc = parser.getCurrentLocation();
if (parser.parseLess() || parser.getCurrentLocation(&sizePos) ||
parser.parseDimensionList(dims, /*allowDynamic=*/false) ||
dispatchParse(parser, elementType) || parser.parseGreater())
@ -285,7 +285,7 @@ static LLVMArrayType parseArrayType(DialectAsmParser &parser) {
return LLVMArrayType();
}
return LLVMArrayType::getChecked(loc, elementType, dims[0]);
return parser.getChecked<LLVMArrayType>(loc, elementType, dims[0]);
}
/// Attempts to set the body of an identified structure type. Reports a parsing

View File

@ -117,7 +117,7 @@ static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser,
/// storage-range ::= integer-literal `:` integer-literal
/// storage-type ::= (`i` | `u`) integer-literal
/// expressed-type-spec ::= `:` `f` integer-literal
static Type parseAnyType(DialectAsmParser &parser, Location loc) {
static Type parseAnyType(DialectAsmParser &parser) {
IntegerType storageType;
FloatType expressedType;
unsigned typeFlags = 0;
@ -155,9 +155,8 @@ static Type parseAnyType(DialectAsmParser &parser, Location loc) {
return nullptr;
}
return AnyQuantizedType::getChecked(loc, typeFlags, storageType,
expressedType, storageTypeMin,
storageTypeMax);
return parser.getChecked<AnyQuantizedType>(
typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax);
}
static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale,
@ -192,7 +191,7 @@ static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale,
/// axis-spec ::= `:` integer-literal
/// scale-zero ::= float-literal `:` integer-literal
/// scale-zero-list ::= `{` scale-zero (`,` scale-zero)* `}`
static Type parseUniformType(DialectAsmParser &parser, Location loc) {
static Type parseUniformType(DialectAsmParser &parser) {
IntegerType storageType;
FloatType expressedType;
unsigned typeFlags = 0;
@ -279,14 +278,14 @@ static Type parseUniformType(DialectAsmParser &parser, Location loc) {
if (isPerAxis) {
ArrayRef<double> scalesRef(scales.begin(), scales.end());
ArrayRef<int64_t> zeroPointsRef(zeroPoints.begin(), zeroPoints.end());
return UniformQuantizedPerAxisType::getChecked(
loc, typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
return parser.getChecked<UniformQuantizedPerAxisType>(
typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
quantizedDimension, storageTypeMin, storageTypeMax);
}
return UniformQuantizedType::getChecked(
loc, typeFlags, storageType, expressedType, scales.front(),
zeroPoints.front(), storageTypeMin, storageTypeMax);
return parser.getChecked<UniformQuantizedType>(
typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(),
storageTypeMin, storageTypeMax);
}
/// Parses an CalibratedQuantizedType.
@ -295,7 +294,7 @@ static Type parseUniformType(DialectAsmParser &parser, Location loc) {
/// expressed-spec ::= expressed-type `<` calibrated-range `>`
/// expressed-type ::= `f` integer-literal
/// calibrated-range ::= float-literal `:` float-literal
static Type parseCalibratedType(DialectAsmParser &parser, Location loc) {
static Type parseCalibratedType(DialectAsmParser &parser) {
FloatType expressedType;
double min;
double max;
@ -314,24 +313,22 @@ static Type parseCalibratedType(DialectAsmParser &parser, Location loc) {
return nullptr;
}
return CalibratedQuantizedType::getChecked(loc, expressedType, min, max);
return parser.getChecked<CalibratedQuantizedType>(expressedType, min, max);
}
/// Parse a type registered to this dialect.
Type QuantizationDialect::parseType(DialectAsmParser &parser) const {
Location loc = parser.getEncodedSourceLoc(parser.getNameLoc());
// All types start with an identifier that we switch on.
StringRef typeNameSpelling;
if (failed(parser.parseKeyword(&typeNameSpelling)))
return nullptr;
if (typeNameSpelling == "uniform")
return parseUniformType(parser, loc);
return parseUniformType(parser);
if (typeNameSpelling == "any")
return parseAnyType(parser, loc);
return parseAnyType(parser);
if (typeNameSpelling == "calibrated")
return parseCalibratedType(parser, loc);
return parseCalibratedType(parser);
parser.emitError(parser.getNameLoc(),
"unknown quantized type " + typeNameSpelling);

View File

@ -524,7 +524,7 @@ Attribute Parser::parseExtendedAttr(Type type) {
// Otherwise, form a new opaque attribute.
return OpaqueAttr::getChecked(
getEncodedSourceLocation(loc),
[&] { return emitError(loc); },
Identifier::get(dialectName, state.context), symbolData,
attrType ? attrType : NoneType::get(state.context));
});
@ -563,7 +563,7 @@ Type Parser::parseExtendedType() {
// Otherwise, form a new opaque type.
return OpaqueType::getChecked(
getEncodedSourceLocation(loc),
[&] { return emitError(loc); },
Identifier::get(dialectName, state.context), symbolData);
});
}