Support `getShape`, `hasStaticShape` and `getDimSize` methods for all the Vector and Tensor Types.

PiperOrigin-RevId: 216447553
This commit is contained in:
Feng Liu 2018-10-09 16:49:39 -07:00 committed by jpienaar
parent 1d3e7e2616
commit 84a0c40261
7 changed files with 51 additions and 30 deletions

View File

@ -86,7 +86,7 @@ public:
MemRefType *getMemRefType(ArrayRef<int> shape, Type *elementType,
ArrayRef<AffineMap> affineMapComposition = {},
unsigned memorySpace = 0);
VectorType *getVectorType(ArrayRef<unsigned> shape, Type *elementType);
VectorType *getVectorType(ArrayRef<int> shape, Type *elementType);
RankedTensorType *getTensorType(ArrayRef<int> shape, Type *elementType);
UnrankedTensorType *getTensorType(Type *elementType);

View File

@ -295,7 +295,23 @@ public:
/// If this is ranked tensor or vector type, return the rank. If it is an
/// unranked tensor, return -1.
int getRankIfPresent() const;
int getRank() const;
/// If this is ranked tensor or vector type, return the shape. If it is an
/// unranked tensor, return an empty array.
ArrayRef<int> getShape() const;
/// If any dimension has unknown size (<0), it doesn't have static shape.
/// If all dimensions has known size (>= 0), it has static shape.
bool hasStaticShape() const {
auto dims = getShape();
return !std::any_of(dims.begin(), dims.end(), [](int i) { return i < 0; });
}
/// If this is ranked tensor or vector type, return the size of the specified
/// dimension. It aborts if the tensor is unranked (this can be checked by
/// the getRank call method).
int getDimSize(unsigned i) const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Type *type) {
@ -315,27 +331,22 @@ public:
/// known constant shape with one or more dimension.
class VectorType : public VectorOrTensorType {
public:
static VectorType *get(ArrayRef<unsigned> shape, Type *elementType);
static VectorType *get(ArrayRef<int> shape, Type *elementType);
unsigned getRank() const { return getSubclassData(); }
ArrayRef<unsigned> getShape() const {
return ArrayRef<unsigned>(shapeElements, getSubclassData());
ArrayRef<int> getShape() const {
return ArrayRef<int>(shapeElements, getSubclassData());
}
/// Return the size of the specified dimension.
unsigned getDimSize(unsigned i) const { return getShape()[i]; }
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Type *type) {
return type->getKind() == Kind::Vector;
}
private:
const unsigned *shapeElements;
const int *shapeElements;
Type *elementType;
VectorType(ArrayRef<unsigned> shape, Type *elementType, MLIRContext *context);
VectorType(ArrayRef<int> shape, Type *elementType, MLIRContext *context);
~VectorType() = delete;
};
@ -363,15 +374,10 @@ public:
static RankedTensorType *get(ArrayRef<int> shape,
Type *elementType);
unsigned getRank() const { return getSubclassData(); }
ArrayRef<int> getShape() const {
return ArrayRef<int>(shapeElements, getSubclassData());
}
/// Return the size of the specified dimension, or -1 if unspecified.
int getDimSize(unsigned i) const { return getShape()[i]; }
static bool classof(const Type *type) {
return type->getKind() == Kind::RankedTensor;
}
@ -390,6 +396,8 @@ class UnrankedTensorType : public TensorType {
public:
static UnrankedTensorType *get(Type *elementType);
ArrayRef<int> getShape() const { return ArrayRef<int>(); }
static bool classof(const Type *type) {
return type->getKind() == Kind::UnrankedTensor;
}

View File

@ -95,8 +95,7 @@ MemRefType *Builder::getMemRefType(ArrayRef<int> shape, Type *elementType,
return MemRefType::get(shape, elementType, affineMapComposition, memorySpace);
}
VectorType *Builder::getVectorType(ArrayRef<unsigned> shape,
Type *elementType) {
VectorType *Builder::getVectorType(ArrayRef<int> shape, Type *elementType) {
return VectorType::get(shape, elementType);
}

View File

@ -85,7 +85,7 @@ struct AffineMapKeyInfo : DenseMapInfo<AffineMap> {
struct VectorTypeKeyInfo : DenseMapInfo<VectorType *> {
// Vectors are uniqued based on their element type and shape.
using KeyTy = std::pair<Type *, ArrayRef<unsigned>>;
using KeyTy = std::pair<Type *, ArrayRef<int>>;
using DenseMapInfo<VectorType *>::getHashValue;
using DenseMapInfo<VectorType *>::isEqual;
@ -484,10 +484,13 @@ FunctionType *FunctionType::get(ArrayRef<Type *> inputs,
return *existing.first = result;
}
VectorType *VectorType::get(ArrayRef<unsigned> shape, Type *elementType) {
VectorType *VectorType::get(ArrayRef<int> shape, Type *elementType) {
assert(!shape.empty() && "vector types must have at least one dimension");
assert((isa<FloatType>(elementType) || isa<IntegerType>(elementType)) &&
"vectors elements must be primitives");
assert(!std::any_of(shape.begin(), shape.end(), [](int i) {
return i < 0;
}) && "vector types must have static shape");
auto *context = elementType->getContext();
auto &impl = context->getImpl();

View File

@ -792,7 +792,7 @@ bool ExtractElementOp::verify() const {
return emitOpError("index to extract_element must have 'index' type");
// Verify the # indices match if we have a ranked type.
auto aggregateRank = aggregateType->getRankIfPresent();
auto aggregateRank = aggregateType->getRank();
if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1)
return emitOpError("incorrect number of indices for extract_element");

View File

@ -42,20 +42,31 @@ VectorOrTensorType::VectorOrTensorType(Kind kind, MLIRContext *context,
/// If this is ranked tensor or vector type, return the rank. If it is an
/// unranked tensor, return -1.
int VectorOrTensorType::getRankIfPresent() const {
int VectorOrTensorType::getRank() const {
switch (getKind()) {
default:
llvm_unreachable("not a VectorOrTensorType");
case Kind::Vector:
return cast<VectorType>(this)->getRank();
return cast<VectorType>(this)->getShape().size();
case Kind::RankedTensor:
return cast<RankedTensorType>(this)->getRank();
return cast<RankedTensorType>(this)->getShape().size();
case Kind::UnrankedTensor:
return -1;
}
}
VectorType::VectorType(ArrayRef<unsigned> shape, Type *elementType,
int VectorOrTensorType::getDimSize(unsigned i) const {
switch (getKind()) {
case Kind::Vector:
return cast<VectorType>(this)->getShape()[i];
case Kind::RankedTensor:
return cast<RankedTensorType>(this)->getShape()[i];
default:
llvm_unreachable("not a VectorOrTensorType");
}
}
VectorType::VectorType(ArrayRef<int> shape, Type *elementType,
MLIRContext *context)
: VectorOrTensorType(Kind::Vector, context, elementType, shape.size()),
shapeElements(shape.data()) {}

View File

@ -372,13 +372,13 @@ VectorType *Parser::parseVectorType() {
if (getToken().isNot(Token::integer))
return (emitError("expected dimension size in vector type"), nullptr);
SmallVector<unsigned, 4> dimensions;
SmallVector<int, 4> dimensions;
while (getToken().is(Token::integer)) {
// Make sure this integer value is in bound and valid.
auto dimension = getToken().getUnsignedIntegerValue();
if (!dimension.hasValue())
return (emitError("invalid dimension in vector type"), nullptr);
dimensions.push_back(dimension.getValue());
dimensions.push_back((int)dimension.getValue());
consumeToken(Token::integer);