[mlir] Refactor the structure of the 'verifyConstructionInvariants' methods.

Summary:
The current structure suffers from several problems, but the main one is that a construction failure is impossible to debug when using the 'get' methods. This is because we only optionally emit errors, so there is no context given to the user about the problem. This revision restructures this so that errors are always emitted, and the 'get' methods simply pass in an UnknownLoc to emit to. This allows for removing usages of the more constrained "emitOptionalLoc", as well as removing the need for the context parameter.

Fixes [PR#44964](https://bugs.llvm.org/show_bug.cgi?id=44964)

Differential Revision: https://reviews.llvm.org/D74876
This commit is contained in:
River Riddle 2020-02-20 10:31:44 -08:00
parent 726c342ce2
commit 70d8fec7c9
14 changed files with 160 additions and 189 deletions

View File

@ -194,42 +194,34 @@ public:
/// This method is used to get an instance of the 'ComplexType'. This method
/// asserts that all of the construction invariants were satisfied. To
/// gracefully handle failed construction, getChecked should be used instead.
static ComplexType get(MLIRContext *context, unsigned param, Type type) {
static ComplexType get(unsigned param, Type type) {
// Call into a helper 'get' method in 'TypeBase' to get a uniqued instance
// of this type. All parameters to the storage class are passed after the
// type kind.
return Base::get(context, MyTypes::Complex, param, type);
return Base::get(type.getContext(), MyTypes::Complex, param, type);
}
/// This method is used to get an instance of the 'ComplexType', defined at
/// the given location. If any of the construction invariants are invalid,
/// errors are emitted with the provided location and a null type is returned.
/// Note: This method is completely optional.
static ComplexType getChecked(MLIRContext *context, unsigned param, Type type,
Location location) {
static ComplexType getChecked(unsigned param, Type type, Location location) {
// Call into a helper 'getChecked' method in 'TypeBase' to get a uniqued
// instance of this type. All parameters to the storage class are passed
// after the type kind.
return Base::getChecked(location, context, MyTypes::Complex, param, type);
return Base::getChecked(location, MyTypes::Complex, param, type);
}
/// This method is used to verify the construction invariants passed into the
/// 'get' and 'getChecked' methods. Note: This method is completely optional.
static LogicalResult verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, unsigned param,
Type type) {
Location loc, unsigned param, Type type) {
// Our type only allows non-zero parameters.
if (param == 0) {
if (loc)
context->emitError(loc) << "non-zero parameter passed to 'ComplexType'";
return failure();
}
if (param == 0)
return emitError(loc) << "non-zero parameter passed to 'ComplexType'";
// Our type also expects an integer type.
if (!type.isa<IntegerType>()) {
if (loc)
context->emitError(loc) << "non integer-type passed to 'ComplexType'";
return failure();
}
if (!type.isa<IntegerType>())
return emitError(loc) << "non integer-type passed to 'ComplexType'";
return success();
}

View File

@ -66,8 +66,7 @@ public:
static constexpr unsigned MaxStorageBits = 32;
static LogicalResult
verifyConstructionInvariants(Optional<Location> loc, MLIRContext *context,
unsigned flags, Type storageType,
verifyConstructionInvariants(Location loc, unsigned flags, Type storageType,
Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax);
@ -229,8 +228,7 @@ public:
/// Verifies construction invariants and issues errors/warnings.
static LogicalResult
verifyConstructionInvariants(Optional<Location> loc, MLIRContext *context,
unsigned flags, Type storageType,
verifyConstructionInvariants(Location loc, unsigned flags, Type storageType,
Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax);
};
@ -288,10 +286,11 @@ public:
Location location);
/// Verifies construction invariants and issues errors/warnings.
static LogicalResult verifyConstructionInvariants(
Optional<Location> loc, MLIRContext *context, unsigned flags,
Type storageType, Type expressedType, double scale, int64_t zeroPoint,
int64_t storageTypeMin, int64_t storageTypeMax);
static LogicalResult
verifyConstructionInvariants(Location loc, unsigned flags, Type storageType,
Type expressedType, double scale,
int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax);
/// Support method to enable LLVM-style type casting.
static bool kindof(unsigned kind) {
@ -351,11 +350,12 @@ public:
int64_t storageTypeMax, Location location);
/// Verifies construction invariants and issues errors/warnings.
static LogicalResult verifyConstructionInvariants(
Optional<Location> loc, MLIRContext *context, unsigned flags,
Type storageType, Type expressedType, ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax);
static LogicalResult
verifyConstructionInvariants(Location loc, unsigned flags, Type storageType,
Type expressedType, ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax);
/// Support method to enable LLVM-style type casting.
static bool kindof(unsigned kind) {

View File

@ -90,10 +90,11 @@ public:
static bool kindof(unsigned kind) { return kind == AttrKind::TargetEnv; }
static LogicalResult
verifyConstructionInvariants(Optional<Location> loc, MLIRContext *context,
IntegerAttr version, ArrayAttr extensions,
ArrayAttr capabilities, DictionaryAttr limits);
static LogicalResult verifyConstructionInvariants(Location loc,
IntegerAttr version,
ArrayAttr extensions,
ArrayAttr capabilities,
DictionaryAttr limits);
};
/// Returns the attribute name for specifying argument ABI information.

View File

@ -330,11 +330,9 @@ public:
}
/// Verify the construction invariants for a double value.
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *ctx, Type type,
static LogicalResult verifyConstructionInvariants(Location loc, Type type,
double value);
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *ctx, Type type,
static LogicalResult verifyConstructionInvariants(Location loc, Type type,
const APFloat &value);
};
@ -361,11 +359,9 @@ public:
return kind == StandardAttributes::Integer;
}
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *ctx, Type type,
static LogicalResult verifyConstructionInvariants(Location loc, Type type,
int64_t value);
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *ctx, Type type,
static LogicalResult verifyConstructionInvariants(Location loc, Type type,
const APInt &value);
};
@ -419,8 +415,7 @@ public:
StringRef getAttrData() const;
/// Verify the construction of an opaque attribute.
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *context,
static LogicalResult verifyConstructionInvariants(Location loc,
Identifier dialect,
StringRef attrData,
Type type);

View File

@ -54,6 +54,12 @@ public:
Location(LocationAttr loc) : impl(loc) {
assert(loc && "location should never be null.");
}
Location(const LocationAttr::ImplType *impl) : impl(impl) {
assert(impl && "location should never be null.");
}
/// Return the context this location is uniqued in.
MLIRContext *getContext() const { return impl.getContext(); }
/// Access the impl location attribute.
operator LocationAttr() const { return impl; }

View File

@ -96,8 +96,7 @@ public:
Location location);
/// Verify the construction of an integer type.
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *context,
static LogicalResult verifyConstructionInvariants(Location loc,
unsigned width);
/// Return the bitwidth of this integer type.
@ -162,8 +161,7 @@ public:
static ComplexType getChecked(Type elementType, Location location);
/// Verify the construction of an integer type.
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *context,
static LogicalResult verifyConstructionInvariants(Location loc,
Type elementType);
Type getElementType();
@ -270,8 +268,7 @@ public:
Location location);
/// Verify the construction of a vector type.
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *context,
static LogicalResult verifyConstructionInvariants(Location loc,
ArrayRef<int64_t> shape,
Type elementType);
@ -329,8 +326,7 @@ public:
Location location);
/// Verify the construction of a ranked tensor type.
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *context,
static LogicalResult verifyConstructionInvariants(Location loc,
ArrayRef<int64_t> shape,
Type elementType);
@ -360,8 +356,7 @@ public:
static UnrankedTensorType getChecked(Type elementType, Location location);
/// Verify the construction of a unranked tensor type.
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *context,
static LogicalResult verifyConstructionInvariants(Location loc,
Type elementType);
ArrayRef<int64_t> getShape() const { return llvm::None; }
@ -505,8 +500,7 @@ public:
Location location);
/// Verify the construction of a unranked memref type.
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *context,
static LogicalResult verifyConstructionInvariants(Location loc,
Type elementType,
unsigned memorySpace);

View File

@ -18,10 +18,15 @@
#include "mlir/Support/StorageUniquer.h"
namespace mlir {
class Location;
class AttributeStorage;
class MLIRContext;
namespace detail {
/// Utility method to generate a raw default location for use when checking the
/// construction invariants of a storage object. This is defined out-of-line to
/// avoid the need to include Location.h.
const AttributeStorage *generateUnknownStorageLocation(MLIRContext *ctx);
/// Utility class for implementing users of storage classes uniqued by a
/// StorageUniquer. Clients are not expected to interact with this class
/// directly.
@ -53,21 +58,20 @@ protected:
template <typename... Args>
static ConcreteT get(MLIRContext *ctx, unsigned kind, Args... args) {
// Ensure that the invariants are correct for construction.
assert(succeeded(
ConcreteT::verifyConstructionInvariants(llvm::None, ctx, args...)));
assert(succeeded(ConcreteT::verifyConstructionInvariants(
generateUnknownStorageLocation(ctx), args...)));
return UniquerT::template get<ConcreteT>(ctx, kind, args...);
}
/// Get or create a new ConcreteT instance within the ctx, defined at
/// the given, potentially unknown, location. If the arguments provided are
/// invalid then emit errors and return a null object.
template <typename... Args>
static ConcreteT getChecked(const Location &loc, MLIRContext *ctx,
unsigned kind, Args... args) {
template <typename LocationT, typename... Args>
static ConcreteT getChecked(LocationT loc, unsigned kind, Args... args) {
// If the construction invariants fail then we return a null attribute.
if (failed(ConcreteT::verifyConstructionInvariants(loc, ctx, args...)))
if (failed(ConcreteT::verifyConstructionInvariants(loc, args...)))
return ConcreteT();
return UniquerT::template get<ConcreteT>(ctx, kind, args...);
return UniquerT::template get<ConcreteT>(loc.getContext(), kind, args...);
}
/// Default implementation that just returns success.

View File

@ -46,10 +46,8 @@ struct OpaqueTypeStorage;
/// current type. Used for isa/dyn_cast casting functionality.
///
/// * Optional:
/// - static LogicalResult verifyConstructionInvariants(
/// Optional<Location> loc,
/// MLIRContext *context,
/// Args... args)
/// - static LogicalResult verifyConstructionInvariants(Location loc,
/// Args... args)
/// * This method is invoked when calling the 'TypeBase::get/getChecked'
/// methods to ensure that the arguments passed in are valid to construct
/// a type instance with.
@ -238,8 +236,7 @@ public:
StringRef getTypeData() const;
/// Verify the construction of an opaque type.
static LogicalResult verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *context,
static LogicalResult verifyConstructionInvariants(Location loc,
Identifier dialect,
StringRef typeData);

View File

@ -24,20 +24,19 @@ unsigned QuantizedType::getFlags() const {
}
LogicalResult QuantizedType::verifyConstructionInvariants(
Optional<Location> loc, MLIRContext *context, unsigned flags,
Type storageType, Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax) {
Location loc, unsigned flags, Type storageType, Type expressedType,
int64_t storageTypeMin, int64_t storageTypeMax) {
// Verify that the storage type is integral.
// This restriction may be lifted at some point in favor of using bf16
// or f16 as exact representations on hardware where that is advantageous.
auto intStorageType = storageType.dyn_cast<IntegerType>();
if (!intStorageType)
return emitOptionalError(loc, "storage type must be integral");
return emitError(loc, "storage type must be integral");
unsigned integralWidth = intStorageType.getWidth();
// Verify storage width.
if (integralWidth == 0 || integralWidth > MaxStorageBits)
return emitOptionalError(loc, "illegal storage type size: ", integralWidth);
return emitError(loc, "illegal storage type size: ") << integralWidth;
// Verify storageTypeMin and storageTypeMax.
bool isSigned =
@ -49,8 +48,8 @@ LogicalResult QuantizedType::verifyConstructionInvariants(
if (storageTypeMax - storageTypeMin <= 0 ||
storageTypeMin < defaultIntegerMin ||
storageTypeMax > defaultIntegerMax) {
return emitOptionalError(loc, "illegal storage min and storage max: (",
storageTypeMin, ":", storageTypeMax, ")");
return emitError(loc, "illegal storage min and storage max: (")
<< storageTypeMin << ":" << storageTypeMax << ")";
}
return success();
}
@ -209,17 +208,15 @@ AnyQuantizedType AnyQuantizedType::getChecked(unsigned flags, Type storageType,
int64_t storageTypeMin,
int64_t storageTypeMax,
Location location) {
return Base::getChecked(location, storageType.getContext(),
QuantizationTypes::Any, flags, storageType,
return Base::getChecked(location, QuantizationTypes::Any, flags, storageType,
expressedType, storageTypeMin, storageTypeMax);
}
LogicalResult AnyQuantizedType::verifyConstructionInvariants(
Optional<Location> loc, MLIRContext *context, unsigned flags,
Type storageType, Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax) {
Location loc, unsigned flags, Type storageType, Type expressedType,
int64_t storageTypeMin, int64_t storageTypeMax) {
if (failed(QuantizedType::verifyConstructionInvariants(
loc, context, flags, storageType, expressedType, storageTypeMin,
loc, flags, storageType, expressedType, storageTypeMin,
storageTypeMax))) {
return failure();
}
@ -228,7 +225,7 @@ LogicalResult AnyQuantizedType::verifyConstructionInvariants(
// If this restriction is ever eliminated, the parser/printer must be
// extended.
if (expressedType && !expressedType.isa<FloatType>())
return emitOptionalError(loc, "expressed type must be floating point");
return emitError(loc, "expressed type must be floating point");
return success();
}
@ -249,18 +246,17 @@ UniformQuantizedType::getChecked(unsigned flags, Type storageType,
Type expressedType, double scale,
int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax, Location location) {
return Base::getChecked(location, storageType.getContext(),
QuantizationTypes::UniformQuantized, flags,
return Base::getChecked(location, QuantizationTypes::UniformQuantized, flags,
storageType, expressedType, scale, zeroPoint,
storageTypeMin, storageTypeMax);
}
LogicalResult UniformQuantizedType::verifyConstructionInvariants(
Optional<Location> loc, MLIRContext *context, unsigned flags,
Type storageType, Type expressedType, double scale, int64_t zeroPoint,
int64_t storageTypeMin, int64_t storageTypeMax) {
Location loc, unsigned flags, Type storageType, Type expressedType,
double scale, int64_t zeroPoint, int64_t storageTypeMin,
int64_t storageTypeMax) {
if (failed(QuantizedType::verifyConstructionInvariants(
loc, context, flags, storageType, expressedType, storageTypeMin,
loc, flags, storageType, expressedType, storageTypeMin,
storageTypeMax))) {
return failure();
}
@ -268,18 +264,17 @@ LogicalResult UniformQuantizedType::verifyConstructionInvariants(
// Uniform quantization requires fully expressed parameters, including
// expressed type.
if (!expressedType)
return emitOptionalError(loc,
"uniform quantization requires expressed type");
return emitError(loc, "uniform quantization requires expressed type");
// Verify that the expressed type is floating point.
// If this restriction is ever eliminated, the parser/printer must be
// extended.
if (!expressedType.isa<FloatType>())
return emitOptionalError(loc, "expressed type must be floating point");
return emitError(loc, "expressed type must be floating point");
// Verify scale.
if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
return emitOptionalError(loc, "illegal scale: ", scale);
return emitError(loc, "illegal scale: ") << scale;
return success();
}
@ -306,19 +301,18 @@ UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked(
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax,
Location location) {
return Base::getChecked(location, storageType.getContext(),
QuantizationTypes::UniformQuantizedPerAxis, flags,
storageType, expressedType, scales, zeroPoints,
return Base::getChecked(location, QuantizationTypes::UniformQuantizedPerAxis,
flags, storageType, expressedType, scales, zeroPoints,
quantizedDimension, storageTypeMin, storageTypeMax);
}
LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants(
Optional<Location> loc, MLIRContext *context, unsigned flags,
Type storageType, Type expressedType, ArrayRef<double> scales,
ArrayRef<int64_t> zeroPoints, int32_t quantizedDimension,
int64_t storageTypeMin, int64_t storageTypeMax) {
Location loc, unsigned flags, Type storageType, Type expressedType,
ArrayRef<double> scales, ArrayRef<int64_t> zeroPoints,
int32_t quantizedDimension, int64_t storageTypeMin,
int64_t storageTypeMax) {
if (failed(QuantizedType::verifyConstructionInvariants(
loc, context, flags, storageType, expressedType, storageTypeMin,
loc, flags, storageType, expressedType, storageTypeMin,
storageTypeMax))) {
return failure();
}
@ -326,24 +320,23 @@ LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants(
// Uniform quantization requires fully expressed parameters, including
// expressed type.
if (!expressedType)
return emitOptionalError(loc,
"uniform quantization requires expressed type");
return emitError(loc, "uniform quantization requires expressed type");
// Verify that the expressed type is floating point.
// If this restriction is ever eliminated, the parser/printer must be
// extended.
if (!expressedType.isa<FloatType>())
return emitOptionalError(loc, "expressed type must be floating point");
return emitError(loc, "expressed type must be floating point");
// Ensure that the number of scales and zeroPoints match.
if (scales.size() != zeroPoints.size())
return emitOptionalError(loc, "illegal number of scales and zeroPoints: ",
scales.size(), ", ", zeroPoints.size());
return emitError(loc, "illegal number of scales and zeroPoints: ")
<< scales.size() << ", " << zeroPoints.size();
// Verify scale.
for (double scale : scales) {
if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale))
return emitOptionalError(loc, "illegal scale: ", scale);
return emitError(loc, "illegal scale: ") << scale;
}
return success();

View File

@ -103,10 +103,10 @@ DictionaryAttr spirv::TargetEnvAttr::getResourceLimits() {
}
LogicalResult spirv::TargetEnvAttr::verifyConstructionInvariants(
Optional<Location> loc, MLIRContext *context, IntegerAttr version,
ArrayAttr extensions, ArrayAttr capabilities, DictionaryAttr limits) {
Location loc, IntegerAttr version, ArrayAttr extensions,
ArrayAttr capabilities, DictionaryAttr limits) {
if (!version.getType().isInteger(32))
return emitOptionalError(loc, "expected 32-bit integer for version");
return emitError(loc, "expected 32-bit integer for version");
if (!llvm::all_of(extensions.getValue(), [](Attribute attr) {
if (auto strAttr = attr.dyn_cast<StringAttr>())
@ -114,7 +114,7 @@ LogicalResult spirv::TargetEnvAttr::verifyConstructionInvariants(
return true;
return false;
}))
return emitOptionalError(loc, "unknown extension in extension list");
return emitError(loc, "unknown extension in extension list");
if (!llvm::all_of(capabilities.getValue(), [](Attribute attr) {
if (auto intAttr = attr.dyn_cast<IntegerAttr>())
@ -122,11 +122,10 @@ LogicalResult spirv::TargetEnvAttr::verifyConstructionInvariants(
return true;
return false;
}))
return emitOptionalError(loc, "unknown capability in capability list");
return emitError(loc, "unknown capability in capability list");
if (!limits.isa<spirv::ResourceLimitsAttr>())
return emitOptionalError(loc,
"expected spirv::ResourceLimitsAttr for limits");
return emitError(loc, "expected spirv::ResourceLimitsAttr for limits");
return success();
}

View File

@ -182,8 +182,7 @@ FloatAttr FloatAttr::get(Type type, double value) {
}
FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) {
return Base::getChecked(loc, type.getContext(), StandardAttributes::Float,
type, value);
return Base::getChecked(loc, StandardAttributes::Float, type, value);
}
FloatAttr FloatAttr::get(Type type, const APFloat &value) {
@ -191,8 +190,7 @@ FloatAttr FloatAttr::get(Type type, const APFloat &value) {
}
FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) {
return Base::getChecked(loc, type.getContext(), StandardAttributes::Float,
type, value);
return Base::getChecked(loc, StandardAttributes::Float, type, value);
}
APFloat FloatAttr::getValue() const { return getImpl()->getValue(); }
@ -210,22 +208,18 @@ double FloatAttr::getValueAsDouble(APFloat value) {
}
/// Verify construction invariants.
static LogicalResult verifyFloatTypeInvariants(Optional<Location> loc,
Type type) {
static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) {
if (!type.isa<FloatType>())
return emitOptionalError(loc, "expected floating point type");
return emitError(loc, "expected floating point type");
return success();
}
LogicalResult FloatAttr::verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *ctx,
Type type, double value) {
LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
double value) {
return verifyFloatTypeInvariants(loc, type);
}
LogicalResult FloatAttr::verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *ctx,
Type type,
LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type,
const APFloat &value) {
// Verify that the type is correct.
if (failed(verifyFloatTypeInvariants(loc, type)))
@ -233,7 +227,7 @@ LogicalResult FloatAttr::verifyConstructionInvariants(Optional<Location> loc,
// Verify that the type semantics match that of the value.
if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
return emitOptionalError(
return emitError(
loc, "FloatAttr type doesn't match the type implied by its value");
}
return success();
@ -286,31 +280,26 @@ APInt IntegerAttr::getValue() const { return getImpl()->getValue(); }
int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); }
static LogicalResult verifyIntegerTypeInvariants(Optional<Location> loc,
Type type) {
static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) {
if (type.isa<IntegerType>() || type.isa<IndexType>())
return success();
return emitOptionalError(loc, "expected integer or index type");
return emitError(loc, "expected integer or index type");
}
LogicalResult IntegerAttr::verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *ctx,
Type type,
LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
int64_t value) {
return verifyIntegerTypeInvariants(loc, type);
}
LogicalResult IntegerAttr::verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *ctx,
Type type,
LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type,
const APInt &value) {
if (failed(verifyIntegerTypeInvariants(loc, type)))
return failure();
if (auto integerType = type.dyn_cast<IntegerType>())
if (integerType.getWidth() != value.getBitWidth())
return emitOptionalError(
loc, "integer type bit width (", integerType.getWidth(),
") doesn't match value bit width (", value.getBitWidth(), ")");
return emitError(loc, "integer type bit width (")
<< integerType.getWidth() << ") doesn't match value bit width ("
<< value.getBitWidth() << ")";
return success();
}
@ -337,8 +326,8 @@ OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type,
OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData,
Type type, Location location) {
return Base::getChecked(location, type.getContext(),
StandardAttributes::Opaque, dialect, attrData, type);
return Base::getChecked(location, StandardAttributes::Opaque, dialect,
attrData, type);
}
/// Returns the dialect namespace of the opaque attribute.
@ -350,13 +339,12 @@ Identifier OpaqueAttr::getDialectNamespace() const {
StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; }
/// Verify the construction of an opaque attribute.
LogicalResult OpaqueAttr::verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *context,
LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc,
Identifier dialect,
StringRef attrData,
Type type) {
if (!Dialect::isValidNamespace(dialect.strref()))
return emitOptionalError(loc, "invalid dialect namespace '", dialect, "'");
return emitError(loc, "invalid dialect namespace '") << dialect << "'";
return success();
}

View File

@ -518,7 +518,7 @@ IntegerType IntegerType::getChecked(unsigned width, MLIRContext *context,
Location location) {
if (auto cached = getCachedIntegerType(width, context))
return cached;
return Base::getChecked(location, context, StandardTypes::Integer, width);
return Base::getChecked(location, StandardTypes::Integer, width);
}
/// Get an instance of the NoneType.
@ -639,3 +639,16 @@ IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount,
llvm::sys::SmartScopedWriter<true> affineLock(impl.affineMutex);
return constructorFn();
}
//===----------------------------------------------------------------------===//
// StorageUniquerSupport
//===----------------------------------------------------------------------===//
/// Utility method to generate a default location for use when checking the
/// construction invariants of a storage object. This is defined out-of-line to
/// avoid the need to include Location.h.
const AttributeStorage *
mlir::detail::generateUnknownStorageLocation(MLIRContext *ctx) {
return reinterpret_cast<const AttributeStorage *>(
ctx->getImpl().unknownLocAttr.getAsOpaquePointer());
}

View File

@ -52,12 +52,11 @@ bool Type::isIntOrFloat() { return isa<IntegerType>() || isa<FloatType>(); }
constexpr unsigned IntegerType::kMaxWidth;
/// Verify the construction of an integer type.
LogicalResult IntegerType::verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *context,
LogicalResult IntegerType::verifyConstructionInvariants(Location loc,
unsigned width) {
if (width > IntegerType::kMaxWidth) {
return emitOptionalError(loc, "integer bitwidth is limited to ",
IntegerType::kMaxWidth, " bits");
return emitError(loc) << "integer bitwidth is limited to "
<< IntegerType::kMaxWidth << " bits";
}
return success();
}
@ -203,24 +202,20 @@ VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) {
VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType,
Location location) {
return Base::getChecked(location, elementType.getContext(),
StandardTypes::Vector, shape, elementType);
return Base::getChecked(location, StandardTypes::Vector, shape, elementType);
}
LogicalResult VectorType::verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *context,
LogicalResult VectorType::verifyConstructionInvariants(Location loc,
ArrayRef<int64_t> shape,
Type elementType) {
if (shape.empty())
return emitOptionalError(loc,
"vector types must have at least one dimension");
return emitError(loc, "vector types must have at least one dimension");
if (!isValidElementType(elementType))
return emitOptionalError(loc, "vector elements must be int or float type");
return emitError(loc, "vector elements must be int or float type");
if (any_of(shape, [](int64_t i) { return i <= 0; }))
return emitOptionalError(loc,
"vector types must have positive constant sizes");
return emitError(loc, "vector types must have positive constant sizes");
return success();
}
@ -233,11 +228,10 @@ ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
// Check if "elementType" can be an element type of a tensor. Emit errors if
// location is not nullptr. Returns failure if check failed.
static inline LogicalResult checkTensorElementType(Optional<Location> location,
MLIRContext *context,
static inline LogicalResult checkTensorElementType(Location location,
Type elementType) {
if (!TensorType::isValidElementType(elementType))
return emitOptionalError(location, "invalid tensor element type");
return emitError(location, "invalid tensor element type");
return success();
}
@ -254,18 +248,17 @@ RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape,
RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape,
Type elementType,
Location location) {
return Base::getChecked(location, elementType.getContext(),
StandardTypes::RankedTensor, shape, elementType);
return Base::getChecked(location, StandardTypes::RankedTensor, shape,
elementType);
}
LogicalResult RankedTensorType::verifyConstructionInvariants(
Optional<Location> loc, MLIRContext *context, ArrayRef<int64_t> shape,
Type elementType) {
Location loc, ArrayRef<int64_t> shape, Type elementType) {
for (int64_t s : shape) {
if (s < -1)
return emitOptionalError(loc, "invalid tensor dimension size");
return emitError(loc, "invalid tensor dimension size");
}
return checkTensorElementType(loc, context, elementType);
return checkTensorElementType(loc, elementType);
}
ArrayRef<int64_t> RankedTensorType::getShape() const {
@ -283,13 +276,13 @@ UnrankedTensorType UnrankedTensorType::get(Type elementType) {
UnrankedTensorType UnrankedTensorType::getChecked(Type elementType,
Location location) {
return Base::getChecked(location, elementType.getContext(),
StandardTypes::UnrankedTensor, elementType);
return Base::getChecked(location, StandardTypes::UnrankedTensor, elementType);
}
LogicalResult UnrankedTensorType::verifyConstructionInvariants(
Optional<Location> loc, MLIRContext *context, Type elementType) {
return checkTensorElementType(loc, context, elementType);
LogicalResult
UnrankedTensorType::verifyConstructionInvariants(Location loc,
Type elementType) {
return checkTensorElementType(loc, elementType);
}
//===----------------------------------------------------------------------===//
@ -399,8 +392,7 @@ UnrankedMemRefType UnrankedMemRefType::get(Type elementType,
UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType,
unsigned memorySpace,
Location location) {
return Base::getChecked(location, elementType.getContext(),
StandardTypes::UnrankedMemRef, elementType,
return Base::getChecked(location, StandardTypes::UnrankedMemRef, elementType,
memorySpace);
}
@ -408,13 +400,13 @@ unsigned UnrankedMemRefType::getMemorySpace() const {
return getImpl()->memorySpace;
}
LogicalResult UnrankedMemRefType::verifyConstructionInvariants(
Optional<Location> loc, MLIRContext *context, Type elementType,
unsigned memorySpace) {
LogicalResult
UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
unsigned memorySpace) {
// Check that memref is formed from allowed types.
if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
!elementType.isa<ComplexType>())
return emitOptionalError(*loc, "invalid memref element type");
return emitError(loc, "invalid memref element type");
return success();
}
@ -621,16 +613,14 @@ ComplexType ComplexType::get(Type elementType) {
}
ComplexType ComplexType::getChecked(Type elementType, Location location) {
return Base::getChecked(location, elementType.getContext(),
StandardTypes::Complex, elementType);
return Base::getChecked(location, StandardTypes::Complex, elementType);
}
/// Verify the construction of an integer type.
LogicalResult ComplexType::verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *context,
LogicalResult ComplexType::verifyConstructionInvariants(Location loc,
Type elementType) {
if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>())
return emitOptionalError(loc, "invalid element type for complex");
return emitError(loc, "invalid element type for complex");
return success();
}

View File

@ -59,7 +59,7 @@ OpaqueType OpaqueType::get(Identifier dialect, StringRef typeData,
OpaqueType OpaqueType::getChecked(Identifier dialect, StringRef typeData,
MLIRContext *context, Location location) {
return Base::getChecked(location, context, Kind::Opaque, dialect, typeData);
return Base::getChecked(location, Kind::Opaque, dialect, typeData);
}
/// Returns the dialect namespace of the opaque type.
@ -71,11 +71,10 @@ Identifier OpaqueType::getDialectNamespace() const {
StringRef OpaqueType::getTypeData() const { return getImpl()->typeData; }
/// Verify the construction of an opaque type.
LogicalResult OpaqueType::verifyConstructionInvariants(Optional<Location> loc,
MLIRContext *context,
LogicalResult OpaqueType::verifyConstructionInvariants(Location loc,
Identifier dialect,
StringRef typeData) {
if (!Dialect::isValidNamespace(dialect.strref()))
return emitOptionalError(loc, "invalid dialect namespace '", dialect, "'");
return emitError(loc, "invalid dialect namespace '") << dialect << "'";
return success();
}