forked from OSchip/llvm-project
Add emitOptional(Error|Warning|Remark) functions to simplify emission with an optional location.
In some situations a diagnostic may optionally be emitted by the presence of a location, e.g. attribute and type verification. These situations currently require extra 'if(loc) emitError(...); return failure()' wrappers that make verification clunky. These new overloads take an optional location and a list of arguments to the diagnostic, and return a LogicalResult. We take the arguments directly and return LogicalResult instead of returning InFlightDiagnostic because we cannot create a valid diagnostic with a null location. This creates an awkward situation where a user may try to treat the, potentially null, diagnostic as a valid one and encounter crashes when attaching notes/etc. Below is an example of how these methods simplify some existing usages: Before: if (loc) emitError(*loc, "this is my diagnostic with argument: ") << 5; return failure(); After: return emitOptionalError(loc, "this is my diagnostic with argument: ", 5); PiperOrigin-RevId: 283853599
This commit is contained in:
parent
b3f7cf80a7
commit
2c930f8d9d
|
@ -75,10 +75,10 @@ public:
|
|||
static constexpr unsigned MaxStorageBits = 32;
|
||||
|
||||
static LogicalResult
|
||||
verifyConstructionInvariants(llvm::Optional<Location> loc,
|
||||
MLIRContext *context, unsigned flags,
|
||||
Type storageType, Type expressedType,
|
||||
int64_t storageTypeMin, int64_t storageTypeMax);
|
||||
verifyConstructionInvariants(Optional<Location> loc, MLIRContext *context,
|
||||
unsigned flags, Type storageType,
|
||||
Type expressedType, int64_t storageTypeMin,
|
||||
int64_t storageTypeMax);
|
||||
|
||||
/// Support method to enable LLVM-style type casting.
|
||||
static bool classof(Type type) {
|
||||
|
@ -238,10 +238,10 @@ public:
|
|||
|
||||
/// Verifies construction invariants and issues errors/warnings.
|
||||
static LogicalResult
|
||||
verifyConstructionInvariants(llvm::Optional<Location> loc,
|
||||
MLIRContext *context, unsigned flags,
|
||||
Type storageType, Type expressedType,
|
||||
int64_t storageTypeMin, int64_t storageTypeMax);
|
||||
verifyConstructionInvariants(Optional<Location> loc, MLIRContext *context,
|
||||
unsigned flags, Type storageType,
|
||||
Type expressedType, int64_t storageTypeMin,
|
||||
int64_t storageTypeMax);
|
||||
};
|
||||
|
||||
/// Represents a family of uniform, quantized types.
|
||||
|
@ -298,7 +298,7 @@ public:
|
|||
|
||||
/// Verifies construction invariants and issues errors/warnings.
|
||||
static LogicalResult verifyConstructionInvariants(
|
||||
llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
|
||||
Optional<Location> loc, MLIRContext *context, unsigned flags,
|
||||
Type storageType, Type expressedType, double scale, int64_t zeroPoint,
|
||||
int64_t storageTypeMin, int64_t storageTypeMax);
|
||||
|
||||
|
|
|
@ -321,12 +321,12 @@ public:
|
|||
}
|
||||
|
||||
/// Verify the construction invariants for a double value.
|
||||
static LogicalResult
|
||||
verifyConstructionInvariants(llvm::Optional<Location> loc, MLIRContext *ctx,
|
||||
Type type, double value);
|
||||
static LogicalResult
|
||||
verifyConstructionInvariants(llvm::Optional<Location> loc, MLIRContext *ctx,
|
||||
Type type, const APFloat &value);
|
||||
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
|
||||
MLIRContext *ctx, Type type,
|
||||
double value);
|
||||
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
|
||||
MLIRContext *ctx, Type type,
|
||||
const APFloat &value);
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -403,10 +403,11 @@ public:
|
|||
StringRef getAttrData() const;
|
||||
|
||||
/// Verify the construction of an opaque attribute.
|
||||
static LogicalResult
|
||||
verifyConstructionInvariants(llvm::Optional<Location> loc,
|
||||
MLIRContext *context, Identifier dialect,
|
||||
StringRef attrData, Type type);
|
||||
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
|
||||
MLIRContext *context,
|
||||
Identifier dialect,
|
||||
StringRef attrData,
|
||||
Type type);
|
||||
|
||||
static bool kindof(unsigned kind) {
|
||||
return kind == StandardAttributes::Opaque;
|
||||
|
|
|
@ -481,6 +481,30 @@ InFlightDiagnostic emitWarning(Location loc, const Twine &message);
|
|||
InFlightDiagnostic emitRemark(Location loc);
|
||||
InFlightDiagnostic emitRemark(Location loc, const Twine &message);
|
||||
|
||||
/// Overloads of the above emission functions that take an optionally null
|
||||
/// location. If the location is null, no diagnostic is emitted and a failure is
|
||||
/// returned. Given that the provided location may be null, these methods take
|
||||
/// the diagnostic arguments directly instead of relying on the returned
|
||||
/// InFlightDiagnostic.
|
||||
template <typename... Args>
|
||||
LogicalResult emitOptionalError(Optional<Location> loc, Args &&... args) {
|
||||
if (loc)
|
||||
return emitError(*loc).append(std::forward<Args>(args)...);
|
||||
return failure();
|
||||
}
|
||||
template <typename... Args>
|
||||
LogicalResult emitOptionalWarning(Optional<Location> loc, Args &&... args) {
|
||||
if (loc)
|
||||
return emitWarning(*loc).append(std::forward<Args>(args)...);
|
||||
return failure();
|
||||
}
|
||||
template <typename... Args>
|
||||
LogicalResult emitOptionalRemark(Optional<Location> loc, Args &&... args) {
|
||||
if (loc)
|
||||
return emitRemark(*loc).append(std::forward<Args>(args)...);
|
||||
return failure();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ScopedDiagnosticHandler
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -102,9 +102,9 @@ public:
|
|||
Location location);
|
||||
|
||||
/// Verify the construction of an integer type.
|
||||
static LogicalResult
|
||||
verifyConstructionInvariants(llvm::Optional<Location> loc,
|
||||
MLIRContext *context, unsigned width);
|
||||
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
|
||||
MLIRContext *context,
|
||||
unsigned width);
|
||||
|
||||
/// Return the bitwidth of this integer type.
|
||||
unsigned getWidth() const;
|
||||
|
@ -168,9 +168,9 @@ public:
|
|||
static ComplexType getChecked(Type elementType, Location location);
|
||||
|
||||
/// Verify the construction of an integer type.
|
||||
static LogicalResult
|
||||
verifyConstructionInvariants(llvm::Optional<Location> loc,
|
||||
MLIRContext *context, Type elementType);
|
||||
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
|
||||
MLIRContext *context,
|
||||
Type elementType);
|
||||
|
||||
Type getElementType();
|
||||
|
||||
|
@ -269,9 +269,9 @@ public:
|
|||
Location location);
|
||||
|
||||
/// Verify the construction of a vector type.
|
||||
static LogicalResult
|
||||
verifyConstructionInvariants(llvm::Optional<Location> loc,
|
||||
MLIRContext *context, ArrayRef<int64_t> shape,
|
||||
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
|
||||
MLIRContext *context,
|
||||
ArrayRef<int64_t> shape,
|
||||
Type elementType);
|
||||
|
||||
/// Returns true of the given type can be used as an element of a vector type.
|
||||
|
@ -328,9 +328,9 @@ public:
|
|||
Location location);
|
||||
|
||||
/// Verify the construction of a ranked tensor type.
|
||||
static LogicalResult
|
||||
verifyConstructionInvariants(llvm::Optional<Location> loc,
|
||||
MLIRContext *context, ArrayRef<int64_t> shape,
|
||||
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
|
||||
MLIRContext *context,
|
||||
ArrayRef<int64_t> shape,
|
||||
Type elementType);
|
||||
|
||||
ArrayRef<int64_t> getShape() const;
|
||||
|
@ -359,9 +359,9 @@ public:
|
|||
static UnrankedTensorType getChecked(Type elementType, Location location);
|
||||
|
||||
/// Verify the construction of a unranked tensor type.
|
||||
static LogicalResult
|
||||
verifyConstructionInvariants(llvm::Optional<Location> loc,
|
||||
MLIRContext *context, Type elementType);
|
||||
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
|
||||
MLIRContext *context,
|
||||
Type elementType);
|
||||
|
||||
ArrayRef<int64_t> getShape() const { return llvm::None; }
|
||||
|
||||
|
|
|
@ -56,7 +56,7 @@ struct OpaqueTypeStorage;
|
|||
///
|
||||
/// * Optional:
|
||||
/// - static LogicalResult verifyConstructionInvariants(
|
||||
/// llvm::Optional<Location> loc,
|
||||
/// Optional<Location> loc,
|
||||
/// MLIRContext *context,
|
||||
/// Args... args)
|
||||
/// * This method is invoked when calling the 'TypeBase::get/getChecked'
|
||||
|
@ -250,9 +250,9 @@ public:
|
|||
StringRef getTypeData() const;
|
||||
|
||||
/// Verify the construction of an opaque type.
|
||||
static LogicalResult
|
||||
verifyConstructionInvariants(llvm::Optional<Location> loc,
|
||||
MLIRContext *context, Identifier dialect,
|
||||
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
|
||||
MLIRContext *context,
|
||||
Identifier dialect,
|
||||
StringRef typeData);
|
||||
|
||||
static bool kindof(unsigned kind) { return kind == Kind::Opaque; }
|
||||
|
|
|
@ -33,28 +33,20 @@ unsigned QuantizedType::getFlags() const {
|
|||
}
|
||||
|
||||
LogicalResult QuantizedType::verifyConstructionInvariants(
|
||||
llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
|
||||
Optional<Location> loc, MLIRContext *context, unsigned flags,
|
||||
Type storageType, Type expressedType, int64_t storageTypeMin,
|
||||
int64_t storageTypeMax) {
|
||||
// Verify that the storage type is integral.
|
||||
// This restriction may be lifted at some point in favor of using bf16
|
||||
// or f16 as exact representations on hardware where that is advantageous.
|
||||
auto intStorageType = storageType.dyn_cast<IntegerType>();
|
||||
if (!intStorageType) {
|
||||
if (loc) {
|
||||
emitError(*loc, "storage type must be integral");
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
if (!intStorageType)
|
||||
return emitOptionalError(loc, "storage type must be integral");
|
||||
unsigned integralWidth = intStorageType.getWidth();
|
||||
|
||||
// Verify storage width.
|
||||
if (integralWidth == 0 || integralWidth > MaxStorageBits) {
|
||||
if (loc) {
|
||||
emitError(*loc, "illegal storage type size: ") << integralWidth;
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
if (integralWidth == 0 || integralWidth > MaxStorageBits)
|
||||
return emitOptionalError(loc, "illegal storage type size: ", integralWidth);
|
||||
|
||||
// Verify storageTypeMin and storageTypeMax.
|
||||
bool isSigned =
|
||||
|
@ -66,11 +58,8 @@ LogicalResult QuantizedType::verifyConstructionInvariants(
|
|||
if (storageTypeMax - storageTypeMin <= 0 ||
|
||||
storageTypeMin < defaultIntegerMin ||
|
||||
storageTypeMax > defaultIntegerMax) {
|
||||
if (loc) {
|
||||
emitError(*loc, "illegal storage min and storage max: (")
|
||||
<< storageTypeMin << ":" << storageTypeMax << ")";
|
||||
}
|
||||
return failure();
|
||||
return emitOptionalError(loc, "illegal storage min and storage max: (",
|
||||
storageTypeMin, ":", storageTypeMax, ")");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
@ -235,7 +224,7 @@ AnyQuantizedType AnyQuantizedType::getChecked(unsigned flags, Type storageType,
|
|||
}
|
||||
|
||||
LogicalResult AnyQuantizedType::verifyConstructionInvariants(
|
||||
llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
|
||||
Optional<Location> loc, MLIRContext *context, unsigned flags,
|
||||
Type storageType, Type expressedType, int64_t storageTypeMin,
|
||||
int64_t storageTypeMax) {
|
||||
if (failed(QuantizedType::verifyConstructionInvariants(
|
||||
|
@ -247,12 +236,8 @@ LogicalResult AnyQuantizedType::verifyConstructionInvariants(
|
|||
// Verify that the expressed type is floating point.
|
||||
// If this restriction is ever eliminated, the parser/printer must be
|
||||
// extended.
|
||||
if (expressedType && !expressedType.isa<FloatType>()) {
|
||||
if (loc) {
|
||||
emitError(*loc, "expressed type must be floating point");
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
if (expressedType && !expressedType.isa<FloatType>())
|
||||
return emitOptionalError(loc, "expressed type must be floating point");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -280,7 +265,7 @@ UniformQuantizedType::getChecked(unsigned flags, Type storageType,
|
|||
}
|
||||
|
||||
LogicalResult UniformQuantizedType::verifyConstructionInvariants(
|
||||
llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
|
||||
Optional<Location> loc, MLIRContext *context, unsigned flags,
|
||||
Type storageType, Type expressedType, double scale, int64_t zeroPoint,
|
||||
int64_t storageTypeMin, int64_t storageTypeMax) {
|
||||
if (failed(QuantizedType::verifyConstructionInvariants(
|
||||
|
@ -291,30 +276,19 @@ LogicalResult UniformQuantizedType::verifyConstructionInvariants(
|
|||
|
||||
// Uniform quantization requires fully expressed parameters, including
|
||||
// expressed type.
|
||||
if (!expressedType) {
|
||||
if (loc) {
|
||||
emitError(*loc, "uniform quantization requires expressed type");
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
if (!expressedType)
|
||||
return emitOptionalError(loc,
|
||||
"uniform quantization requires expressed type");
|
||||
|
||||
// Verify that the expressed type is floating point.
|
||||
// If this restriction is ever eliminated, the parser/printer must be
|
||||
// extended.
|
||||
if (!expressedType.isa<FloatType>()) {
|
||||
if (loc) {
|
||||
emitError(*loc, "expressed type must be floating point");
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
if (!expressedType.isa<FloatType>())
|
||||
return emitOptionalError(loc, "expressed type must be floating point");
|
||||
|
||||
// Verify scale.
|
||||
if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale)) {
|
||||
if (loc) {
|
||||
emitError(*loc) << "illegal scale: " << scale;
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
|
||||
return emitOptionalError(loc, "illegal scale: ", scale);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -348,7 +322,7 @@ UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
|
|||
}
|
||||
|
||||
LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants(
|
||||
llvm::Optional<Location> loc, MLIRContext *context, unsigned flags,
|
||||
Optional<Location> loc, MLIRContext *context, unsigned flags,
|
||||
Type storageType, Type expressedType, ArrayRef<double> scales,
|
||||
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
|
||||
int64_t storageTypeMin, int64_t storageTypeMax) {
|
||||
|
@ -360,40 +334,25 @@ LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants(
|
|||
|
||||
// Uniform quantization requires fully expressed parameters, including
|
||||
// expressed type.
|
||||
if (!expressedType) {
|
||||
if (loc) {
|
||||
emitError(*loc, "uniform quantization requires expressed type");
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
if (!expressedType)
|
||||
return emitOptionalError(loc,
|
||||
"uniform quantization requires expressed type");
|
||||
|
||||
// Verify that the expressed type is floating point.
|
||||
// If this restriction is ever eliminated, the parser/printer must be
|
||||
// extended.
|
||||
if (!expressedType.isa<FloatType>()) {
|
||||
if (loc) {
|
||||
emitError(*loc, "expressed type must be floating point");
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
if (!expressedType.isa<FloatType>())
|
||||
return emitOptionalError(loc, "expressed type must be floating point");
|
||||
|
||||
// Ensure that the number of scales and zeroPoints match.
|
||||
if (scales.size() != zeroPoints.size()) {
|
||||
if (loc) {
|
||||
emitError(*loc, "illegal number of scales and zeroPoints: ")
|
||||
<< scales.size() << ", " << zeroPoints.size();
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
if (scales.size() != zeroPoints.size())
|
||||
return emitOptionalError(loc, "illegal number of scales and zeroPoints: ",
|
||||
scales.size(), ", ", zeroPoints.size());
|
||||
|
||||
// Verify scale.
|
||||
for (double scale : scales) {
|
||||
if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale)) {
|
||||
if (loc) {
|
||||
emitError(*loc) << "illegal scale: " << scale;
|
||||
}
|
||||
return failure();
|
||||
}
|
||||
if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
|
||||
return emitOptionalError(loc, "illegal scale: ", scale);
|
||||
}
|
||||
|
||||
return success();
|
||||
|
|
|
@ -214,24 +214,22 @@ double FloatAttr::getValueAsDouble(APFloat value) {
|
|||
}
|
||||
|
||||
/// Verify construction invariants.
|
||||
static LogicalResult verifyFloatTypeInvariants(llvm::Optional<Location> loc,
|
||||
static LogicalResult verifyFloatTypeInvariants(Optional<Location> loc,
|
||||
Type type) {
|
||||
if (!type.isa<FloatType>()) {
|
||||
if (loc)
|
||||
emitError(*loc, "expected floating point type");
|
||||
return failure();
|
||||
}
|
||||
if (!type.isa<FloatType>())
|
||||
return emitOptionalError(loc, "expected floating point type");
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult FloatAttr::verifyConstructionInvariants(
|
||||
llvm::Optional<Location> loc, MLIRContext *ctx, Type type, double value) {
|
||||
LogicalResult FloatAttr::verifyConstructionInvariants(Optional<Location> loc,
|
||||
MLIRContext *ctx,
|
||||
Type type, double value) {
|
||||
return verifyFloatTypeInvariants(loc, type);
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
FloatAttr::verifyConstructionInvariants(llvm::Optional<Location> loc,
|
||||
MLIRContext *ctx, Type type,
|
||||
LogicalResult FloatAttr::verifyConstructionInvariants(Optional<Location> loc,
|
||||
MLIRContext *ctx,
|
||||
Type type,
|
||||
const APFloat &value) {
|
||||
// Verify that the type is correct.
|
||||
if (failed(verifyFloatTypeInvariants(loc, type)))
|
||||
|
@ -239,10 +237,8 @@ FloatAttr::verifyConstructionInvariants(llvm::Optional<Location> loc,
|
|||
|
||||
// Verify that the type semantics match that of the value.
|
||||
if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
|
||||
if (loc)
|
||||
emitError(*loc,
|
||||
"FloatAttr type doesn't match the type implied by its value");
|
||||
return failure();
|
||||
return emitOptionalError(
|
||||
loc, "FloatAttr type doesn't match the type implied by its value");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
@ -330,14 +326,13 @@ Identifier OpaqueAttr::getDialectNamespace() const {
|
|||
StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; }
|
||||
|
||||
/// Verify the construction of an opaque attribute.
|
||||
LogicalResult OpaqueAttr::verifyConstructionInvariants(
|
||||
llvm::Optional<Location> loc, MLIRContext *context, Identifier dialect,
|
||||
StringRef attrData, Type type) {
|
||||
if (!Dialect::isValidNamespace(dialect.strref())) {
|
||||
if (loc)
|
||||
emitError(*loc) << "invalid dialect namespace '" << dialect << "'";
|
||||
return failure();
|
||||
}
|
||||
LogicalResult OpaqueAttr::verifyConstructionInvariants(Optional<Location> loc,
|
||||
MLIRContext *context,
|
||||
Identifier dialect,
|
||||
StringRef attrData,
|
||||
Type type) {
|
||||
if (!Dialect::isValidNamespace(dialect.strref()))
|
||||
return emitOptionalError(loc, "invalid dialect namespace '", dialect, "'");
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -61,13 +61,12 @@ bool Type::isIntOrFloat() { return isa<IntegerType>() || isa<FloatType>(); }
|
|||
constexpr unsigned IntegerType::kMaxWidth;
|
||||
|
||||
/// Verify the construction of an integer type.
|
||||
LogicalResult IntegerType::verifyConstructionInvariants(
|
||||
llvm::Optional<Location> loc, MLIRContext *context, unsigned width) {
|
||||
LogicalResult IntegerType::verifyConstructionInvariants(Optional<Location> loc,
|
||||
MLIRContext *context,
|
||||
unsigned width) {
|
||||
if (width > IntegerType::kMaxWidth) {
|
||||
if (loc)
|
||||
emitError(*loc) << "integer bitwidth is limited to "
|
||||
<< IntegerType::kMaxWidth << " bits";
|
||||
return failure();
|
||||
return emitOptionalError(loc, "integer bitwidth is limited to ",
|
||||
IntegerType::kMaxWidth, " bits");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
@ -213,26 +212,21 @@ VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType,
|
|||
StandardTypes::Vector, shape, elementType);
|
||||
}
|
||||
|
||||
LogicalResult VectorType::verifyConstructionInvariants(
|
||||
llvm::Optional<Location> loc, MLIRContext *context, ArrayRef<int64_t> shape,
|
||||
LogicalResult VectorType::verifyConstructionInvariants(Optional<Location> loc,
|
||||
MLIRContext *context,
|
||||
ArrayRef<int64_t> shape,
|
||||
Type elementType) {
|
||||
if (shape.empty()) {
|
||||
if (loc)
|
||||
emitError(*loc, "vector types must have at least one dimension");
|
||||
return failure();
|
||||
}
|
||||
if (shape.empty())
|
||||
return emitOptionalError(loc,
|
||||
"vector types must have at least one dimension");
|
||||
|
||||
if (!isValidElementType(elementType)) {
|
||||
if (loc)
|
||||
emitError(*loc, "vector elements must be int or float type");
|
||||
return failure();
|
||||
}
|
||||
if (!isValidElementType(elementType))
|
||||
return emitOptionalError(loc, "vector elements must be int or float type");
|
||||
|
||||
if (any_of(shape, [](int64_t i) { return i <= 0; }))
|
||||
return emitOptionalError(loc,
|
||||
"vector types must have positive constant sizes");
|
||||
|
||||
if (any_of(shape, [](int64_t i) { return i <= 0; })) {
|
||||
if (loc)
|
||||
emitError(*loc, "vector types must have positive constant sizes");
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -247,11 +241,8 @@ ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
|
|||
static inline LogicalResult checkTensorElementType(Optional<Location> location,
|
||||
MLIRContext *context,
|
||||
Type elementType) {
|
||||
if (!TensorType::isValidElementType(elementType)) {
|
||||
if (location)
|
||||
emitError(*location, "invalid tensor element type");
|
||||
return failure();
|
||||
}
|
||||
if (!TensorType::isValidElementType(elementType))
|
||||
return emitOptionalError(location, "invalid tensor element type");
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -273,14 +264,11 @@ RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape,
|
|||
}
|
||||
|
||||
LogicalResult RankedTensorType::verifyConstructionInvariants(
|
||||
llvm::Optional<Location> loc, MLIRContext *context, ArrayRef<int64_t> shape,
|
||||
Optional<Location> loc, MLIRContext *context, ArrayRef<int64_t> shape,
|
||||
Type elementType) {
|
||||
for (int64_t s : shape) {
|
||||
if (s < -1) {
|
||||
if (loc)
|
||||
emitError(*loc, "invalid tensor dimension size");
|
||||
return failure();
|
||||
}
|
||||
if (s < -1)
|
||||
return emitOptionalError(loc, "invalid tensor dimension size");
|
||||
}
|
||||
return checkTensorElementType(loc, context, elementType);
|
||||
}
|
||||
|
@ -305,7 +293,7 @@ UnrankedTensorType UnrankedTensorType::getChecked(Type elementType,
|
|||
}
|
||||
|
||||
LogicalResult UnrankedTensorType::verifyConstructionInvariants(
|
||||
llvm::Optional<Location> loc, MLIRContext *context, Type elementType) {
|
||||
Optional<Location> loc, MLIRContext *context, Type elementType) {
|
||||
return checkTensorElementType(loc, context, elementType);
|
||||
}
|
||||
|
||||
|
@ -350,19 +338,14 @@ MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
|
|||
auto *context = elementType.getContext();
|
||||
|
||||
// Check that memref is formed from allowed types.
|
||||
if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>()) {
|
||||
if (location)
|
||||
emitError(*location, "invalid memref element type");
|
||||
return nullptr;
|
||||
}
|
||||
if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>())
|
||||
return emitOptionalError(location, "invalid memref element type"),
|
||||
MemRefType();
|
||||
|
||||
for (int64_t s : shape) {
|
||||
// Negative sizes are not allowed except for `-1` that means dynamic size.
|
||||
if (s < -1) {
|
||||
if (location)
|
||||
emitError(*location, "invalid memref size");
|
||||
return {};
|
||||
}
|
||||
if (s < -1)
|
||||
return emitOptionalError(location, "invalid memref size"), MemRefType();
|
||||
}
|
||||
|
||||
// Check that the structure of the composition is valid, i.e. that each
|
||||
|
@ -631,11 +614,8 @@ ComplexType ComplexType::getChecked(Type elementType, Location location) {
|
|||
/// Verify the construction of an integer type.
|
||||
LogicalResult ComplexType::verifyConstructionInvariants(
|
||||
llvm::Optional<Location> loc, MLIRContext *context, Type elementType) {
|
||||
if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>()) {
|
||||
if (loc)
|
||||
emitError(*loc, "invalid element type for complex");
|
||||
return failure();
|
||||
}
|
||||
if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>())
|
||||
return emitOptionalError(loc, "invalid element type for complex");
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -80,13 +80,11 @@ Identifier OpaqueType::getDialectNamespace() const {
|
|||
StringRef OpaqueType::getTypeData() const { return getImpl()->typeData; }
|
||||
|
||||
/// Verify the construction of an opaque type.
|
||||
LogicalResult OpaqueType::verifyConstructionInvariants(
|
||||
llvm::Optional<Location> loc, MLIRContext *context, Identifier dialect,
|
||||
LogicalResult OpaqueType::verifyConstructionInvariants(Optional<Location> loc,
|
||||
MLIRContext *context,
|
||||
Identifier dialect,
|
||||
StringRef typeData) {
|
||||
if (!Dialect::isValidNamespace(dialect.strref())) {
|
||||
if (loc)
|
||||
emitError(*loc) << "invalid dialect namespace '" << dialect << "'";
|
||||
return failure();
|
||||
}
|
||||
if (!Dialect::isValidNamespace(dialect.strref()))
|
||||
return emitOptionalError(loc, "invalid dialect namespace '", dialect, "'");
|
||||
return success();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue