From 84a0c4026169f96f9c597b63acaa28c5e5b58104 Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Tue, 9 Oct 2018 16:49:39 -0700 Subject: [PATCH] Support `getShape`, `hasStaticShape` and `getDimSize` methods for all the Vector and Tensor Types. PiperOrigin-RevId: 216447553 --- mlir/include/mlir/IR/Builders.h | 2 +- mlir/include/mlir/IR/Types.h | 42 ++++++++++++++++++++------------- mlir/lib/IR/Builders.cpp | 3 +-- mlir/lib/IR/MLIRContext.cpp | 7 ++++-- mlir/lib/IR/StandardOps.cpp | 2 +- mlir/lib/IR/Types.cpp | 21 +++++++++++++---- mlir/lib/Parser/Parser.cpp | 4 ++-- 7 files changed, 51 insertions(+), 30 deletions(-) diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 0d56db4b9e75..2c1b6ddc726d 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -86,7 +86,7 @@ public: MemRefType *getMemRefType(ArrayRef shape, Type *elementType, ArrayRef affineMapComposition = {}, unsigned memorySpace = 0); - VectorType *getVectorType(ArrayRef shape, Type *elementType); + VectorType *getVectorType(ArrayRef shape, Type *elementType); RankedTensorType *getTensorType(ArrayRef shape, Type *elementType); UnrankedTensorType *getTensorType(Type *elementType); diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h index 65c355a90429..ee34e203d1c4 100644 --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -293,9 +293,25 @@ class VectorOrTensorType : public Type { public: Type *getElementType() const { return elementType; } - /// If this is ranked tensor or vector type, return the rank. If it is an + /// If this is ranked tensor or vector type, return the rank. If it is an /// unranked tensor, return -1. - int getRankIfPresent() const; + int getRank() const; + + /// If this is ranked tensor or vector type, return the shape. If it is an + /// unranked tensor, return an empty array. + ArrayRef getShape() const; + + /// If any dimension has unknown size (<0), it doesn't have static shape. + /// If all dimensions has known size (>= 0), it has static shape. + bool hasStaticShape() const { + auto dims = getShape(); + return !std::any_of(dims.begin(), dims.end(), [](int i) { return i < 0; }); + } + + /// If this is ranked tensor or vector type, return the size of the specified + /// dimension. It aborts if the tensor is unranked (this can be checked by + /// the getRank call method). + int getDimSize(unsigned i) const; /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(const Type *type) { @@ -315,27 +331,22 @@ public: /// known constant shape with one or more dimension. class VectorType : public VectorOrTensorType { public: - static VectorType *get(ArrayRef shape, Type *elementType); + static VectorType *get(ArrayRef shape, Type *elementType); - unsigned getRank() const { return getSubclassData(); } - - ArrayRef getShape() const { - return ArrayRef(shapeElements, getSubclassData()); + ArrayRef getShape() const { + return ArrayRef(shapeElements, getSubclassData()); } - /// Return the size of the specified dimension. - unsigned getDimSize(unsigned i) const { return getShape()[i]; } - /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(const Type *type) { return type->getKind() == Kind::Vector; } private: - const unsigned *shapeElements; + const int *shapeElements; Type *elementType; - VectorType(ArrayRef shape, Type *elementType, MLIRContext *context); + VectorType(ArrayRef shape, Type *elementType, MLIRContext *context); ~VectorType() = delete; }; @@ -363,15 +374,10 @@ public: static RankedTensorType *get(ArrayRef shape, Type *elementType); - unsigned getRank() const { return getSubclassData(); } - ArrayRef getShape() const { return ArrayRef(shapeElements, getSubclassData()); } - /// Return the size of the specified dimension, or -1 if unspecified. - int getDimSize(unsigned i) const { return getShape()[i]; } - static bool classof(const Type *type) { return type->getKind() == Kind::RankedTensor; } @@ -390,6 +396,8 @@ class UnrankedTensorType : public TensorType { public: static UnrankedTensorType *get(Type *elementType); + ArrayRef getShape() const { return ArrayRef(); } + static bool classof(const Type *type) { return type->getKind() == Kind::UnrankedTensor; } diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 449c4dc822f1..59299eec2d58 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -95,8 +95,7 @@ MemRefType *Builder::getMemRefType(ArrayRef shape, Type *elementType, return MemRefType::get(shape, elementType, affineMapComposition, memorySpace); } -VectorType *Builder::getVectorType(ArrayRef shape, - Type *elementType) { +VectorType *Builder::getVectorType(ArrayRef shape, Type *elementType) { return VectorType::get(shape, elementType); } diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index fceb0760c22c..7501c7d6b57c 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -85,7 +85,7 @@ struct AffineMapKeyInfo : DenseMapInfo { struct VectorTypeKeyInfo : DenseMapInfo { // Vectors are uniqued based on their element type and shape. - using KeyTy = std::pair>; + using KeyTy = std::pair>; using DenseMapInfo::getHashValue; using DenseMapInfo::isEqual; @@ -484,10 +484,13 @@ FunctionType *FunctionType::get(ArrayRef inputs, return *existing.first = result; } -VectorType *VectorType::get(ArrayRef shape, Type *elementType) { +VectorType *VectorType::get(ArrayRef shape, Type *elementType) { assert(!shape.empty() && "vector types must have at least one dimension"); assert((isa(elementType) || isa(elementType)) && "vectors elements must be primitives"); + assert(!std::any_of(shape.begin(), shape.end(), [](int i) { + return i < 0; + }) && "vector types must have static shape"); auto *context = elementType->getContext(); auto &impl = context->getImpl(); diff --git a/mlir/lib/IR/StandardOps.cpp b/mlir/lib/IR/StandardOps.cpp index 9524c3056a70..1099dc45ab74 100644 --- a/mlir/lib/IR/StandardOps.cpp +++ b/mlir/lib/IR/StandardOps.cpp @@ -792,7 +792,7 @@ bool ExtractElementOp::verify() const { return emitOpError("index to extract_element must have 'index' type"); // Verify the # indices match if we have a ranked type. - auto aggregateRank = aggregateType->getRankIfPresent(); + auto aggregateRank = aggregateType->getRank(); if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1) return emitOpError("incorrect number of indices for extract_element"); diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp index 1ff5de40b864..12880a3b1aef 100644 --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -40,22 +40,33 @@ VectorOrTensorType::VectorOrTensorType(Kind kind, MLIRContext *context, Type *elementType, unsigned subClassData) : Type(kind, context, subClassData), elementType(elementType) {} -/// If this is ranked tensor or vector type, return the rank. If it is an +/// If this is ranked tensor or vector type, return the rank. If it is an /// unranked tensor, return -1. -int VectorOrTensorType::getRankIfPresent() const { +int VectorOrTensorType::getRank() const { switch (getKind()) { default: llvm_unreachable("not a VectorOrTensorType"); case Kind::Vector: - return cast(this)->getRank(); + return cast(this)->getShape().size(); case Kind::RankedTensor: - return cast(this)->getRank(); + return cast(this)->getShape().size(); case Kind::UnrankedTensor: return -1; } } -VectorType::VectorType(ArrayRef shape, Type *elementType, +int VectorOrTensorType::getDimSize(unsigned i) const { + switch (getKind()) { + case Kind::Vector: + return cast(this)->getShape()[i]; + case Kind::RankedTensor: + return cast(this)->getShape()[i]; + default: + llvm_unreachable("not a VectorOrTensorType"); + } +} + +VectorType::VectorType(ArrayRef shape, Type *elementType, MLIRContext *context) : VectorOrTensorType(Kind::Vector, context, elementType, shape.size()), shapeElements(shape.data()) {} diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 5e0f5f31313e..727c86e1caa7 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -372,13 +372,13 @@ VectorType *Parser::parseVectorType() { if (getToken().isNot(Token::integer)) return (emitError("expected dimension size in vector type"), nullptr); - SmallVector dimensions; + SmallVector dimensions; while (getToken().is(Token::integer)) { // Make sure this integer value is in bound and valid. auto dimension = getToken().getUnsignedIntegerValue(); if (!dimension.hasValue()) return (emitError("invalid dimension in vector type"), nullptr); - dimensions.push_back(dimension.getValue()); + dimensions.push_back((int)dimension.getValue()); consumeToken(Token::integer);