Add emitOptional(Error|Warning|Remark) functions to simplify emission with an optional location.

In some situations a diagnostic may optionally be emitted by the presence of a location, e.g. attribute and type verification. These situations currently require extra 'if(loc) emitError(...); return failure()' wrappers that make verification clunky. These new overloads take an optional location and a list of arguments to the diagnostic, and return a LogicalResult. We take the arguments directly and return LogicalResult instead of returning InFlightDiagnostic because we cannot create a valid diagnostic with a null location. This creates an awkward situation where a user may try to treat the, potentially null, diagnostic as a valid one and encounter crashes when attaching notes/etc. Below is an example of how these methods simplify some existing usages:

Before:

  if (loc)
    emitError(*loc, "this is my diagnostic with argument: ") << 5;
  return failure();

After:

  return emitOptionalError(loc, "this is my diagnostic with argument: ", 5);

PiperOrigin-RevId: 283853599
This commit is contained in:
River Riddle 2019-12-04 15:49:09 -08:00 committed by A. Unique TensorFlower
parent b3f7cf80a7
commit 2c930f8d9d
9 changed files with 151 additions and 194 deletions

View File

@ -75,10 +75,10 @@ public:
static constexpr unsigned MaxStorageBits = 32;
static LogicalResult
verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *context, unsigned flags,
Type storageType, Type expressedType,
int64_t storageTypeMin, int64_t storageTypeMax);
verifyConstructionInvariants(Optional<Location> loc, MLIRContext *context,
unsigned flags, Type storageType,
Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax);
/// Support method to enable LLVM-style type casting.
static bool classof(Type type) {
@ -238,10 +238,10 @@ public:
/// Verifies construction invariants and issues errors/warnings.
static LogicalResult
verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *context, unsigned flags,
Type storageType, Type expressedType,
int64_t storageTypeMin, int64_t storageTypeMax);
verifyConstructionInvariants(Optional<Location> loc, MLIRContext *context,
unsigned flags, Type storageType,
Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax);
};
/// Represents a family of uniform, quantized types.
@ -298,7 +298,7 @@ public:
/// Verifies construction invariants and issues errors/warnings.
static LogicalResult verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
Optional<Location> loc, MLIRContext *context, unsigned flags,
Type storageType, Type expressedType, double scale, int64_t zeroPoint,
int64_t storageTypeMin, int64_t storageTypeMax);

View File

@ -321,12 +321,12 @@ public:
}
/// Verify the construction invariants for a double value.
static LogicalResult
verifyConstructionInvariants(llvm::Optional<Location> loc, MLIRContext *ctx,
Type type, double value);
static LogicalResult
verifyConstructionInvariants(llvm::Optional<Location> loc, MLIRContext *ctx,
Type type, const APFloat &value);
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *ctx, Type type,
double value);
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *ctx, Type type,
const APFloat &value);
};
//===----------------------------------------------------------------------===//
@ -403,10 +403,11 @@ public:
StringRef getAttrData() const;
/// Verify the construction of an opaque attribute.
static LogicalResult
verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *context, Identifier dialect,
StringRef attrData, Type type);
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *context,
Identifier dialect,
StringRef attrData,
Type type);
static bool kindof(unsigned kind) {
return kind == StandardAttributes::Opaque;

View File

@ -481,6 +481,30 @@ InFlightDiagnostic emitWarning(Location loc, const Twine &message);
InFlightDiagnostic emitRemark(Location loc);
InFlightDiagnostic emitRemark(Location loc, const Twine &message);
/// Overloads of the above emission functions that take an optionally null
/// location. If the location is null, no diagnostic is emitted and a failure is
/// returned. Given that the provided location may be null, these methods take
/// the diagnostic arguments directly instead of relying on the returned
/// InFlightDiagnostic.
template <typename... Args>
LogicalResult emitOptionalError(Optional<Location> loc, Args &&... args) {
if (loc)
return emitError(*loc).append(std::forward<Args>(args)...);
return failure();
}
template <typename... Args>
LogicalResult emitOptionalWarning(Optional<Location> loc, Args &&... args) {
if (loc)
return emitWarning(*loc).append(std::forward<Args>(args)...);
return failure();
}
template <typename... Args>
LogicalResult emitOptionalRemark(Optional<Location> loc, Args &&... args) {
if (loc)
return emitRemark(*loc).append(std::forward<Args>(args)...);
return failure();
}
//===----------------------------------------------------------------------===//
// ScopedDiagnosticHandler
//===----------------------------------------------------------------------===//

View File

@ -102,9 +102,9 @@ public:
Location location);
/// Verify the construction of an integer type.
static LogicalResult
verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *context, unsigned width);
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *context,
unsigned width);
/// Return the bitwidth of this integer type.
unsigned getWidth() const;
@ -168,9 +168,9 @@ public:
static ComplexType getChecked(Type elementType, Location location);
/// Verify the construction of an integer type.
static LogicalResult
verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *context, Type elementType);
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *context,
Type elementType);
Type getElementType();
@ -269,9 +269,9 @@ public:
Location location);
/// Verify the construction of a vector type.
static LogicalResult
verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *context, ArrayRef<int64_t> shape,
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *context,
ArrayRef<int64_t> shape,
Type elementType);
/// Returns true of the given type can be used as an element of a vector type.
@ -328,9 +328,9 @@ public:
Location location);
/// Verify the construction of a ranked tensor type.
static LogicalResult
verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *context, ArrayRef<int64_t> shape,
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *context,
ArrayRef<int64_t> shape,
Type elementType);
ArrayRef<int64_t> getShape() const;
@ -359,9 +359,9 @@ public:
static UnrankedTensorType getChecked(Type elementType, Location location);
/// Verify the construction of a unranked tensor type.
static LogicalResult
verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *context, Type elementType);
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *context,
Type elementType);
ArrayRef<int64_t> getShape() const { return llvm::None; }

View File

@ -56,7 +56,7 @@ struct OpaqueTypeStorage;
///
/// * Optional:
/// - static LogicalResult verifyConstructionInvariants(
/// llvm::Optional<Location> loc,
/// Optional<Location> loc,
/// MLIRContext *context,
/// Args... args)
/// * This method is invoked when calling the 'TypeBase::get/getChecked'
@ -250,9 +250,9 @@ public:
StringRef getTypeData() const;
/// Verify the construction of an opaque type.
static LogicalResult
verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *context, Identifier dialect,
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *context,
Identifier dialect,
StringRef typeData);
static bool kindof(unsigned kind) { return kind == Kind::Opaque; }

View File

@ -33,28 +33,20 @@ unsigned QuantizedType::getFlags() const {
}
LogicalResult QuantizedType::verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
Optional<Location> loc, MLIRContext *context, unsigned flags,
Type storageType, Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax) {
// Verify that the storage type is integral.
// This restriction may be lifted at some point in favor of using bf16
// or f16 as exact representations on hardware where that is advantageous.
auto intStorageType = storageType.dyn_cast<IntegerType>();
if (!intStorageType) {
if (loc) {
emitError(*loc, "storage type must be integral");
}
return failure();
}
if (!intStorageType)
return emitOptionalError(loc, "storage type must be integral");
unsigned integralWidth = intStorageType.getWidth();
// Verify storage width.
if (integralWidth == 0 || integralWidth > MaxStorageBits) {
if (loc) {
emitError(*loc, "illegal storage type size: ") << integralWidth;
}
return failure();
}
if (integralWidth == 0 || integralWidth > MaxStorageBits)
return emitOptionalError(loc, "illegal storage type size: ", integralWidth);
// Verify storageTypeMin and storageTypeMax.
bool isSigned =
@ -66,11 +58,8 @@ LogicalResult QuantizedType::verifyConstructionInvariants(
if (storageTypeMax - storageTypeMin <= 0 ||
storageTypeMin < defaultIntegerMin ||
storageTypeMax > defaultIntegerMax) {
if (loc) {
emitError(*loc, "illegal storage min and storage max: (")
<< storageTypeMin << ":" << storageTypeMax << ")";
}
return failure();
return emitOptionalError(loc, "illegal storage min and storage max: (",
storageTypeMin, ":", storageTypeMax, ")");
}
return success();
}
@ -235,7 +224,7 @@ AnyQuantizedType AnyQuantizedType::getChecked(unsigned flags, Type storageType,
}
LogicalResult AnyQuantizedType::verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
Optional<Location> loc, MLIRContext *context, unsigned flags,
Type storageType, Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax) {
if (failed(QuantizedType::verifyConstructionInvariants(
@ -247,12 +236,8 @@ LogicalResult AnyQuantizedType::verifyConstructionInvariants(
// Verify that the expressed type is floating point.
// If this restriction is ever eliminated, the parser/printer must be
// extended.
if (expressedType && !expressedType.isa<FloatType>()) {
if (loc) {
emitError(*loc, "expressed type must be floating point");
}
return failure();
}
if (expressedType && !expressedType.isa<FloatType>())
return emitOptionalError(loc, "expressed type must be floating point");
return success();
}
@ -280,7 +265,7 @@ UniformQuantizedType::getChecked(unsigned flags, Type storageType,
}
LogicalResult UniformQuantizedType::verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
Optional<Location> loc, MLIRContext *context, unsigned flags,
Type storageType, Type expressedType, double scale, int64_t zeroPoint,
int64_t storageTypeMin, int64_t storageTypeMax) {
if (failed(QuantizedType::verifyConstructionInvariants(
@ -291,30 +276,19 @@ LogicalResult UniformQuantizedType::verifyConstructionInvariants(
// Uniform quantization requires fully expressed parameters, including
// expressed type.
if (!expressedType) {
if (loc) {
emitError(*loc, "uniform quantization requires expressed type");
}
return failure();
}
if (!expressedType)
return emitOptionalError(loc,
"uniform quantization requires expressed type");
// Verify that the expressed type is floating point.
// If this restriction is ever eliminated, the parser/printer must be
// extended.
if (!expressedType.isa<FloatType>()) {
if (loc) {
emitError(*loc, "expressed type must be floating point");
}
return failure();
}
if (!expressedType.isa<FloatType>())
return emitOptionalError(loc, "expressed type must be floating point");
// Verify scale.
if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale)) {
if (loc) {
emitError(*loc) << "illegal scale: " << scale;
}
return failure();
}
if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
return emitOptionalError(loc, "illegal scale: ", scale);
return success();
}
@ -348,7 +322,7 @@ UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
}
LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
Optional<Location> loc, MLIRContext *context, unsigned flags,
Type storageType, Type expressedType, ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax) {
@ -360,40 +334,25 @@ LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants(
// Uniform quantization requires fully expressed parameters, including
// expressed type.
if (!expressedType) {
if (loc) {
emitError(*loc, "uniform quantization requires expressed type");
}
return failure();
}
if (!expressedType)
return emitOptionalError(loc,
"uniform quantization requires expressed type");
// Verify that the expressed type is floating point.
// If this restriction is ever eliminated, the parser/printer must be
// extended.
if (!expressedType.isa<FloatType>()) {
if (loc) {
emitError(*loc, "expressed type must be floating point");
}
return failure();
}
if (!expressedType.isa<FloatType>())
return emitOptionalError(loc, "expressed type must be floating point");
// Ensure that the number of scales and zeroPoints match.
if (scales.size() != zeroPoints.size()) {
if (loc) {
emitError(*loc, "illegal number of scales and zeroPoints: ")
<< scales.size() << ", " << zeroPoints.size();
}
return failure();
}
if (scales.size() != zeroPoints.size())
return emitOptionalError(loc, "illegal number of scales and zeroPoints: ",
scales.size(), ", ", zeroPoints.size());
// Verify scale.
for (double scale : scales) {
if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale)) {
if (loc) {
emitError(*loc) << "illegal scale: " << scale;
}
return failure();
}
if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
return emitOptionalError(loc, "illegal scale: ", scale);
}
return success();

View File

@ -214,24 +214,22 @@ double FloatAttr::getValueAsDouble(APFloat value) {
}
/// Verify construction invariants.
static LogicalResult verifyFloatTypeInvariants(llvm::Optional<Location> loc,
static LogicalResult verifyFloatTypeInvariants(Optional<Location> loc,
Type type) {
if (!type.isa<FloatType>()) {
if (loc)
emitError(*loc, "expected floating point type");
return failure();
}
if (!type.isa<FloatType>())
return emitOptionalError(loc, "expected floating point type");
return success();
}
LogicalResult FloatAttr::verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *ctx, Type type, double value) {
LogicalResult FloatAttr::verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *ctx,
Type type, double value) {
return verifyFloatTypeInvariants(loc, type);
}
LogicalResult
FloatAttr::verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *ctx, Type type,
LogicalResult FloatAttr::verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *ctx,
Type type,
const APFloat &value) {
// Verify that the type is correct.
if (failed(verifyFloatTypeInvariants(loc, type)))
@ -239,10 +237,8 @@ FloatAttr::verifyConstructionInvariants(llvm::Optional<Location> loc,
// Verify that the type semantics match that of the value.
if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
if (loc)
emitError(*loc,
"FloatAttr type doesn't match the type implied by its value");
return failure();
return emitOptionalError(
loc, "FloatAttr type doesn't match the type implied by its value");
}
return success();
}
@ -330,14 +326,13 @@ Identifier OpaqueAttr::getDialectNamespace() const {
StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; }
/// Verify the construction of an opaque attribute.
LogicalResult OpaqueAttr::verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, Identifier dialect,
StringRef attrData, Type type) {
if (!Dialect::isValidNamespace(dialect.strref())) {
if (loc)
emitError(*loc) << "invalid dialect namespace '" << dialect << "'";
return failure();
}
LogicalResult OpaqueAttr::verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *context,
Identifier dialect,
StringRef attrData,
Type type) {
if (!Dialect::isValidNamespace(dialect.strref()))
return emitOptionalError(loc, "invalid dialect namespace '", dialect, "'");
return success();
}

View File

@ -61,13 +61,12 @@ bool Type::isIntOrFloat() { return isa<IntegerType>() || isa<FloatType>(); }
constexpr unsigned IntegerType::kMaxWidth;
/// Verify the construction of an integer type.
LogicalResult IntegerType::verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, unsigned width) {
LogicalResult IntegerType::verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *context,
unsigned width) {
if (width > IntegerType::kMaxWidth) {
if (loc)
emitError(*loc) << "integer bitwidth is limited to "
<< IntegerType::kMaxWidth << " bits";
return failure();
return emitOptionalError(loc, "integer bitwidth is limited to ",
IntegerType::kMaxWidth, " bits");
}
return success();
}
@ -213,26 +212,21 @@ VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType,
StandardTypes::Vector, shape, elementType);
}
LogicalResult VectorType::verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, ArrayRef<int64_t> shape,
LogicalResult VectorType::verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *context,
ArrayRef<int64_t> shape,
Type elementType) {
if (shape.empty()) {
if (loc)
emitError(*loc, "vector types must have at least one dimension");
return failure();
}
if (shape.empty())
return emitOptionalError(loc,
"vector types must have at least one dimension");
if (!isValidElementType(elementType)) {
if (loc)
emitError(*loc, "vector elements must be int or float type");
return failure();
}
if (!isValidElementType(elementType))
return emitOptionalError(loc, "vector elements must be int or float type");
if (any_of(shape, [](int64_t i) { return i <= 0; }))
return emitOptionalError(loc,
"vector types must have positive constant sizes");
if (any_of(shape, [](int64_t i) { return i <= 0; })) {
if (loc)
emitError(*loc, "vector types must have positive constant sizes");
return failure();
}
return success();
}
@ -247,11 +241,8 @@ ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
static inline LogicalResult checkTensorElementType(Optional<Location> location,
MLIRContext *context,
Type elementType) {
if (!TensorType::isValidElementType(elementType)) {
if (location)
emitError(*location, "invalid tensor element type");
return failure();
}
if (!TensorType::isValidElementType(elementType))
return emitOptionalError(location, "invalid tensor element type");
return success();
}
@ -273,14 +264,11 @@ RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape,
}
LogicalResult RankedTensorType::verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, ArrayRef<int64_t> shape,
Optional<Location> loc, MLIRContext *context, ArrayRef<int64_t> shape,
Type elementType) {
for (int64_t s : shape) {
if (s < -1) {
if (loc)
emitError(*loc, "invalid tensor dimension size");
return failure();
}
if (s < -1)
return emitOptionalError(loc, "invalid tensor dimension size");
}
return checkTensorElementType(loc, context, elementType);
}
@ -305,7 +293,7 @@ UnrankedTensorType UnrankedTensorType::getChecked(Type elementType,
}
LogicalResult UnrankedTensorType::verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, Type elementType) {
Optional<Location> loc, MLIRContext *context, Type elementType) {
return checkTensorElementType(loc, context, elementType);
}
@ -350,19 +338,14 @@ MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
auto *context = elementType.getContext();
// Check that memref is formed from allowed types.
if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>()) {
if (location)
emitError(*location, "invalid memref element type");
return nullptr;
}
if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>())
return emitOptionalError(location, "invalid memref element type"),
MemRefType();
for (int64_t s : shape) {
// Negative sizes are not allowed except for `-1` that means dynamic size.
if (s < -1) {
if (location)
emitError(*location, "invalid memref size");
return {};
}
if (s < -1)
return emitOptionalError(location, "invalid memref size"), MemRefType();
}
// Check that the structure of the composition is valid, i.e. that each
@ -631,11 +614,8 @@ ComplexType ComplexType::getChecked(Type elementType, Location location) {
/// Verify the construction of an integer type.
LogicalResult ComplexType::verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, Type elementType) {
if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>()) {
if (loc)
emitError(*loc, "invalid element type for complex");
return failure();
}
if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>())
return emitOptionalError(loc, "invalid element type for complex");
return success();
}

View File

@ -80,13 +80,11 @@ Identifier OpaqueType::getDialectNamespace() const {
StringRef OpaqueType::getTypeData() const { return getImpl()->typeData; }
/// Verify the construction of an opaque type.
LogicalResult OpaqueType::verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, Identifier dialect,
LogicalResult OpaqueType::verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *context,
Identifier dialect,
StringRef typeData) {
if (!Dialect::isValidNamespace(dialect.strref())) {
if (loc)
emitError(*loc) << "invalid dialect namespace '" << dialect << "'";
return failure();
}
if (!Dialect::isValidNamespace(dialect.strref()))
return emitOptionalError(loc, "invalid dialect namespace '", dialect, "'");
return success();
}