Some cleanup of ShapedType now that MemRef subclasses it.

Extract common methods into ShapedType.
    Simplify methods.
    Remove some extraneous asserts.
    Replace sentinel value with a helper method to check the same.

--

PiperOrigin-RevId: 250945261
This commit is contained in:
Geoffrey Martin-Noble 2019-05-31 13:28:19 -07:00 committed by Mehdi Amini
parent 32de860a09
commit ac4b0a1e7b
3 changed files with 16 additions and 34 deletions

View File

@ -214,6 +214,10 @@ public:
/// has static shape.
bool hasStaticShape() const;
/// If this is a ranked type, return the number of dimensions with dynamic
/// size. Otherwise, abort.
unsigned getNumDynamicDims() const;
/// If this is ranked type, return the size of the specified dimension.
/// Otherwise, abort.
int64_t getDimSize(unsigned i) const;
@ -233,6 +237,9 @@ public:
type.getKind() == StandardTypes::UnrankedTensor ||
type.getKind() == StandardTypes::MemRef;
}
/// Whether the given dimension size indicates a dynamic dimension.
static constexpr bool isDynamic(int64_t dSize) { return dSize < 0; }
};
/// Vector types represent multi-dimensional SIMD vectors, and have a fixed
@ -402,21 +409,8 @@ public:
/// Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpace() const;
// TODO(b/132735995) Extract into shaped type.
/// Returns the number of dimensions with dynamic size.
unsigned getNumDynamicDims() const;
// TODO(b/132735995) Extract into shaped type.
/// If any dimension of the shape has unknown size (<0), it doesn't have
/// static shape.
bool hasStaticShape() const { return getNumDynamicDims() == 0; }
static bool kindof(unsigned kind) { return kind == StandardTypes::MemRef; }
// TODO(b/132735995) Extract into shaped type.
/// Integer value indicating that the size in a dimension is dynamic.
static constexpr int64_t kDynamicDimSize = -1;
private:
/// Get or create a new MemRefType defined by the arguments. If the resulting
/// type would be ill-formed, return nullptr. If the location is provided,

View File

@ -308,7 +308,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
for (unsigned r = 0; r < rank; r++) {
cst.addConstantLowerBound(r, 0);
int64_t dimSize = memRefType.getDimSize(r);
if (dimSize == MemRefType::kDynamicDimSize)
if (ShapedType::isDynamic(dimSize))
continue;
cst.addConstantUpperBound(r, dimSize - 1);
}

View File

@ -116,7 +116,7 @@ unsigned ShapedType::getElementTypeBitWidth() const {
}
unsigned ShapedType::getNumElements() const {
assert(hasStaticShape() && "expected type to have static shape");
assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
auto shape = getShape();
unsigned num = 1;
for (auto dim : shape)
@ -124,18 +124,11 @@ unsigned ShapedType::getNumElements() const {
return num;
}
int64_t ShapedType::getRank() const {
assert(hasRank());
return getShape().size();
}
int64_t ShapedType::getRank() const { return getShape().size(); }
bool ShapedType::hasRank() const { return !isa<UnrankedTensorType>(); }
int64_t ShapedType::getDimSize(unsigned i) const {
if (hasRank())
return getShape()[i];
llvm_unreachable("not a ShapedType or not ranked");
}
int64_t ShapedType::getDimSize(unsigned i) const { return getShape()[i]; }
/// Get the number of bits require to store a value of the given shaped type.
/// Compute the value recursively since tensors are allowed to have vectors as
@ -169,10 +162,12 @@ ArrayRef<int64_t> ShapedType::getShape() const {
}
}
unsigned ShapedType::getNumDynamicDims() const {
return llvm::count_if(getShape(), isDynamic);
}
bool ShapedType::hasStaticShape() const {
if (!hasRank())
return false;
return llvm::none_of(getShape(), [](int64_t i) { return i < 0; });
return hasRank() && llvm::none_of(getShape(), isDynamic);
}
//===----------------------------------------------------------------------===//
@ -291,9 +286,6 @@ LogicalResult UnrankedTensorType::verifyConstructionInvariants(
// MemRefType
//===----------------------------------------------------------------------===//
// static constexpr must have a definition (until in C++17 and inline variable).
constexpr int64_t MemRefType::kDynamicDimSize;
/// Get or create a new MemRefType defined by the arguments. If the resulting
/// type would be ill-formed, return nullptr. If the location is provided,
/// emit detailed error messages. To emit errors when the location is unknown,
@ -355,10 +347,6 @@ ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
unsigned MemRefType::getMemorySpace() const { return getImpl()->memorySpace; }
unsigned MemRefType::getNumDynamicDims() const {
return llvm::count_if(getShape(), [](int64_t i) { return i < 0; });
}
//===----------------------------------------------------------------------===//
/// ComplexType
//===----------------------------------------------------------------------===//