forked from OSchip/llvm-project
Support `getShape`, `hasStaticShape` and `getDimSize` methods for all the Vector and Tensor Types.
PiperOrigin-RevId: 216447553
This commit is contained in:
parent
1d3e7e2616
commit
84a0c40261
|
@ -86,7 +86,7 @@ public:
|
||||||
MemRefType *getMemRefType(ArrayRef<int> shape, Type *elementType,
|
MemRefType *getMemRefType(ArrayRef<int> shape, Type *elementType,
|
||||||
ArrayRef<AffineMap> affineMapComposition = {},
|
ArrayRef<AffineMap> affineMapComposition = {},
|
||||||
unsigned memorySpace = 0);
|
unsigned memorySpace = 0);
|
||||||
VectorType *getVectorType(ArrayRef<unsigned> shape, Type *elementType);
|
VectorType *getVectorType(ArrayRef<int> shape, Type *elementType);
|
||||||
RankedTensorType *getTensorType(ArrayRef<int> shape, Type *elementType);
|
RankedTensorType *getTensorType(ArrayRef<int> shape, Type *elementType);
|
||||||
UnrankedTensorType *getTensorType(Type *elementType);
|
UnrankedTensorType *getTensorType(Type *elementType);
|
||||||
|
|
||||||
|
|
|
@ -293,9 +293,25 @@ class VectorOrTensorType : public Type {
|
||||||
public:
|
public:
|
||||||
Type *getElementType() const { return elementType; }
|
Type *getElementType() const { return elementType; }
|
||||||
|
|
||||||
/// If this is ranked tensor or vector type, return the rank. If it is an
|
/// If this is ranked tensor or vector type, return the rank. If it is an
|
||||||
/// unranked tensor, return -1.
|
/// unranked tensor, return -1.
|
||||||
int getRankIfPresent() const;
|
int getRank() const;
|
||||||
|
|
||||||
|
/// If this is ranked tensor or vector type, return the shape. If it is an
|
||||||
|
/// unranked tensor, return an empty array.
|
||||||
|
ArrayRef<int> getShape() const;
|
||||||
|
|
||||||
|
/// If any dimension has unknown size (<0), it doesn't have static shape.
|
||||||
|
/// If all dimensions has known size (>= 0), it has static shape.
|
||||||
|
bool hasStaticShape() const {
|
||||||
|
auto dims = getShape();
|
||||||
|
return !std::any_of(dims.begin(), dims.end(), [](int i) { return i < 0; });
|
||||||
|
}
|
||||||
|
|
||||||
|
/// If this is ranked tensor or vector type, return the size of the specified
|
||||||
|
/// dimension. It aborts if the tensor is unranked (this can be checked by
|
||||||
|
/// the getRank call method).
|
||||||
|
int getDimSize(unsigned i) const;
|
||||||
|
|
||||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||||
static bool classof(const Type *type) {
|
static bool classof(const Type *type) {
|
||||||
|
@ -315,27 +331,22 @@ public:
|
||||||
/// known constant shape with one or more dimension.
|
/// known constant shape with one or more dimension.
|
||||||
class VectorType : public VectorOrTensorType {
|
class VectorType : public VectorOrTensorType {
|
||||||
public:
|
public:
|
||||||
static VectorType *get(ArrayRef<unsigned> shape, Type *elementType);
|
static VectorType *get(ArrayRef<int> shape, Type *elementType);
|
||||||
|
|
||||||
unsigned getRank() const { return getSubclassData(); }
|
ArrayRef<int> getShape() const {
|
||||||
|
return ArrayRef<int>(shapeElements, getSubclassData());
|
||||||
ArrayRef<unsigned> getShape() const {
|
|
||||||
return ArrayRef<unsigned>(shapeElements, getSubclassData());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return the size of the specified dimension.
|
|
||||||
unsigned getDimSize(unsigned i) const { return getShape()[i]; }
|
|
||||||
|
|
||||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||||
static bool classof(const Type *type) {
|
static bool classof(const Type *type) {
|
||||||
return type->getKind() == Kind::Vector;
|
return type->getKind() == Kind::Vector;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const unsigned *shapeElements;
|
const int *shapeElements;
|
||||||
Type *elementType;
|
Type *elementType;
|
||||||
|
|
||||||
VectorType(ArrayRef<unsigned> shape, Type *elementType, MLIRContext *context);
|
VectorType(ArrayRef<int> shape, Type *elementType, MLIRContext *context);
|
||||||
~VectorType() = delete;
|
~VectorType() = delete;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -363,15 +374,10 @@ public:
|
||||||
static RankedTensorType *get(ArrayRef<int> shape,
|
static RankedTensorType *get(ArrayRef<int> shape,
|
||||||
Type *elementType);
|
Type *elementType);
|
||||||
|
|
||||||
unsigned getRank() const { return getSubclassData(); }
|
|
||||||
|
|
||||||
ArrayRef<int> getShape() const {
|
ArrayRef<int> getShape() const {
|
||||||
return ArrayRef<int>(shapeElements, getSubclassData());
|
return ArrayRef<int>(shapeElements, getSubclassData());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return the size of the specified dimension, or -1 if unspecified.
|
|
||||||
int getDimSize(unsigned i) const { return getShape()[i]; }
|
|
||||||
|
|
||||||
static bool classof(const Type *type) {
|
static bool classof(const Type *type) {
|
||||||
return type->getKind() == Kind::RankedTensor;
|
return type->getKind() == Kind::RankedTensor;
|
||||||
}
|
}
|
||||||
|
@ -390,6 +396,8 @@ class UnrankedTensorType : public TensorType {
|
||||||
public:
|
public:
|
||||||
static UnrankedTensorType *get(Type *elementType);
|
static UnrankedTensorType *get(Type *elementType);
|
||||||
|
|
||||||
|
ArrayRef<int> getShape() const { return ArrayRef<int>(); }
|
||||||
|
|
||||||
static bool classof(const Type *type) {
|
static bool classof(const Type *type) {
|
||||||
return type->getKind() == Kind::UnrankedTensor;
|
return type->getKind() == Kind::UnrankedTensor;
|
||||||
}
|
}
|
||||||
|
|
|
@ -95,8 +95,7 @@ MemRefType *Builder::getMemRefType(ArrayRef<int> shape, Type *elementType,
|
||||||
return MemRefType::get(shape, elementType, affineMapComposition, memorySpace);
|
return MemRefType::get(shape, elementType, affineMapComposition, memorySpace);
|
||||||
}
|
}
|
||||||
|
|
||||||
VectorType *Builder::getVectorType(ArrayRef<unsigned> shape,
|
VectorType *Builder::getVectorType(ArrayRef<int> shape, Type *elementType) {
|
||||||
Type *elementType) {
|
|
||||||
return VectorType::get(shape, elementType);
|
return VectorType::get(shape, elementType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -85,7 +85,7 @@ struct AffineMapKeyInfo : DenseMapInfo<AffineMap> {
|
||||||
|
|
||||||
struct VectorTypeKeyInfo : DenseMapInfo<VectorType *> {
|
struct VectorTypeKeyInfo : DenseMapInfo<VectorType *> {
|
||||||
// Vectors are uniqued based on their element type and shape.
|
// Vectors are uniqued based on their element type and shape.
|
||||||
using KeyTy = std::pair<Type *, ArrayRef<unsigned>>;
|
using KeyTy = std::pair<Type *, ArrayRef<int>>;
|
||||||
using DenseMapInfo<VectorType *>::getHashValue;
|
using DenseMapInfo<VectorType *>::getHashValue;
|
||||||
using DenseMapInfo<VectorType *>::isEqual;
|
using DenseMapInfo<VectorType *>::isEqual;
|
||||||
|
|
||||||
|
@ -484,10 +484,13 @@ FunctionType *FunctionType::get(ArrayRef<Type *> inputs,
|
||||||
return *existing.first = result;
|
return *existing.first = result;
|
||||||
}
|
}
|
||||||
|
|
||||||
VectorType *VectorType::get(ArrayRef<unsigned> shape, Type *elementType) {
|
VectorType *VectorType::get(ArrayRef<int> shape, Type *elementType) {
|
||||||
assert(!shape.empty() && "vector types must have at least one dimension");
|
assert(!shape.empty() && "vector types must have at least one dimension");
|
||||||
assert((isa<FloatType>(elementType) || isa<IntegerType>(elementType)) &&
|
assert((isa<FloatType>(elementType) || isa<IntegerType>(elementType)) &&
|
||||||
"vectors elements must be primitives");
|
"vectors elements must be primitives");
|
||||||
|
assert(!std::any_of(shape.begin(), shape.end(), [](int i) {
|
||||||
|
return i < 0;
|
||||||
|
}) && "vector types must have static shape");
|
||||||
|
|
||||||
auto *context = elementType->getContext();
|
auto *context = elementType->getContext();
|
||||||
auto &impl = context->getImpl();
|
auto &impl = context->getImpl();
|
||||||
|
|
|
@ -792,7 +792,7 @@ bool ExtractElementOp::verify() const {
|
||||||
return emitOpError("index to extract_element must have 'index' type");
|
return emitOpError("index to extract_element must have 'index' type");
|
||||||
|
|
||||||
// Verify the # indices match if we have a ranked type.
|
// Verify the # indices match if we have a ranked type.
|
||||||
auto aggregateRank = aggregateType->getRankIfPresent();
|
auto aggregateRank = aggregateType->getRank();
|
||||||
if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1)
|
if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1)
|
||||||
return emitOpError("incorrect number of indices for extract_element");
|
return emitOpError("incorrect number of indices for extract_element");
|
||||||
|
|
||||||
|
|
|
@ -40,22 +40,33 @@ VectorOrTensorType::VectorOrTensorType(Kind kind, MLIRContext *context,
|
||||||
Type *elementType, unsigned subClassData)
|
Type *elementType, unsigned subClassData)
|
||||||
: Type(kind, context, subClassData), elementType(elementType) {}
|
: Type(kind, context, subClassData), elementType(elementType) {}
|
||||||
|
|
||||||
/// If this is ranked tensor or vector type, return the rank. If it is an
|
/// If this is ranked tensor or vector type, return the rank. If it is an
|
||||||
/// unranked tensor, return -1.
|
/// unranked tensor, return -1.
|
||||||
int VectorOrTensorType::getRankIfPresent() const {
|
int VectorOrTensorType::getRank() const {
|
||||||
switch (getKind()) {
|
switch (getKind()) {
|
||||||
default:
|
default:
|
||||||
llvm_unreachable("not a VectorOrTensorType");
|
llvm_unreachable("not a VectorOrTensorType");
|
||||||
case Kind::Vector:
|
case Kind::Vector:
|
||||||
return cast<VectorType>(this)->getRank();
|
return cast<VectorType>(this)->getShape().size();
|
||||||
case Kind::RankedTensor:
|
case Kind::RankedTensor:
|
||||||
return cast<RankedTensorType>(this)->getRank();
|
return cast<RankedTensorType>(this)->getShape().size();
|
||||||
case Kind::UnrankedTensor:
|
case Kind::UnrankedTensor:
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
VectorType::VectorType(ArrayRef<unsigned> shape, Type *elementType,
|
int VectorOrTensorType::getDimSize(unsigned i) const {
|
||||||
|
switch (getKind()) {
|
||||||
|
case Kind::Vector:
|
||||||
|
return cast<VectorType>(this)->getShape()[i];
|
||||||
|
case Kind::RankedTensor:
|
||||||
|
return cast<RankedTensorType>(this)->getShape()[i];
|
||||||
|
default:
|
||||||
|
llvm_unreachable("not a VectorOrTensorType");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
VectorType::VectorType(ArrayRef<int> shape, Type *elementType,
|
||||||
MLIRContext *context)
|
MLIRContext *context)
|
||||||
: VectorOrTensorType(Kind::Vector, context, elementType, shape.size()),
|
: VectorOrTensorType(Kind::Vector, context, elementType, shape.size()),
|
||||||
shapeElements(shape.data()) {}
|
shapeElements(shape.data()) {}
|
||||||
|
|
|
@ -372,13 +372,13 @@ VectorType *Parser::parseVectorType() {
|
||||||
if (getToken().isNot(Token::integer))
|
if (getToken().isNot(Token::integer))
|
||||||
return (emitError("expected dimension size in vector type"), nullptr);
|
return (emitError("expected dimension size in vector type"), nullptr);
|
||||||
|
|
||||||
SmallVector<unsigned, 4> dimensions;
|
SmallVector<int, 4> dimensions;
|
||||||
while (getToken().is(Token::integer)) {
|
while (getToken().is(Token::integer)) {
|
||||||
// Make sure this integer value is in bound and valid.
|
// Make sure this integer value is in bound and valid.
|
||||||
auto dimension = getToken().getUnsignedIntegerValue();
|
auto dimension = getToken().getUnsignedIntegerValue();
|
||||||
if (!dimension.hasValue())
|
if (!dimension.hasValue())
|
||||||
return (emitError("invalid dimension in vector type"), nullptr);
|
return (emitError("invalid dimension in vector type"), nullptr);
|
||||||
dimensions.push_back(dimension.getValue());
|
dimensions.push_back((int)dimension.getValue());
|
||||||
|
|
||||||
consumeToken(Token::integer);
|
consumeToken(Token::integer);
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue