Update TypeBase::verifyConstructionInvariants to use a LogicalResult return instead of bool.

--

PiperOrigin-RevId: 241045568
This commit is contained in:
River Riddle 2019-03-29 13:59:28 -07:00 committed by jpienaar
parent 90d2e16e63
commit 258dbdafa8
4 changed files with 57 additions and 57 deletions

View File

@ -109,9 +109,9 @@ public:
Location location);
/// Verify the construction of an integer type.
static bool verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *context,
unsigned width);
static LogicalResult
verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *context, unsigned width);
/// Return the bitwidth of this integer type.
unsigned getWidth() const;
@ -249,10 +249,10 @@ public:
Location location);
/// Verify the construction of a vector type.
static bool verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *context,
ArrayRef<int64_t> shape,
Type elementType);
static LogicalResult
verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *context, ArrayRef<int64_t> shape,
Type elementType);
/// Returns true of the given type can be used as an element of a vector type.
/// In particular, vectors can consist of integer or float primitives.
@ -308,10 +308,10 @@ public:
Location location);
/// Verify the construction of a ranked tensor type.
static bool verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *context,
ArrayRef<int64_t> shape,
Type elementType);
static LogicalResult
verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *context, ArrayRef<int64_t> shape,
Type elementType);
ArrayRef<int64_t> getShape() const;
@ -339,9 +339,9 @@ public:
static UnrankedTensorType getChecked(Type elementType, Location location);
/// Verify the construction of a unranked tensor type.
static bool verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *context,
Type elementType);
static LogicalResult
verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *context, Type elementType);
ArrayRef<int64_t> getShape() const { return llvm::None; }

View File

@ -21,6 +21,7 @@
#include "mlir/IR/Location.h"
#include "mlir/IR/TypeSupport.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMapInfo.h"
@ -56,14 +57,15 @@ struct UnknownTypeStorage;
/// current type. Used for isa/dyn_cast casting functionality.
///
/// * Optional:
/// - static bool verifyConstructionInvariants(llvm::Optional<Location> loc,
/// - static LogicalResult verifyConstructionInvariants(
/// llvm::Optional<Location> loc,
/// MLIRContext *context,
/// Args... args)
/// * This method is invoked when calling the 'TypeBase::get/getChecked'
/// methods to ensure that the arguments passed in are valid to construct
/// a type instance with.
/// * This method is expected to return true if a type cannot be
/// constructed with 'args'.
/// * This method is expected to return failure if a type cannot be
/// constructed with 'args', success otherwise.
/// * 'args' must correspond with the arguments passed into the
/// 'TypeBase::get' call after the type kind.
///
@ -141,8 +143,8 @@ public:
template <typename... Args>
static ConcreteType get(MLIRContext *context, unsigned kind, Args... args) {
// Ensure that the invariants are correct for type construction.
assert(!ConcreteType::verifyConstructionInvariants(llvm::None, context,
args...));
assert(succeeded(ConcreteType::verifyConstructionInvariants(
llvm::None, context, args...)));
return detail::TypeUniquer::get<ConcreteType>(context, kind, args...);
}
@ -153,17 +155,18 @@ public:
static ConcreteType getChecked(Location loc, MLIRContext *context,
unsigned kind, Args... args) {
// If the construction invariants fail then we return a null type.
if (ConcreteType::verifyConstructionInvariants(loc, context, args...))
if (failed(ConcreteType::verifyConstructionInvariants(loc, context,
args...)))
return ConcreteType();
return detail::TypeUniquer::get<ConcreteType>(context, kind, args...);
}
/// Default implementation that just returns false for success.
/// Default implementation that just returns success.
template <typename... Args>
static bool verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *context,
Args... args) {
return false;
static LogicalResult
verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *context, Args... args) {
return success();
}
/// Utility for easy access to the storage instance.
@ -302,10 +305,10 @@ public:
StringRef getTypeData() const;
/// Verify the construction of an unknown type.
static bool verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *context,
Identifier dialect,
StringRef typeData);
static LogicalResult
verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *context, Identifier dialect,
StringRef typeData);
static bool kindof(unsigned kind) { return kind == Kind::Unknown; }
};

View File

@ -29,16 +29,15 @@ using namespace mlir::detail;
/// Integer Type.
/// Verify the construction of an integer type.
bool IntegerType::verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *context,
unsigned width) {
LogicalResult IntegerType::verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, unsigned width) {
if (width > IntegerType::kMaxWidth) {
if (loc)
context->emitError(*loc, "integer bitwidth is limited to " +
Twine(IntegerType::kMaxWidth) + " bits");
return true;
return failure();
}
return false;
return success();
}
IntegerType IntegerType::get(unsigned width, MLIRContext *context) {
@ -194,29 +193,28 @@ VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType,
StandardTypes::Vector, shape, elementType);
}
bool VectorType::verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *context,
ArrayRef<int64_t> shape,
Type elementType) {
LogicalResult VectorType::verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, ArrayRef<int64_t> shape,
Type elementType) {
if (shape.empty()) {
if (loc)
context->emitError(*loc, "vector types must have at least one dimension");
return true;
return failure();
}
if (!isValidElementType(elementType)) {
if (loc)
context->emitError(*loc, "vector elements must be int or float type");
return true;
return failure();
}
if (any_of(shape, [](int64_t i) { return i <= 0; })) {
if (loc)
context->emitError(*loc,
"vector types must have positive constant sizes");
return true;
return failure();
}
return false;
return success();
}
ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
@ -224,16 +222,16 @@ ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
/// TensorType
// Check if "elementType" can be an element type of a tensor. Emit errors if
// location is not nullptr. Returns true if check failed.
static inline bool checkTensorElementType(Optional<Location> location,
MLIRContext *context,
Type elementType) {
// location is not nullptr. Returns failure if check failed.
static inline LogicalResult checkTensorElementType(Optional<Location> location,
MLIRContext *context,
Type elementType) {
if (!TensorType::isValidElementType(elementType)) {
if (location)
context->emitError(*location, "invalid tensor element type");
return true;
return failure();
}
return false;
return success();
}
/// RankedTensorType
@ -251,14 +249,14 @@ RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape,
StandardTypes::RankedTensor, shape, elementType);
}
bool RankedTensorType::verifyConstructionInvariants(
LogicalResult RankedTensorType::verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, ArrayRef<int64_t> shape,
Type elementType) {
for (int64_t s : shape) {
if (s < -1) {
if (loc)
context->emitError(*loc, "invalid tensor dimension size");
return true;
return failure();
}
}
return checkTensorElementType(loc, context, elementType);
@ -281,7 +279,7 @@ UnrankedTensorType UnrankedTensorType::getChecked(Type elementType,
StandardTypes::UnrankedTensor, elementType);
}
bool UnrankedTensorType::verifyConstructionInvariants(
LogicalResult UnrankedTensorType::verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, Type elementType) {
return checkTensorElementType(loc, context, elementType);
}

View File

@ -71,15 +71,14 @@ Identifier UnknownType::getDialectNamespace() const {
StringRef UnknownType::getTypeData() const { return getImpl()->typeData; }
/// Verify the construction of an unknown type.
bool UnknownType::verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *context,
Identifier dialect,
StringRef typeData) {
LogicalResult UnknownType::verifyConstructionInvariants(
llvm::Optional<Location> loc, MLIRContext *context, Identifier dialect,
StringRef typeData) {
if (!Dialect::isValidNamespace(dialect.strref())) {
if (loc)
context->emitError(*loc, "invalid dialect namespace '" +
dialect.strref() + "'");
return true;
return failure();
}
return false;
return success();
}