forked from OSchip/llvm-project
Support `getShape`, `hasStaticShape` and `getDimSize` methods for all the Vector and Tensor Types.
PiperOrigin-RevId: 216447553
This commit is contained in:
parent
1d3e7e2616
commit
84a0c40261
|
@ -86,7 +86,7 @@ public:
|
|||
MemRefType *getMemRefType(ArrayRef<int> shape, Type *elementType,
|
||||
ArrayRef<AffineMap> affineMapComposition = {},
|
||||
unsigned memorySpace = 0);
|
||||
VectorType *getVectorType(ArrayRef<unsigned> shape, Type *elementType);
|
||||
VectorType *getVectorType(ArrayRef<int> shape, Type *elementType);
|
||||
RankedTensorType *getTensorType(ArrayRef<int> shape, Type *elementType);
|
||||
UnrankedTensorType *getTensorType(Type *elementType);
|
||||
|
||||
|
|
|
@ -293,9 +293,25 @@ class VectorOrTensorType : public Type {
|
|||
public:
|
||||
Type *getElementType() const { return elementType; }
|
||||
|
||||
/// If this is ranked tensor or vector type, return the rank. If it is an
|
||||
/// If this is ranked tensor or vector type, return the rank. If it is an
|
||||
/// unranked tensor, return -1.
|
||||
int getRankIfPresent() const;
|
||||
int getRank() const;
|
||||
|
||||
/// If this is ranked tensor or vector type, return the shape. If it is an
|
||||
/// unranked tensor, return an empty array.
|
||||
ArrayRef<int> getShape() const;
|
||||
|
||||
/// If any dimension has unknown size (<0), it doesn't have static shape.
|
||||
/// If all dimensions has known size (>= 0), it has static shape.
|
||||
bool hasStaticShape() const {
|
||||
auto dims = getShape();
|
||||
return !std::any_of(dims.begin(), dims.end(), [](int i) { return i < 0; });
|
||||
}
|
||||
|
||||
/// If this is ranked tensor or vector type, return the size of the specified
|
||||
/// dimension. It aborts if the tensor is unranked (this can be checked by
|
||||
/// the getRank call method).
|
||||
int getDimSize(unsigned i) const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const Type *type) {
|
||||
|
@ -315,27 +331,22 @@ public:
|
|||
/// known constant shape with one or more dimension.
|
||||
class VectorType : public VectorOrTensorType {
|
||||
public:
|
||||
static VectorType *get(ArrayRef<unsigned> shape, Type *elementType);
|
||||
static VectorType *get(ArrayRef<int> shape, Type *elementType);
|
||||
|
||||
unsigned getRank() const { return getSubclassData(); }
|
||||
|
||||
ArrayRef<unsigned> getShape() const {
|
||||
return ArrayRef<unsigned>(shapeElements, getSubclassData());
|
||||
ArrayRef<int> getShape() const {
|
||||
return ArrayRef<int>(shapeElements, getSubclassData());
|
||||
}
|
||||
|
||||
/// Return the size of the specified dimension.
|
||||
unsigned getDimSize(unsigned i) const { return getShape()[i]; }
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const Type *type) {
|
||||
return type->getKind() == Kind::Vector;
|
||||
}
|
||||
|
||||
private:
|
||||
const unsigned *shapeElements;
|
||||
const int *shapeElements;
|
||||
Type *elementType;
|
||||
|
||||
VectorType(ArrayRef<unsigned> shape, Type *elementType, MLIRContext *context);
|
||||
VectorType(ArrayRef<int> shape, Type *elementType, MLIRContext *context);
|
||||
~VectorType() = delete;
|
||||
};
|
||||
|
||||
|
@ -363,15 +374,10 @@ public:
|
|||
static RankedTensorType *get(ArrayRef<int> shape,
|
||||
Type *elementType);
|
||||
|
||||
unsigned getRank() const { return getSubclassData(); }
|
||||
|
||||
ArrayRef<int> getShape() const {
|
||||
return ArrayRef<int>(shapeElements, getSubclassData());
|
||||
}
|
||||
|
||||
/// Return the size of the specified dimension, or -1 if unspecified.
|
||||
int getDimSize(unsigned i) const { return getShape()[i]; }
|
||||
|
||||
static bool classof(const Type *type) {
|
||||
return type->getKind() == Kind::RankedTensor;
|
||||
}
|
||||
|
@ -390,6 +396,8 @@ class UnrankedTensorType : public TensorType {
|
|||
public:
|
||||
static UnrankedTensorType *get(Type *elementType);
|
||||
|
||||
ArrayRef<int> getShape() const { return ArrayRef<int>(); }
|
||||
|
||||
static bool classof(const Type *type) {
|
||||
return type->getKind() == Kind::UnrankedTensor;
|
||||
}
|
||||
|
|
|
@ -95,8 +95,7 @@ MemRefType *Builder::getMemRefType(ArrayRef<int> shape, Type *elementType,
|
|||
return MemRefType::get(shape, elementType, affineMapComposition, memorySpace);
|
||||
}
|
||||
|
||||
VectorType *Builder::getVectorType(ArrayRef<unsigned> shape,
|
||||
Type *elementType) {
|
||||
VectorType *Builder::getVectorType(ArrayRef<int> shape, Type *elementType) {
|
||||
return VectorType::get(shape, elementType);
|
||||
}
|
||||
|
||||
|
|
|
@ -85,7 +85,7 @@ struct AffineMapKeyInfo : DenseMapInfo<AffineMap> {
|
|||
|
||||
struct VectorTypeKeyInfo : DenseMapInfo<VectorType *> {
|
||||
// Vectors are uniqued based on their element type and shape.
|
||||
using KeyTy = std::pair<Type *, ArrayRef<unsigned>>;
|
||||
using KeyTy = std::pair<Type *, ArrayRef<int>>;
|
||||
using DenseMapInfo<VectorType *>::getHashValue;
|
||||
using DenseMapInfo<VectorType *>::isEqual;
|
||||
|
||||
|
@ -484,10 +484,13 @@ FunctionType *FunctionType::get(ArrayRef<Type *> inputs,
|
|||
return *existing.first = result;
|
||||
}
|
||||
|
||||
VectorType *VectorType::get(ArrayRef<unsigned> shape, Type *elementType) {
|
||||
VectorType *VectorType::get(ArrayRef<int> shape, Type *elementType) {
|
||||
assert(!shape.empty() && "vector types must have at least one dimension");
|
||||
assert((isa<FloatType>(elementType) || isa<IntegerType>(elementType)) &&
|
||||
"vectors elements must be primitives");
|
||||
assert(!std::any_of(shape.begin(), shape.end(), [](int i) {
|
||||
return i < 0;
|
||||
}) && "vector types must have static shape");
|
||||
|
||||
auto *context = elementType->getContext();
|
||||
auto &impl = context->getImpl();
|
||||
|
|
|
@ -792,7 +792,7 @@ bool ExtractElementOp::verify() const {
|
|||
return emitOpError("index to extract_element must have 'index' type");
|
||||
|
||||
// Verify the # indices match if we have a ranked type.
|
||||
auto aggregateRank = aggregateType->getRankIfPresent();
|
||||
auto aggregateRank = aggregateType->getRank();
|
||||
if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1)
|
||||
return emitOpError("incorrect number of indices for extract_element");
|
||||
|
||||
|
|
|
@ -40,22 +40,33 @@ VectorOrTensorType::VectorOrTensorType(Kind kind, MLIRContext *context,
|
|||
Type *elementType, unsigned subClassData)
|
||||
: Type(kind, context, subClassData), elementType(elementType) {}
|
||||
|
||||
/// If this is ranked tensor or vector type, return the rank. If it is an
|
||||
/// If this is ranked tensor or vector type, return the rank. If it is an
|
||||
/// unranked tensor, return -1.
|
||||
int VectorOrTensorType::getRankIfPresent() const {
|
||||
int VectorOrTensorType::getRank() const {
|
||||
switch (getKind()) {
|
||||
default:
|
||||
llvm_unreachable("not a VectorOrTensorType");
|
||||
case Kind::Vector:
|
||||
return cast<VectorType>(this)->getRank();
|
||||
return cast<VectorType>(this)->getShape().size();
|
||||
case Kind::RankedTensor:
|
||||
return cast<RankedTensorType>(this)->getRank();
|
||||
return cast<RankedTensorType>(this)->getShape().size();
|
||||
case Kind::UnrankedTensor:
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
VectorType::VectorType(ArrayRef<unsigned> shape, Type *elementType,
|
||||
int VectorOrTensorType::getDimSize(unsigned i) const {
|
||||
switch (getKind()) {
|
||||
case Kind::Vector:
|
||||
return cast<VectorType>(this)->getShape()[i];
|
||||
case Kind::RankedTensor:
|
||||
return cast<RankedTensorType>(this)->getShape()[i];
|
||||
default:
|
||||
llvm_unreachable("not a VectorOrTensorType");
|
||||
}
|
||||
}
|
||||
|
||||
VectorType::VectorType(ArrayRef<int> shape, Type *elementType,
|
||||
MLIRContext *context)
|
||||
: VectorOrTensorType(Kind::Vector, context, elementType, shape.size()),
|
||||
shapeElements(shape.data()) {}
|
||||
|
|
|
@ -372,13 +372,13 @@ VectorType *Parser::parseVectorType() {
|
|||
if (getToken().isNot(Token::integer))
|
||||
return (emitError("expected dimension size in vector type"), nullptr);
|
||||
|
||||
SmallVector<unsigned, 4> dimensions;
|
||||
SmallVector<int, 4> dimensions;
|
||||
while (getToken().is(Token::integer)) {
|
||||
// Make sure this integer value is in bound and valid.
|
||||
auto dimension = getToken().getUnsignedIntegerValue();
|
||||
if (!dimension.hasValue())
|
||||
return (emitError("invalid dimension in vector type"), nullptr);
|
||||
dimensions.push_back(dimension.getValue());
|
||||
dimensions.push_back((int)dimension.getValue());
|
||||
|
||||
consumeToken(Token::integer);
|
||||
|
||||
|
|
Loading…
Reference in New Issue