forked from OSchip/llvm-project
Some cleanup of ShapedType now that MemRef subclasses it.
Extract common methods into ShapedType. Simplify methods. Remove some extraneous asserts. Replace sentinel value with a helper method to check the same. -- PiperOrigin-RevId: 250945261
This commit is contained in:
parent
32de860a09
commit
ac4b0a1e7b
|
@ -214,6 +214,10 @@ public:
|
|||
/// has static shape.
|
||||
bool hasStaticShape() const;
|
||||
|
||||
/// If this is a ranked type, return the number of dimensions with dynamic
|
||||
/// size. Otherwise, abort.
|
||||
unsigned getNumDynamicDims() const;
|
||||
|
||||
/// If this is ranked type, return the size of the specified dimension.
|
||||
/// Otherwise, abort.
|
||||
int64_t getDimSize(unsigned i) const;
|
||||
|
@ -233,6 +237,9 @@ public:
|
|||
type.getKind() == StandardTypes::UnrankedTensor ||
|
||||
type.getKind() == StandardTypes::MemRef;
|
||||
}
|
||||
|
||||
/// Whether the given dimension size indicates a dynamic dimension.
|
||||
static constexpr bool isDynamic(int64_t dSize) { return dSize < 0; }
|
||||
};
|
||||
|
||||
/// Vector types represent multi-dimensional SIMD vectors, and have a fixed
|
||||
|
@ -402,21 +409,8 @@ public:
|
|||
/// Returns the memory space in which data referred to by this memref resides.
|
||||
unsigned getMemorySpace() const;
|
||||
|
||||
// TODO(b/132735995) Extract into shaped type.
|
||||
/// Returns the number of dimensions with dynamic size.
|
||||
unsigned getNumDynamicDims() const;
|
||||
|
||||
// TODO(b/132735995) Extract into shaped type.
|
||||
/// If any dimension of the shape has unknown size (<0), it doesn't have
|
||||
/// static shape.
|
||||
bool hasStaticShape() const { return getNumDynamicDims() == 0; }
|
||||
|
||||
static bool kindof(unsigned kind) { return kind == StandardTypes::MemRef; }
|
||||
|
||||
// TODO(b/132735995) Extract into shaped type.
|
||||
/// Integer value indicating that the size in a dimension is dynamic.
|
||||
static constexpr int64_t kDynamicDimSize = -1;
|
||||
|
||||
private:
|
||||
/// Get or create a new MemRefType defined by the arguments. If the resulting
|
||||
/// type would be ill-formed, return nullptr. If the location is provided,
|
||||
|
|
|
@ -308,7 +308,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth,
|
|||
for (unsigned r = 0; r < rank; r++) {
|
||||
cst.addConstantLowerBound(r, 0);
|
||||
int64_t dimSize = memRefType.getDimSize(r);
|
||||
if (dimSize == MemRefType::kDynamicDimSize)
|
||||
if (ShapedType::isDynamic(dimSize))
|
||||
continue;
|
||||
cst.addConstantUpperBound(r, dimSize - 1);
|
||||
}
|
||||
|
|
|
@ -116,7 +116,7 @@ unsigned ShapedType::getElementTypeBitWidth() const {
|
|||
}
|
||||
|
||||
unsigned ShapedType::getNumElements() const {
|
||||
assert(hasStaticShape() && "expected type to have static shape");
|
||||
assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
|
||||
auto shape = getShape();
|
||||
unsigned num = 1;
|
||||
for (auto dim : shape)
|
||||
|
@ -124,18 +124,11 @@ unsigned ShapedType::getNumElements() const {
|
|||
return num;
|
||||
}
|
||||
|
||||
int64_t ShapedType::getRank() const {
|
||||
assert(hasRank());
|
||||
return getShape().size();
|
||||
}
|
||||
int64_t ShapedType::getRank() const { return getShape().size(); }
|
||||
|
||||
bool ShapedType::hasRank() const { return !isa<UnrankedTensorType>(); }
|
||||
|
||||
int64_t ShapedType::getDimSize(unsigned i) const {
|
||||
if (hasRank())
|
||||
return getShape()[i];
|
||||
llvm_unreachable("not a ShapedType or not ranked");
|
||||
}
|
||||
int64_t ShapedType::getDimSize(unsigned i) const { return getShape()[i]; }
|
||||
|
||||
/// Get the number of bits require to store a value of the given shaped type.
|
||||
/// Compute the value recursively since tensors are allowed to have vectors as
|
||||
|
@ -169,10 +162,12 @@ ArrayRef<int64_t> ShapedType::getShape() const {
|
|||
}
|
||||
}
|
||||
|
||||
unsigned ShapedType::getNumDynamicDims() const {
|
||||
return llvm::count_if(getShape(), isDynamic);
|
||||
}
|
||||
|
||||
bool ShapedType::hasStaticShape() const {
|
||||
if (!hasRank())
|
||||
return false;
|
||||
return llvm::none_of(getShape(), [](int64_t i) { return i < 0; });
|
||||
return hasRank() && llvm::none_of(getShape(), isDynamic);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -291,9 +286,6 @@ LogicalResult UnrankedTensorType::verifyConstructionInvariants(
|
|||
// MemRefType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// static constexpr must have a definition (until in C++17 and inline variable).
|
||||
constexpr int64_t MemRefType::kDynamicDimSize;
|
||||
|
||||
/// Get or create a new MemRefType defined by the arguments. If the resulting
|
||||
/// type would be ill-formed, return nullptr. If the location is provided,
|
||||
/// emit detailed error messages. To emit errors when the location is unknown,
|
||||
|
@ -355,10 +347,6 @@ ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
|
|||
|
||||
unsigned MemRefType::getMemorySpace() const { return getImpl()->memorySpace; }
|
||||
|
||||
unsigned MemRefType::getNumDynamicDims() const {
|
||||
return llvm::count_if(getShape(), [](int64_t i) { return i < 0; });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// ComplexType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
Loading…
Reference in New Issue