forked from OSchip/llvm-project
Update TypeBase::verifyConstructionInvariants to use a LogicalResult return instead of bool.
-- PiperOrigin-RevId: 241045568
This commit is contained in:
parent
90d2e16e63
commit
258dbdafa8
|
@ -109,9 +109,9 @@ public:
|
|||
Location location);
|
||||
|
||||
/// Verify the construction of an integer type.
|
||||
static bool verifyConstructionInvariants(llvm::Optional<Location> loc,
|
||||
MLIRContext *context,
|
||||
unsigned width);
|
||||
static LogicalResult
|
||||
verifyConstructionInvariants(llvm::Optional<Location> loc,
|
||||
MLIRContext *context, unsigned width);
|
||||
|
||||
/// Return the bitwidth of this integer type.
|
||||
unsigned getWidth() const;
|
||||
|
@ -249,10 +249,10 @@ public:
|
|||
Location location);
|
||||
|
||||
/// Verify the construction of a vector type.
|
||||
static bool verifyConstructionInvariants(llvm::Optional<Location> loc,
|
||||
MLIRContext *context,
|
||||
ArrayRef<int64_t> shape,
|
||||
Type elementType);
|
||||
static LogicalResult
|
||||
verifyConstructionInvariants(llvm::Optional<Location> loc,
|
||||
MLIRContext *context, ArrayRef<int64_t> shape,
|
||||
Type elementType);
|
||||
|
||||
/// Returns true of the given type can be used as an element of a vector type.
|
||||
/// In particular, vectors can consist of integer or float primitives.
|
||||
|
@ -308,10 +308,10 @@ public:
|
|||
Location location);
|
||||
|
||||
/// Verify the construction of a ranked tensor type.
|
||||
static bool verifyConstructionInvariants(llvm::Optional<Location> loc,
|
||||
MLIRContext *context,
|
||||
ArrayRef<int64_t> shape,
|
||||
Type elementType);
|
||||
static LogicalResult
|
||||
verifyConstructionInvariants(llvm::Optional<Location> loc,
|
||||
MLIRContext *context, ArrayRef<int64_t> shape,
|
||||
Type elementType);
|
||||
|
||||
ArrayRef<int64_t> getShape() const;
|
||||
|
||||
|
@ -339,9 +339,9 @@ public:
|
|||
static UnrankedTensorType getChecked(Type elementType, Location location);
|
||||
|
||||
/// Verify the construction of a unranked tensor type.
|
||||
static bool verifyConstructionInvariants(llvm::Optional<Location> loc,
|
||||
MLIRContext *context,
|
||||
Type elementType);
|
||||
static LogicalResult
|
||||
verifyConstructionInvariants(llvm::Optional<Location> loc,
|
||||
MLIRContext *context, Type elementType);
|
||||
|
||||
ArrayRef<int64_t> getShape() const { return llvm::None; }
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/TypeSupport.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/DenseMapInfo.h"
|
||||
|
||||
|
@ -56,14 +57,15 @@ struct UnknownTypeStorage;
|
|||
/// current type. Used for isa/dyn_cast casting functionality.
|
||||
///
|
||||
/// * Optional:
|
||||
/// - static bool verifyConstructionInvariants(llvm::Optional<Location> loc,
|
||||
/// - static LogicalResult verifyConstructionInvariants(
|
||||
/// llvm::Optional<Location> loc,
|
||||
/// MLIRContext *context,
|
||||
/// Args... args)
|
||||
/// * This method is invoked when calling the 'TypeBase::get/getChecked'
|
||||
/// methods to ensure that the arguments passed in are valid to construct
|
||||
/// a type instance with.
|
||||
/// * This method is expected to return true if a type cannot be
|
||||
/// constructed with 'args'.
|
||||
/// * This method is expected to return failure if a type cannot be
|
||||
/// constructed with 'args', success otherwise.
|
||||
/// * 'args' must correspond with the arguments passed into the
|
||||
/// 'TypeBase::get' call after the type kind.
|
||||
///
|
||||
|
@ -141,8 +143,8 @@ public:
|
|||
template <typename... Args>
|
||||
static ConcreteType get(MLIRContext *context, unsigned kind, Args... args) {
|
||||
// Ensure that the invariants are correct for type construction.
|
||||
assert(!ConcreteType::verifyConstructionInvariants(llvm::None, context,
|
||||
args...));
|
||||
assert(succeeded(ConcreteType::verifyConstructionInvariants(
|
||||
llvm::None, context, args...)));
|
||||
return detail::TypeUniquer::get<ConcreteType>(context, kind, args...);
|
||||
}
|
||||
|
||||
|
@ -153,17 +155,18 @@ public:
|
|||
static ConcreteType getChecked(Location loc, MLIRContext *context,
|
||||
unsigned kind, Args... args) {
|
||||
// If the construction invariants fail then we return a null type.
|
||||
if (ConcreteType::verifyConstructionInvariants(loc, context, args...))
|
||||
if (failed(ConcreteType::verifyConstructionInvariants(loc, context,
|
||||
args...)))
|
||||
return ConcreteType();
|
||||
return detail::TypeUniquer::get<ConcreteType>(context, kind, args...);
|
||||
}
|
||||
|
||||
/// Default implementation that just returns false for success.
|
||||
/// Default implementation that just returns success.
|
||||
template <typename... Args>
|
||||
static bool verifyConstructionInvariants(llvm::Optional<Location> loc,
|
||||
MLIRContext *context,
|
||||
Args... args) {
|
||||
return false;
|
||||
static LogicalResult
|
||||
verifyConstructionInvariants(llvm::Optional<Location> loc,
|
||||
MLIRContext *context, Args... args) {
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Utility for easy access to the storage instance.
|
||||
|
@ -302,10 +305,10 @@ public:
|
|||
StringRef getTypeData() const;
|
||||
|
||||
/// Verify the construction of an unknown type.
|
||||
static bool verifyConstructionInvariants(llvm::Optional<Location> loc,
|
||||
MLIRContext *context,
|
||||
Identifier dialect,
|
||||
StringRef typeData);
|
||||
static LogicalResult
|
||||
verifyConstructionInvariants(llvm::Optional<Location> loc,
|
||||
MLIRContext *context, Identifier dialect,
|
||||
StringRef typeData);
|
||||
|
||||
static bool kindof(unsigned kind) { return kind == Kind::Unknown; }
|
||||
};
|
||||
|
|
|
@ -29,16 +29,15 @@ using namespace mlir::detail;
|
|||
/// Integer Type.
|
||||
|
||||
/// Verify the construction of an integer type.
|
||||
bool IntegerType::verifyConstructionInvariants(llvm::Optional<Location> loc,
|
||||
MLIRContext *context,
|
||||
unsigned width) {
|
||||
LogicalResult IntegerType::verifyConstructionInvariants(
|
||||
llvm::Optional<Location> loc, MLIRContext *context, unsigned width) {
|
||||
if (width > IntegerType::kMaxWidth) {
|
||||
if (loc)
|
||||
context->emitError(*loc, "integer bitwidth is limited to " +
|
||||
Twine(IntegerType::kMaxWidth) + " bits");
|
||||
return true;
|
||||
return failure();
|
||||
}
|
||||
return false;
|
||||
return success();
|
||||
}
|
||||
|
||||
IntegerType IntegerType::get(unsigned width, MLIRContext *context) {
|
||||
|
@ -194,29 +193,28 @@ VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType,
|
|||
StandardTypes::Vector, shape, elementType);
|
||||
}
|
||||
|
||||
bool VectorType::verifyConstructionInvariants(llvm::Optional<Location> loc,
|
||||
MLIRContext *context,
|
||||
ArrayRef<int64_t> shape,
|
||||
Type elementType) {
|
||||
LogicalResult VectorType::verifyConstructionInvariants(
|
||||
llvm::Optional<Location> loc, MLIRContext *context, ArrayRef<int64_t> shape,
|
||||
Type elementType) {
|
||||
if (shape.empty()) {
|
||||
if (loc)
|
||||
context->emitError(*loc, "vector types must have at least one dimension");
|
||||
return true;
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (!isValidElementType(elementType)) {
|
||||
if (loc)
|
||||
context->emitError(*loc, "vector elements must be int or float type");
|
||||
return true;
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (any_of(shape, [](int64_t i) { return i <= 0; })) {
|
||||
if (loc)
|
||||
context->emitError(*loc,
|
||||
"vector types must have positive constant sizes");
|
||||
return true;
|
||||
return failure();
|
||||
}
|
||||
return false;
|
||||
return success();
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
|
||||
|
@ -224,16 +222,16 @@ ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
|
|||
/// TensorType
|
||||
|
||||
// Check if "elementType" can be an element type of a tensor. Emit errors if
|
||||
// location is not nullptr. Returns true if check failed.
|
||||
static inline bool checkTensorElementType(Optional<Location> location,
|
||||
MLIRContext *context,
|
||||
Type elementType) {
|
||||
// location is not nullptr. Returns failure if check failed.
|
||||
static inline LogicalResult checkTensorElementType(Optional<Location> location,
|
||||
MLIRContext *context,
|
||||
Type elementType) {
|
||||
if (!TensorType::isValidElementType(elementType)) {
|
||||
if (location)
|
||||
context->emitError(*location, "invalid tensor element type");
|
||||
return true;
|
||||
return failure();
|
||||
}
|
||||
return false;
|
||||
return success();
|
||||
}
|
||||
|
||||
/// RankedTensorType
|
||||
|
@ -251,14 +249,14 @@ RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape,
|
|||
StandardTypes::RankedTensor, shape, elementType);
|
||||
}
|
||||
|
||||
bool RankedTensorType::verifyConstructionInvariants(
|
||||
LogicalResult RankedTensorType::verifyConstructionInvariants(
|
||||
llvm::Optional<Location> loc, MLIRContext *context, ArrayRef<int64_t> shape,
|
||||
Type elementType) {
|
||||
for (int64_t s : shape) {
|
||||
if (s < -1) {
|
||||
if (loc)
|
||||
context->emitError(*loc, "invalid tensor dimension size");
|
||||
return true;
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
return checkTensorElementType(loc, context, elementType);
|
||||
|
@ -281,7 +279,7 @@ UnrankedTensorType UnrankedTensorType::getChecked(Type elementType,
|
|||
StandardTypes::UnrankedTensor, elementType);
|
||||
}
|
||||
|
||||
bool UnrankedTensorType::verifyConstructionInvariants(
|
||||
LogicalResult UnrankedTensorType::verifyConstructionInvariants(
|
||||
llvm::Optional<Location> loc, MLIRContext *context, Type elementType) {
|
||||
return checkTensorElementType(loc, context, elementType);
|
||||
}
|
||||
|
|
|
@ -71,15 +71,14 @@ Identifier UnknownType::getDialectNamespace() const {
|
|||
StringRef UnknownType::getTypeData() const { return getImpl()->typeData; }
|
||||
|
||||
/// Verify the construction of an unknown type.
|
||||
bool UnknownType::verifyConstructionInvariants(llvm::Optional<Location> loc,
|
||||
MLIRContext *context,
|
||||
Identifier dialect,
|
||||
StringRef typeData) {
|
||||
LogicalResult UnknownType::verifyConstructionInvariants(
|
||||
llvm::Optional<Location> loc, MLIRContext *context, Identifier dialect,
|
||||
StringRef typeData) {
|
||||
if (!Dialect::isValidNamespace(dialect.strref())) {
|
||||
if (loc)
|
||||
context->emitError(*loc, "invalid dialect namespace '" +
|
||||
dialect.strref() + "'");
|
||||
return true;
|
||||
return failure();
|
||||
}
|
||||
return false;
|
||||
return success();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue