[mlir] Refactor the structure of the 'verifyConstructionInvariants' methods.

Summary:
The current structure suffers from several problems, but the main one is that a construction failure is impossible to debug when using the 'get' methods. This is because we only optionally emit errors, so there is no context given to the user about the problem. This revision restructures this so that errors are always emitted, and the 'get' methods simply pass in an UnknownLoc to emit to. This allows for removing usages of the more constrained "emitOptionalLoc", as well as removing the need for the context parameter.

Fixes [PR#44964](https://bugs.llvm.org/show_bug.cgi?id=44964)

Differential Revision: https://reviews.llvm.org/D74876
This commit is contained in:
River Riddle 2020-02-20 10:31:44 -08:00
parent 726c342ce2
commit 70d8fec7c9
14 changed files with 160 additions and 189 deletions

View File

@ -194,42 +194,34 @@ public:
/// This method is used to get an instance of the 'ComplexType'. This method /// This method is used to get an instance of the 'ComplexType'. This method
/// asserts that all of the construction invariants were satisfied. To /// asserts that all of the construction invariants were satisfied. To
/// gracefully handle failed construction, getChecked should be used instead. /// gracefully handle failed construction, getChecked should be used instead.
static ComplexType get(MLIRContext *context, unsigned param, Type type) { static ComplexType get(unsigned param, Type type) {
// Call into a helper 'get' method in 'TypeBase' to get a uniqued instance // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance
// of this type. All parameters to the storage class are passed after the // of this type. All parameters to the storage class are passed after the
// type kind. // type kind.
return Base::get(context, MyTypes::Complex, param, type); return Base::get(type.getContext(), MyTypes::Complex, param, type);
} }
/// This method is used to get an instance of the 'ComplexType', defined at /// This method is used to get an instance of the 'ComplexType', defined at
/// the given location. If any of the construction invariants are invalid, /// the given location. If any of the construction invariants are invalid,
/// errors are emitted with the provided location and a null type is returned. /// errors are emitted with the provided location and a null type is returned.
/// Note: This method is completely optional. /// Note: This method is completely optional.
static ComplexType getChecked(MLIRContext *context, unsigned param, Type type, static ComplexType getChecked(unsigned param, Type type, Location location) {
Location location) {
// Call into a helper 'getChecked' method in 'TypeBase' to get a uniqued // Call into a helper 'getChecked' method in 'TypeBase' to get a uniqued
// instance of this type. All parameters to the storage class are passed // instance of this type. All parameters to the storage class are passed
// after the type kind. // after the type kind.
return Base::getChecked(location, context, MyTypes::Complex, param, type); return Base::getChecked(location, MyTypes::Complex, param, type);
} }
/// This method is used to verify the construction invariants passed into the /// This method is used to verify the construction invariants passed into the
/// 'get' and 'getChecked' methods. Note: This method is completely optional. /// 'get' and 'getChecked' methods. Note: This method is completely optional.
static LogicalResult verifyConstructionInvariants( static LogicalResult verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, unsigned param, Location loc, unsigned param, Type type) {
Type type) {
// Our type only allows non-zero parameters. // Our type only allows non-zero parameters.
if (param == 0) { if (param == 0)
if (loc) return emitError(loc) << "non-zero parameter passed to 'ComplexType'";
context->emitError(loc) << "non-zero parameter passed to 'ComplexType'";
return failure();
}
// Our type also expects an integer type. // Our type also expects an integer type.
if (!type.isa<IntegerType>()) { if (!type.isa<IntegerType>())
if (loc) return emitError(loc) << "non integer-type passed to 'ComplexType'";
context->emitError(loc) << "non integer-type passed to 'ComplexType'";
return failure();
}
return success(); return success();
} }

View File

@ -66,8 +66,7 @@ public:
static constexpr unsigned MaxStorageBits = 32; static constexpr unsigned MaxStorageBits = 32;
static LogicalResult static LogicalResult
verifyConstructionInvariants(Optional<Location> loc, MLIRContext *context, verifyConstructionInvariants(Location loc, unsigned flags, Type storageType,
unsigned flags, Type storageType,
Type expressedType, int64_t storageTypeMin, Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax); int64_t storageTypeMax);
@ -229,8 +228,7 @@ public:
/// Verifies construction invariants and issues errors/warnings. /// Verifies construction invariants and issues errors/warnings.
static LogicalResult static LogicalResult
verifyConstructionInvariants(Optional<Location> loc, MLIRContext *context, verifyConstructionInvariants(Location loc, unsigned flags, Type storageType,
unsigned flags, Type storageType,
Type expressedType, int64_t storageTypeMin, Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax); int64_t storageTypeMax);
}; };
@ -288,10 +286,11 @@ public:
Location location); Location location);
/// Verifies construction invariants and issues errors/warnings. /// Verifies construction invariants and issues errors/warnings.
static LogicalResult verifyConstructionInvariants( static LogicalResult
Optional<Location> loc, MLIRContext *context, unsigned flags, verifyConstructionInvariants(Location loc, unsigned flags, Type storageType,
Type storageType, Type expressedType, double scale, int64_t zeroPoint, Type expressedType, double scale,
int64_t storageTypeMin, int64_t storageTypeMax); int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax);
/// Support method to enable LLVM-style type casting. /// Support method to enable LLVM-style type casting.
static bool kindof(unsigned kind) { static bool kindof(unsigned kind) {
@ -351,11 +350,12 @@ public:
int64_t storageTypeMax, Location location); int64_t storageTypeMax, Location location);
/// Verifies construction invariants and issues errors/warnings. /// Verifies construction invariants and issues errors/warnings.
static LogicalResult verifyConstructionInvariants( static LogicalResult
Optional<Location> loc, MLIRContext *context, unsigned flags, verifyConstructionInvariants(Location loc, unsigned flags, Type storageType,
Type storageType, Type expressedType, ArrayRef<double> scales, Type expressedType, ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension, ArrayRef<int64_t> zeroPoints,
int64_t storageTypeMin, int64_t storageTypeMax); int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax);
/// Support method to enable LLVM-style type casting. /// Support method to enable LLVM-style type casting.
static bool kindof(unsigned kind) { static bool kindof(unsigned kind) {

View File

@ -90,10 +90,11 @@ public:
static bool kindof(unsigned kind) { return kind == AttrKind::TargetEnv; } static bool kindof(unsigned kind) { return kind == AttrKind::TargetEnv; }
static LogicalResult static LogicalResult verifyConstructionInvariants(Location loc,
verifyConstructionInvariants(Optional<Location> loc, MLIRContext *context, IntegerAttr version,
IntegerAttr version, ArrayAttr extensions, ArrayAttr extensions,
ArrayAttr capabilities, DictionaryAttr limits); ArrayAttr capabilities,
DictionaryAttr limits);
}; };
/// Returns the attribute name for specifying argument ABI information. /// Returns the attribute name for specifying argument ABI information.

View File

@ -330,11 +330,9 @@ public:
} }
/// Verify the construction invariants for a double value. /// Verify the construction invariants for a double value.
static LogicalResult verifyConstructionInvariants(Optional<Location> loc, static LogicalResult verifyConstructionInvariants(Location loc, Type type,
MLIRContext *ctx, Type type,
double value); double value);
static LogicalResult verifyConstructionInvariants(Optional<Location> loc, static LogicalResult verifyConstructionInvariants(Location loc, Type type,
MLIRContext *ctx, Type type,
const APFloat &value); const APFloat &value);
}; };
@ -361,11 +359,9 @@ public:
return kind == StandardAttributes::Integer; return kind == StandardAttributes::Integer;
} }
static LogicalResult verifyConstructionInvariants(Optional<Location> loc, static LogicalResult verifyConstructionInvariants(Location loc, Type type,
MLIRContext *ctx, Type type,
int64_t value); int64_t value);
static LogicalResult verifyConstructionInvariants(Optional<Location> loc, static LogicalResult verifyConstructionInvariants(Location loc, Type type,
MLIRContext *ctx, Type type,
const APInt &value); const APInt &value);
}; };
@ -419,8 +415,7 @@ public:
StringRef getAttrData() const; StringRef getAttrData() const;
/// Verify the construction of an opaque attribute. /// Verify the construction of an opaque attribute.
static LogicalResult verifyConstructionInvariants(Optional<Location> loc, static LogicalResult verifyConstructionInvariants(Location loc,
MLIRContext *context,
Identifier dialect, Identifier dialect,
StringRef attrData, StringRef attrData,
Type type); Type type);

View File

@ -54,6 +54,12 @@ public:
Location(LocationAttr loc) : impl(loc) { Location(LocationAttr loc) : impl(loc) {
assert(loc && "location should never be null."); assert(loc && "location should never be null.");
} }
Location(const LocationAttr::ImplType *impl) : impl(impl) {
assert(impl && "location should never be null.");
}
/// Return the context this location is uniqued in.
MLIRContext *getContext() const { return impl.getContext(); }
/// Access the impl location attribute. /// Access the impl location attribute.
operator LocationAttr() const { return impl; } operator LocationAttr() const { return impl; }

View File

@ -96,8 +96,7 @@ public:
Location location); Location location);
/// Verify the construction of an integer type. /// Verify the construction of an integer type.
static LogicalResult verifyConstructionInvariants(Optional<Location> loc, static LogicalResult verifyConstructionInvariants(Location loc,
MLIRContext *context,
unsigned width); unsigned width);
/// Return the bitwidth of this integer type. /// Return the bitwidth of this integer type.
@ -162,8 +161,7 @@ public:
static ComplexType getChecked(Type elementType, Location location); static ComplexType getChecked(Type elementType, Location location);
/// Verify the construction of an integer type. /// Verify the construction of an integer type.
static LogicalResult verifyConstructionInvariants(Optional<Location> loc, static LogicalResult verifyConstructionInvariants(Location loc,
MLIRContext *context,
Type elementType); Type elementType);
Type getElementType(); Type getElementType();
@ -270,8 +268,7 @@ public:
Location location); Location location);
/// Verify the construction of a vector type. /// Verify the construction of a vector type.
static LogicalResult verifyConstructionInvariants(Optional<Location> loc, static LogicalResult verifyConstructionInvariants(Location loc,
MLIRContext *context,
ArrayRef<int64_t> shape, ArrayRef<int64_t> shape,
Type elementType); Type elementType);
@ -329,8 +326,7 @@ public:
Location location); Location location);
/// Verify the construction of a ranked tensor type. /// Verify the construction of a ranked tensor type.
static LogicalResult verifyConstructionInvariants(Optional<Location> loc, static LogicalResult verifyConstructionInvariants(Location loc,
MLIRContext *context,
ArrayRef<int64_t> shape, ArrayRef<int64_t> shape,
Type elementType); Type elementType);
@ -360,8 +356,7 @@ public:
static UnrankedTensorType getChecked(Type elementType, Location location); static UnrankedTensorType getChecked(Type elementType, Location location);
/// Verify the construction of a unranked tensor type. /// Verify the construction of a unranked tensor type.
static LogicalResult verifyConstructionInvariants(Optional<Location> loc, static LogicalResult verifyConstructionInvariants(Location loc,
MLIRContext *context,
Type elementType); Type elementType);
ArrayRef<int64_t> getShape() const { return llvm::None; } ArrayRef<int64_t> getShape() const { return llvm::None; }
@ -505,8 +500,7 @@ public:
Location location); Location location);
/// Verify the construction of a unranked memref type. /// Verify the construction of a unranked memref type.
static LogicalResult verifyConstructionInvariants(Optional<Location> loc, static LogicalResult verifyConstructionInvariants(Location loc,
MLIRContext *context,
Type elementType, Type elementType,
unsigned memorySpace); unsigned memorySpace);

View File

@ -18,10 +18,15 @@
#include "mlir/Support/StorageUniquer.h" #include "mlir/Support/StorageUniquer.h"
namespace mlir { namespace mlir {
class Location; class AttributeStorage;
class MLIRContext; class MLIRContext;
namespace detail { namespace detail {
/// Utility method to generate a raw default location for use when checking the
/// construction invariants of a storage object. This is defined out-of-line to
/// avoid the need to include Location.h.
const AttributeStorage *generateUnknownStorageLocation(MLIRContext *ctx);
/// Utility class for implementing users of storage classes uniqued by a /// Utility class for implementing users of storage classes uniqued by a
/// StorageUniquer. Clients are not expected to interact with this class /// StorageUniquer. Clients are not expected to interact with this class
/// directly. /// directly.
@ -53,21 +58,20 @@ protected:
template <typename... Args> template <typename... Args>
static ConcreteT get(MLIRContext *ctx, unsigned kind, Args... args) { static ConcreteT get(MLIRContext *ctx, unsigned kind, Args... args) {
// Ensure that the invariants are correct for construction. // Ensure that the invariants are correct for construction.
assert(succeeded( assert(succeeded(ConcreteT::verifyConstructionInvariants(
ConcreteT::verifyConstructionInvariants(llvm::None, ctx, args...))); generateUnknownStorageLocation(ctx), args...)));
return UniquerT::template get<ConcreteT>(ctx, kind, args...); return UniquerT::template get<ConcreteT>(ctx, kind, args...);
} }
/// Get or create a new ConcreteT instance within the ctx, defined at /// Get or create a new ConcreteT instance within the ctx, defined at
/// the given, potentially unknown, location. If the arguments provided are /// the given, potentially unknown, location. If the arguments provided are
/// invalid then emit errors and return a null object. /// invalid then emit errors and return a null object.
template <typename... Args> template <typename LocationT, typename... Args>
static ConcreteT getChecked(const Location &loc, MLIRContext *ctx, static ConcreteT getChecked(LocationT loc, unsigned kind, Args... args) {
unsigned kind, Args... args) {
// If the construction invariants fail then we return a null attribute. // If the construction invariants fail then we return a null attribute.
if (failed(ConcreteT::verifyConstructionInvariants(loc, ctx, args...))) if (failed(ConcreteT::verifyConstructionInvariants(loc, args...)))
return ConcreteT(); return ConcreteT();
return UniquerT::template get<ConcreteT>(ctx, kind, args...); return UniquerT::template get<ConcreteT>(loc.getContext(), kind, args...);
} }
/// Default implementation that just returns success. /// Default implementation that just returns success.

View File

@ -46,10 +46,8 @@ struct OpaqueTypeStorage;
/// current type. Used for isa/dyn_cast casting functionality. /// current type. Used for isa/dyn_cast casting functionality.
/// ///
/// * Optional: /// * Optional:
/// - static LogicalResult verifyConstructionInvariants( /// - static LogicalResult verifyConstructionInvariants(Location loc,
/// Optional<Location> loc, /// Args... args)
/// MLIRContext *context,
/// Args... args)
/// * This method is invoked when calling the 'TypeBase::get/getChecked' /// * This method is invoked when calling the 'TypeBase::get/getChecked'
/// methods to ensure that the arguments passed in are valid to construct /// methods to ensure that the arguments passed in are valid to construct
/// a type instance with. /// a type instance with.
@ -238,8 +236,7 @@ public:
StringRef getTypeData() const; StringRef getTypeData() const;
/// Verify the construction of an opaque type. /// Verify the construction of an opaque type.
static LogicalResult verifyConstructionInvariants(Optional<Location> loc, static LogicalResult verifyConstructionInvariants(Location loc,
MLIRContext *context,
Identifier dialect, Identifier dialect,
StringRef typeData); StringRef typeData);

View File

@ -24,20 +24,19 @@ unsigned QuantizedType::getFlags() const {
} }
LogicalResult QuantizedType::verifyConstructionInvariants( LogicalResult QuantizedType::verifyConstructionInvariants(
Optional<Location> loc, MLIRContext *context, unsigned flags, Location loc, unsigned flags, Type storageType, Type expressedType,
Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMin, int64_t storageTypeMax) {
int64_t storageTypeMax) {
// Verify that the storage type is integral. // Verify that the storage type is integral.
// This restriction may be lifted at some point in favor of using bf16 // This restriction may be lifted at some point in favor of using bf16
// or f16 as exact representations on hardware where that is advantageous. // or f16 as exact representations on hardware where that is advantageous.
auto intStorageType = storageType.dyn_cast<IntegerType>(); auto intStorageType = storageType.dyn_cast<IntegerType>();
if (!intStorageType) if (!intStorageType)
return emitOptionalError(loc, "storage type must be integral"); return emitError(loc, "storage type must be integral");
unsigned integralWidth = intStorageType.getWidth(); unsigned integralWidth = intStorageType.getWidth();
// Verify storage width. // Verify storage width.
if (integralWidth == 0 || integralWidth > MaxStorageBits) if (integralWidth == 0 || integralWidth > MaxStorageBits)
return emitOptionalError(loc, "illegal storage type size: ", integralWidth); return emitError(loc, "illegal storage type size: ") << integralWidth;
// Verify storageTypeMin and storageTypeMax. // Verify storageTypeMin and storageTypeMax.
bool isSigned = bool isSigned =
@ -49,8 +48,8 @@ LogicalResult QuantizedType::verifyConstructionInvariants(
if (storageTypeMax - storageTypeMin <= 0 || if (storageTypeMax - storageTypeMin <= 0 ||
storageTypeMin < defaultIntegerMin || storageTypeMin < defaultIntegerMin ||
storageTypeMax > defaultIntegerMax) { storageTypeMax > defaultIntegerMax) {
return emitOptionalError(loc, "illegal storage min and storage max: (", return emitError(loc, "illegal storage min and storage max: (")
storageTypeMin, ":", storageTypeMax, ")"); << storageTypeMin << ":" << storageTypeMax << ")";
} }
return success(); return success();
} }
@ -209,17 +208,15 @@ AnyQuantizedType AnyQuantizedType::getChecked(unsigned flags, Type storageType,
int64_t storageTypeMin, int64_t storageTypeMin,
int64_t storageTypeMax, int64_t storageTypeMax,
Location location) { Location location) {
return Base::getChecked(location, storageType.getContext(), return Base::getChecked(location, QuantizationTypes::Any, flags, storageType,
QuantizationTypes::Any, flags, storageType,
expressedType, storageTypeMin, storageTypeMax); expressedType, storageTypeMin, storageTypeMax);
} }
LogicalResult AnyQuantizedType::verifyConstructionInvariants( LogicalResult AnyQuantizedType::verifyConstructionInvariants(
Optional<Location> loc, MLIRContext *context, unsigned flags, Location loc, unsigned flags, Type storageType, Type expressedType,
Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMin, int64_t storageTypeMax) {
int64_t storageTypeMax) {
if (failed(QuantizedType::verifyConstructionInvariants( if (failed(QuantizedType::verifyConstructionInvariants(
loc, context, flags, storageType, expressedType, storageTypeMin, loc, flags, storageType, expressedType, storageTypeMin,
storageTypeMax))) { storageTypeMax))) {
return failure(); return failure();
} }
@ -228,7 +225,7 @@ LogicalResult AnyQuantizedType::verifyConstructionInvariants(
// If this restriction is ever eliminated, the parser/printer must be // If this restriction is ever eliminated, the parser/printer must be
// extended. // extended.
if (expressedType && !expressedType.isa<FloatType>()) if (expressedType && !expressedType.isa<FloatType>())
return emitOptionalError(loc, "expressed type must be floating point"); return emitError(loc, "expressed type must be floating point");
return success(); return success();
} }
@ -249,18 +246,17 @@ UniformQuantizedType::getChecked(unsigned flags, Type storageType,
Type expressedType, double scale, Type expressedType, double scale,
int64_t zeroPoint, int64_t storageTypeMin, int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax, Location location) { int64_t storageTypeMax, Location location) {
return Base::getChecked(location, storageType.getContext(), return Base::getChecked(location, QuantizationTypes::UniformQuantized, flags,
QuantizationTypes::UniformQuantized, flags,
storageType, expressedType, scale, zeroPoint, storageType, expressedType, scale, zeroPoint,
storageTypeMin, storageTypeMax); storageTypeMin, storageTypeMax);
} }
LogicalResult UniformQuantizedType::verifyConstructionInvariants( LogicalResult UniformQuantizedType::verifyConstructionInvariants(
Optional<Location> loc, MLIRContext *context, unsigned flags, Location loc, unsigned flags, Type storageType, Type expressedType,
Type storageType, Type expressedType, double scale, int64_t zeroPoint, double scale, int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMin, int64_t storageTypeMax) { int64_t storageTypeMax) {
if (failed(QuantizedType::verifyConstructionInvariants( if (failed(QuantizedType::verifyConstructionInvariants(
loc, context, flags, storageType, expressedType, storageTypeMin, loc, flags, storageType, expressedType, storageTypeMin,
storageTypeMax))) { storageTypeMax))) {
return failure(); return failure();
} }
@ -268,18 +264,17 @@ LogicalResult UniformQuantizedType::verifyConstructionInvariants(
// Uniform quantization requires fully expressed parameters, including // Uniform quantization requires fully expressed parameters, including
// expressed type. // expressed type.
if (!expressedType) if (!expressedType)
return emitOptionalError(loc, return emitError(loc, "uniform quantization requires expressed type");
"uniform quantization requires expressed type");
// Verify that the expressed type is floating point. // Verify that the expressed type is floating point.
// If this restriction is ever eliminated, the parser/printer must be // If this restriction is ever eliminated, the parser/printer must be
// extended. // extended.
if (!expressedType.isa<FloatType>()) if (!expressedType.isa<FloatType>())
return emitOptionalError(loc, "expressed type must be floating point"); return emitError(loc, "expressed type must be floating point");
// Verify scale. // Verify scale.
if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale)) if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
return emitOptionalError(loc, "illegal scale: ", scale); return emitError(loc, "illegal scale: ") << scale;
return success(); return success();
} }
@ -306,19 +301,18 @@ UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints, ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax,
Location location) { Location location) {
return Base::getChecked(location, storageType.getContext(), return Base::getChecked(location, QuantizationTypes::UniformQuantizedPerAxis,
QuantizationTypes::UniformQuantizedPerAxis, flags, flags, storageType, expressedType, scales, zeroPoints,
storageType, expressedType, scales, zeroPoints,
quantizedDimension, storageTypeMin, storageTypeMax); quantizedDimension, storageTypeMin, storageTypeMax);
} }
LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants( LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants(
Optional<Location> loc, MLIRContext *context, unsigned flags, Location loc, unsigned flags, Type storageType, Type expressedType,
Type storageType, Type expressedType, ArrayRef<double> scales, ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension, int32_t quantizedDimension, int64_t storageTypeMin,
int64_t storageTypeMin, int64_t storageTypeMax) { int64_t storageTypeMax) {
if (failed(QuantizedType::verifyConstructionInvariants( if (failed(QuantizedType::verifyConstructionInvariants(
loc, context, flags, storageType, expressedType, storageTypeMin, loc, flags, storageType, expressedType, storageTypeMin,
storageTypeMax))) { storageTypeMax))) {
return failure(); return failure();
} }
@ -326,24 +320,23 @@ LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants(
// Uniform quantization requires fully expressed parameters, including // Uniform quantization requires fully expressed parameters, including
// expressed type. // expressed type.
if (!expressedType) if (!expressedType)
return emitOptionalError(loc, return emitError(loc, "uniform quantization requires expressed type");
"uniform quantization requires expressed type");
// Verify that the expressed type is floating point. // Verify that the expressed type is floating point.
// If this restriction is ever eliminated, the parser/printer must be // If this restriction is ever eliminated, the parser/printer must be
// extended. // extended.
if (!expressedType.isa<FloatType>()) if (!expressedType.isa<FloatType>())
return emitOptionalError(loc, "expressed type must be floating point"); return emitError(loc, "expressed type must be floating point");
// Ensure that the number of scales and zeroPoints match. // Ensure that the number of scales and zeroPoints match.
if (scales.size() != zeroPoints.size()) if (scales.size() != zeroPoints.size())
return emitOptionalError(loc, "illegal number of scales and zeroPoints: ", return emitError(loc, "illegal number of scales and zeroPoints: ")
scales.size(), ", ", zeroPoints.size()); << scales.size() << ", " << zeroPoints.size();
// Verify scale. // Verify scale.
for (double scale : scales) { for (double scale : scales) {
if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale)) if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
return emitOptionalError(loc, "illegal scale: ", scale); return emitError(loc, "illegal scale: ") << scale;
} }
return success(); return success();

View File

@ -103,10 +103,10 @@ DictionaryAttr spirv::TargetEnvAttr::getResourceLimits() {
} }
LogicalResult spirv::TargetEnvAttr::verifyConstructionInvariants( LogicalResult spirv::TargetEnvAttr::verifyConstructionInvariants(
Optional<Location> loc, MLIRContext *context, IntegerAttr version, Location loc, IntegerAttr version, ArrayAttr extensions,
ArrayAttr extensions, ArrayAttr capabilities, DictionaryAttr limits) { ArrayAttr capabilities, DictionaryAttr limits) {
if (!version.getType().isInteger(32)) if (!version.getType().isInteger(32))
return emitOptionalError(loc, "expected 32-bit integer for version"); return emitError(loc, "expected 32-bit integer for version");
if (!llvm::all_of(extensions.getValue(), [](Attribute attr) { if (!llvm::all_of(extensions.getValue(), [](Attribute attr) {
if (auto strAttr = attr.dyn_cast<StringAttr>()) if (auto strAttr = attr.dyn_cast<StringAttr>())
@ -114,7 +114,7 @@ LogicalResult spirv::TargetEnvAttr::verifyConstructionInvariants(
return true; return true;
return false; return false;
})) }))
return emitOptionalError(loc, "unknown extension in extension list"); return emitError(loc, "unknown extension in extension list");
if (!llvm::all_of(capabilities.getValue(), [](Attribute attr) { if (!llvm::all_of(capabilities.getValue(), [](Attribute attr) {
if (auto intAttr = attr.dyn_cast<IntegerAttr>()) if (auto intAttr = attr.dyn_cast<IntegerAttr>())
@ -122,11 +122,10 @@ LogicalResult spirv::TargetEnvAttr::verifyConstructionInvariants(
return true; return true;
return false; return false;
})) }))
return emitOptionalError(loc, "unknown capability in capability list"); return emitError(loc, "unknown capability in capability list");
if (!limits.isa<spirv::ResourceLimitsAttr>()) if (!limits.isa<spirv::ResourceLimitsAttr>())
return emitOptionalError(loc, return emitError(loc, "expected spirv::ResourceLimitsAttr for limits");
"expected spirv::ResourceLimitsAttr for limits");
return success(); return success();
} }

View File

@ -182,8 +182,7 @@ FloatAttr FloatAttr::get(Type type, double value) {
} }
FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) { FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
return Base::getChecked(loc, type.getContext(), StandardAttributes::Float, return Base::getChecked(loc, StandardAttributes::Float, type, value);
type, value);
} }
FloatAttr FloatAttr::get(Type type, const APFloat &value) { FloatAttr FloatAttr::get(Type type, const APFloat &value) {
@ -191,8 +190,7 @@ FloatAttr FloatAttr::get(Type type, const APFloat &value) {
} }
FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) { FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
return Base::getChecked(loc, type.getContext(), StandardAttributes::Float, return Base::getChecked(loc, StandardAttributes::Float, type, value);
type, value);
} }
APFloat FloatAttr::getValue() const { return getImpl()->getValue(); } APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
@ -210,22 +208,18 @@ double FloatAttr::getValueAsDouble(APFloat value) {
} }
/// Verify construction invariants. /// Verify construction invariants.
static LogicalResult verifyFloatTypeInvariants(Optional<Location> loc, static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) {
Type type) {
if (!type.isa<FloatType>()) if (!type.isa<FloatType>())
return emitOptionalError(loc, "expected floating point type"); return emitError(loc, "expected floating point type");
return success(); return success();
} }
LogicalResult FloatAttr::verifyConstructionInvariants(Optional<Location> loc, LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
MLIRContext *ctx, double value) {
Type type, double value) {
return verifyFloatTypeInvariants(loc, type); return verifyFloatTypeInvariants(loc, type);
} }
LogicalResult FloatAttr::verifyConstructionInvariants(Optional<Location> loc, LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
MLIRContext *ctx,
Type type,
const APFloat &value) { const APFloat &value) {
// Verify that the type is correct. // Verify that the type is correct.
if (failed(verifyFloatTypeInvariants(loc, type))) if (failed(verifyFloatTypeInvariants(loc, type)))
@ -233,7 +227,7 @@ LogicalResult FloatAttr::verifyConstructionInvariants(Optional<Location> loc,
// Verify that the type semantics match that of the value. // Verify that the type semantics match that of the value.
if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
return emitOptionalError( return emitError(
loc, "FloatAttr type doesn't match the type implied by its value"); loc, "FloatAttr type doesn't match the type implied by its value");
} }
return success(); return success();
@ -286,31 +280,26 @@ APInt IntegerAttr::getValue() const { return getImpl()->getValue(); }
int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); } int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); }
static LogicalResult verifyIntegerTypeInvariants(Optional<Location> loc, static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) {
Type type) {
if (type.isa<IntegerType>() || type.isa<IndexType>()) if (type.isa<IntegerType>() || type.isa<IndexType>())
return success(); return success();
return emitOptionalError(loc, "expected integer or index type"); return emitError(loc, "expected integer or index type");
} }
LogicalResult IntegerAttr::verifyConstructionInvariants(Optional<Location> loc, LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
MLIRContext *ctx,
Type type,
int64_t value) { int64_t value) {
return verifyIntegerTypeInvariants(loc, type); return verifyIntegerTypeInvariants(loc, type);
} }
LogicalResult IntegerAttr::verifyConstructionInvariants(Optional<Location> loc, LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
MLIRContext *ctx,
Type type,
const APInt &value) { const APInt &value) {
if (failed(verifyIntegerTypeInvariants(loc, type))) if (failed(verifyIntegerTypeInvariants(loc, type)))
return failure(); return failure();
if (auto integerType = type.dyn_cast<IntegerType>()) if (auto integerType = type.dyn_cast<IntegerType>())
if (integerType.getWidth() != value.getBitWidth()) if (integerType.getWidth() != value.getBitWidth())
return emitOptionalError( return emitError(loc, "integer type bit width (")
loc, "integer type bit width (", integerType.getWidth(), << integerType.getWidth() << ") doesn't match value bit width ("
") doesn't match value bit width (", value.getBitWidth(), ")"); << value.getBitWidth() << ")";
return success(); return success();
} }
@ -337,8 +326,8 @@ OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData, OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData,
Type type, Location location) { Type type, Location location) {
return Base::getChecked(location, type.getContext(), return Base::getChecked(location, StandardAttributes::Opaque, dialect,
StandardAttributes::Opaque, dialect, attrData, type); attrData, type);
} }
/// Returns the dialect namespace of the opaque attribute. /// Returns the dialect namespace of the opaque attribute.
@ -350,13 +339,12 @@ Identifier OpaqueAttr::getDialectNamespace() const {
StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; } StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; }
/// Verify the construction of an opaque attribute. /// Verify the construction of an opaque attribute.
LogicalResult OpaqueAttr::verifyConstructionInvariants(Optional<Location> loc, LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc,
MLIRContext *context,
Identifier dialect, Identifier dialect,
StringRef attrData, StringRef attrData,
Type type) { Type type) {
if (!Dialect::isValidNamespace(dialect.strref())) if (!Dialect::isValidNamespace(dialect.strref()))
return emitOptionalError(loc, "invalid dialect namespace '", dialect, "'"); return emitError(loc, "invalid dialect namespace '") << dialect << "'";
return success(); return success();
} }

View File

@ -518,7 +518,7 @@ IntegerType IntegerType::getChecked(unsigned width, MLIRContext *context,
Location location) { Location location) {
if (auto cached = getCachedIntegerType(width, context)) if (auto cached = getCachedIntegerType(width, context))
return cached; return cached;
return Base::getChecked(location, context, StandardTypes::Integer, width); return Base::getChecked(location, StandardTypes::Integer, width);
} }
/// Get an instance of the NoneType. /// Get an instance of the NoneType.
@ -639,3 +639,16 @@ IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount,
llvm::sys::SmartScopedWriter<true> affineLock(impl.affineMutex); llvm::sys::SmartScopedWriter<true> affineLock(impl.affineMutex);
return constructorFn(); return constructorFn();
} }
//===----------------------------------------------------------------------===//
// StorageUniquerSupport
//===----------------------------------------------------------------------===//
/// Utility method to generate a default location for use when checking the
/// construction invariants of a storage object. This is defined out-of-line to
/// avoid the need to include Location.h.
const AttributeStorage *
mlir::detail::generateUnknownStorageLocation(MLIRContext *ctx) {
return reinterpret_cast<const AttributeStorage *>(
ctx->getImpl().unknownLocAttr.getAsOpaquePointer());
}

View File

@ -52,12 +52,11 @@ bool Type::isIntOrFloat() { return isa<IntegerType>() || isa<FloatType>(); }
constexpr unsigned IntegerType::kMaxWidth; constexpr unsigned IntegerType::kMaxWidth;
/// Verify the construction of an integer type. /// Verify the construction of an integer type.
LogicalResult IntegerType::verifyConstructionInvariants(Optional<Location> loc, LogicalResult IntegerType::verifyConstructionInvariants(Location loc,
MLIRContext *context,
unsigned width) { unsigned width) {
if (width > IntegerType::kMaxWidth) { if (width > IntegerType::kMaxWidth) {
return emitOptionalError(loc, "integer bitwidth is limited to ", return emitError(loc) << "integer bitwidth is limited to "
IntegerType::kMaxWidth, " bits"); << IntegerType::kMaxWidth << " bits";
} }
return success(); return success();
} }
@ -203,24 +202,20 @@ VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) {
VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType, VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType,
Location location) { Location location) {
return Base::getChecked(location, elementType.getContext(), return Base::getChecked(location, StandardTypes::Vector, shape, elementType);
StandardTypes::Vector, shape, elementType);
} }
LogicalResult VectorType::verifyConstructionInvariants(Optional<Location> loc, LogicalResult VectorType::verifyConstructionInvariants(Location loc,
MLIRContext *context,
ArrayRef<int64_t> shape, ArrayRef<int64_t> shape,
Type elementType) { Type elementType) {
if (shape.empty()) if (shape.empty())
return emitOptionalError(loc, return emitError(loc, "vector types must have at least one dimension");
"vector types must have at least one dimension");
if (!isValidElementType(elementType)) if (!isValidElementType(elementType))
return emitOptionalError(loc, "vector elements must be int or float type"); return emitError(loc, "vector elements must be int or float type");
if (any_of(shape, [](int64_t i) { return i <= 0; })) if (any_of(shape, [](int64_t i) { return i <= 0; }))
return emitOptionalError(loc, return emitError(loc, "vector types must have positive constant sizes");
"vector types must have positive constant sizes");
return success(); return success();
} }
@ -233,11 +228,10 @@ ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
// Check if "elementType" can be an element type of a tensor. Emit errors if // Check if "elementType" can be an element type of a tensor. Emit errors if
// location is not nullptr. Returns failure if check failed. // location is not nullptr. Returns failure if check failed.
static inline LogicalResult checkTensorElementType(Optional<Location> location, static inline LogicalResult checkTensorElementType(Location location,
MLIRContext *context,
Type elementType) { Type elementType) {
if (!TensorType::isValidElementType(elementType)) if (!TensorType::isValidElementType(elementType))
return emitOptionalError(location, "invalid tensor element type"); return emitError(location, "invalid tensor element type");
return success(); return success();
} }
@ -254,18 +248,17 @@ RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape,
RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape, RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape,
Type elementType, Type elementType,
Location location) { Location location) {
return Base::getChecked(location, elementType.getContext(), return Base::getChecked(location, StandardTypes::RankedTensor, shape,
StandardTypes::RankedTensor, shape, elementType); elementType);
} }
LogicalResult RankedTensorType::verifyConstructionInvariants( LogicalResult RankedTensorType::verifyConstructionInvariants(
Optional<Location> loc, MLIRContext *context, ArrayRef<int64_t> shape, Location loc, ArrayRef<int64_t> shape, Type elementType) {
Type elementType) {
for (int64_t s : shape) { for (int64_t s : shape) {
if (s < -1) if (s < -1)
return emitOptionalError(loc, "invalid tensor dimension size"); return emitError(loc, "invalid tensor dimension size");
} }
return checkTensorElementType(loc, context, elementType); return checkTensorElementType(loc, elementType);
} }
ArrayRef<int64_t> RankedTensorType::getShape() const { ArrayRef<int64_t> RankedTensorType::getShape() const {
@ -283,13 +276,13 @@ UnrankedTensorType UnrankedTensorType::get(Type elementType) {
UnrankedTensorType UnrankedTensorType::getChecked(Type elementType, UnrankedTensorType UnrankedTensorType::getChecked(Type elementType,
Location location) { Location location) {
return Base::getChecked(location, elementType.getContext(), return Base::getChecked(location, StandardTypes::UnrankedTensor, elementType);
StandardTypes::UnrankedTensor, elementType);
} }
LogicalResult UnrankedTensorType::verifyConstructionInvariants( LogicalResult
Optional<Location> loc, MLIRContext *context, Type elementType) { UnrankedTensorType::verifyConstructionInvariants(Location loc,
return checkTensorElementType(loc, context, elementType); Type elementType) {
return checkTensorElementType(loc, elementType);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -399,8 +392,7 @@ UnrankedMemRefType UnrankedMemRefType::get(Type elementType,
UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType, UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType,
unsigned memorySpace, unsigned memorySpace,
Location location) { Location location) {
return Base::getChecked(location, elementType.getContext(), return Base::getChecked(location, StandardTypes::UnrankedMemRef, elementType,
StandardTypes::UnrankedMemRef, elementType,
memorySpace); memorySpace);
} }
@ -408,13 +400,13 @@ unsigned UnrankedMemRefType::getMemorySpace() const {
return getImpl()->memorySpace; return getImpl()->memorySpace;
} }
LogicalResult UnrankedMemRefType::verifyConstructionInvariants( LogicalResult
Optional<Location> loc, MLIRContext *context, Type elementType, UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
unsigned memorySpace) { unsigned memorySpace) {
// Check that memref is formed from allowed types. // Check that memref is formed from allowed types.
if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() && if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
!elementType.isa<ComplexType>()) !elementType.isa<ComplexType>())
return emitOptionalError(*loc, "invalid memref element type"); return emitError(loc, "invalid memref element type");
return success(); return success();
} }
@ -621,16 +613,14 @@ ComplexType ComplexType::get(Type elementType) {
} }
ComplexType ComplexType::getChecked(Type elementType, Location location) { ComplexType ComplexType::getChecked(Type elementType, Location location) {
return Base::getChecked(location, elementType.getContext(), return Base::getChecked(location, StandardTypes::Complex, elementType);
StandardTypes::Complex, elementType);
} }
/// Verify the construction of an integer type. /// Verify the construction of an integer type.
LogicalResult ComplexType::verifyConstructionInvariants(Optional<Location> loc, LogicalResult ComplexType::verifyConstructionInvariants(Location loc,
MLIRContext *context,
Type elementType) { Type elementType) {
if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>()) if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>())
return emitOptionalError(loc, "invalid element type for complex"); return emitError(loc, "invalid element type for complex");
return success(); return success();
} }

View File

@ -59,7 +59,7 @@ OpaqueType OpaqueType::get(Identifier dialect, StringRef typeData,
OpaqueType OpaqueType::getChecked(Identifier dialect, StringRef typeData, OpaqueType OpaqueType::getChecked(Identifier dialect, StringRef typeData,
MLIRContext *context, Location location) { MLIRContext *context, Location location) {
return Base::getChecked(location, context, Kind::Opaque, dialect, typeData); return Base::getChecked(location, Kind::Opaque, dialect, typeData);
} }
/// Returns the dialect namespace of the opaque type. /// Returns the dialect namespace of the opaque type.
@ -71,11 +71,10 @@ Identifier OpaqueType::getDialectNamespace() const {
StringRef OpaqueType::getTypeData() const { return getImpl()->typeData; } StringRef OpaqueType::getTypeData() const { return getImpl()->typeData; }
/// Verify the construction of an opaque type. /// Verify the construction of an opaque type.
LogicalResult OpaqueType::verifyConstructionInvariants(Optional<Location> loc, LogicalResult OpaqueType::verifyConstructionInvariants(Location loc,
MLIRContext *context,
Identifier dialect, Identifier dialect,
StringRef typeData) { StringRef typeData) {
if (!Dialect::isValidNamespace(dialect.strref())) if (!Dialect::isValidNamespace(dialect.strref()))
return emitOptionalError(loc, "invalid dialect namespace '", dialect, "'"); return emitError(loc, "invalid dialect namespace '") << dialect << "'";
return success(); return success();
} }