Add tensor type.

PiperOrigin-RevId: 201830793
This commit is contained in:
MLIR Team 2018-06-23 18:09:09 -07:00 committed by jpienaar
parent 80b6bd24b3
commit 642f3e8847
5 changed files with 192 additions and 6 deletions

View File

@ -50,8 +50,10 @@ enum class TypeKind {
// Derived types. // Derived types.
Function, Function,
Vector, Vector,
RankedTensor,
UnrankedTensor,
// TODO: Tensor / MemRef types. // TODO: MemRef types.
}; };
/// Instances of the Type class are immutable, uniqued, immortal, and owned by /// Instances of the Type class are immutable, uniqued, immortal, and owned by
@ -198,7 +200,7 @@ private:
}; };
/// Vector types represent multi-dimensional SIMD vectors, and have fixed a /// Vector types represent multi-dimensional SIMD vectors, and have a fixed
/// known constant shape with one or more dimension. /// known constant shape with one or more dimension.
class VectorType : public Type { class VectorType : public Type {
public: public:
@ -225,6 +227,63 @@ private:
MLIRContext *context); MLIRContext *context);
}; };
/// Tensor types represent multi-dimensional arrays, and have two variants:
/// RankedTensorType and UnrankedTensorType.
class TensorType : public Type {
public:
Type *getElementType() const { return elementType; }
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Type *type) {
return type->getKind() == TypeKind::RankedTensor ||
type->getKind() == TypeKind::UnrankedTensor;
}
protected:
/// The type of each scalar element of the tensor.
Type *elementType;
TensorType(TypeKind kind, Type *elementType, MLIRContext *context);
};
/// Ranked tensor types represent multi-dimensional arrays that have a shape
/// with a fixed number of dimensions. Each shape element can be a positive
/// integer or unknown (represented by any negative integer).
class RankedTensorType : public TensorType {
public:
static RankedTensorType *get(ArrayRef<int> shape,
Type *elementType);
ArrayRef<int> getShape() const {
return ArrayRef<int>(shapeElements, getSubclassData());
}
unsigned getRank() const { return getShape().size(); }
static bool classof(const Type *type) {
return type->getKind() == TypeKind::RankedTensor;
}
private:
const int *shapeElements;
RankedTensorType(ArrayRef<int> shape, Type *elementType,
MLIRContext *context);
};
/// Unranked tensor types represent multi-dimensional arrays that have an
/// unknown shape.
class UnrankedTensorType : public TensorType {
public:
static UnrankedTensorType *get(Type *elementType);
static bool classof(const Type *type) {
return type->getKind() == TypeKind::UnrankedTensor;
}
private:
UnrankedTensorType(Type *elementType, MLIRContext *context);
};
} // end namespace mlir } // end namespace mlir

View File

@ -60,6 +60,40 @@ struct VectorTypeKeyInfo : DenseMapInfo<VectorType*> {
return lhs == KeyTy(rhs->getElementType(), rhs->getShape()); return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
} }
}; };
struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorType*> {
// Ranked tensors are uniqued based on their element type and shape.
using KeyTy = std::pair<Type*, ArrayRef<int>>;
using DenseMapInfo<RankedTensorType*>::getHashValue;
using DenseMapInfo<RankedTensorType*>::isEqual;
static unsigned getHashValue(KeyTy key) {
return hash_combine(DenseMapInfo<Type*>::getHashValue(key.first),
hash_combine_range(key.second.begin(),
key.second.end()));
}
static bool isEqual(const KeyTy &lhs, const RankedTensorType *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
}
};
struct UnrankedTensorTypeKeyInfo : DenseMapInfo<UnrankedTensorType*> {
// Ranked tensors are uniqued based on their element type and shape.
using KeyTy = Type*;
using DenseMapInfo<UnrankedTensorType*>::getHashValue;
using DenseMapInfo<UnrankedTensorType*>::isEqual;
static unsigned getHashValue(KeyTy key) {
return hash_combine(DenseMapInfo<Type*>::getHashValue(key));
}
static bool isEqual(const KeyTy &lhs, const UnrankedTensorType *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
return lhs == rhs->getElementType();
}
};
} // end anonymous namespace. } // end anonymous namespace.
@ -82,6 +116,14 @@ public:
using VectorTypeSet = DenseSet<VectorType*, VectorTypeKeyInfo>; using VectorTypeSet = DenseSet<VectorType*, VectorTypeKeyInfo>;
VectorTypeSet vectors; VectorTypeSet vectors;
/// Ranked tensor type uniquing.
using RankedTensorTypeSet = DenseSet<RankedTensorType*,
RankedTensorTypeKeyInfo>;
RankedTensorTypeSet rankedTensors;
/// Unranked tensor type uniquing.
DenseMap<Type*, UnrankedTensorType*> unrankedTensors;
public: public:
/// Copy the specified array of elements into memory managed by our bump /// Copy the specified array of elements into memory managed by our bump
@ -198,3 +240,69 @@ VectorType *VectorType::get(ArrayRef<unsigned> shape, Type *elementType) {
// Cache and return it. // Cache and return it.
return *existing.first = result; return *existing.first = result;
} }
TensorType::TensorType(TypeKind kind, Type *elementType, MLIRContext *context)
: Type(kind, context), elementType(elementType) {
assert((isa<PrimitiveType>(elementType) || isa<VectorType>(elementType)) &&
"tensor elements must be primitives or vectors");
assert(isa<TensorType>(this));
}
RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType,
MLIRContext *context)
: TensorType(TypeKind::RankedTensor, elementType, context),
shapeElements(shape.data()) {
setSubclassData(shape.size());
}
UnrankedTensorType::UnrankedTensorType(Type *elementType, MLIRContext *context)
: TensorType(TypeKind::UnrankedTensor, elementType, context) {
}
RankedTensorType *RankedTensorType::get(ArrayRef<int> shape,
Type *elementType) {
auto *context = elementType->getContext();
auto &impl = context->getImpl();
// Look to see if we already have this ranked tensor type.
RankedTensorTypeKeyInfo::KeyTy key(elementType, shape);
auto existing = impl.rankedTensors.insert_as(nullptr, key);
// If we already have it, return that value.
if (!existing.second)
return *existing.first;
// On the first use, we allocate them into the bump pointer.
auto *result = impl.allocator.Allocate<RankedTensorType>();
// Copy the shape into the bump pointer.
shape = impl.copyInto(shape);
// Initialize the memory using placement new.
new (result) RankedTensorType(shape, elementType, context);
// Cache and return it.
return *existing.first = result;
}
UnrankedTensorType *UnrankedTensorType::get(Type *elementType) {
auto *context = elementType->getContext();
auto &impl = context->getImpl();
// Look to see if we already have this unranked tensor type.
auto existing = impl.unrankedTensors.insert({elementType, nullptr});
// If we already have it, return that value.
if (!existing.second)
return existing.first->second;
// On the first use, we allocate them into the bump pointer.
auto *result = impl.allocator.Allocate<UnrankedTensorType>();
// Initialize the memory using placement new.
new (result) UnrankedTensorType(elementType, context);
// Cache and return it.
return existing.first->second = result;
}

View File

@ -60,6 +60,24 @@ void Type::print(raw_ostream &os) const {
os << *v->getElementType() << '>'; os << *v->getElementType() << '>';
return; return;
} }
case TypeKind::RankedTensor: {
auto *v = cast<RankedTensorType>(this);
os << "tensor<";
for (auto dim : v->getShape()) {
if (dim < 0)
os << '?';
else
os << dim;
os << 'x';
}
os << *v->getElementType() << '>';
return;
}
case TypeKind::UnrankedTensor: {
auto *v = cast<UnrankedTensorType>(this);
os << "tensor<??" << *v->getElementType() << '>';
return;
}
} }
} }

View File

@ -345,8 +345,9 @@ Type *Parser::parseTensorType() {
if (!consumeIf(Token::greater)) if (!consumeIf(Token::greater))
return (emitError("expected '>' in tensor type"), nullptr); return (emitError("expected '>' in tensor type"), nullptr);
// FIXME: Add an IR representation for tensor types. if (isUnranked)
return Type::getI1(context); return UnrankedTensorType::get(elementType);
return RankedTensorType::get(dimensions, elementType);
} }
/// Parse a memref type. /// Parse a memref type.

View File

@ -20,7 +20,7 @@ extfunc @missingReturn()
; CHECK: extfunc @vectors(vector<1xf32>, vector<2x4xf32>) ; CHECK: extfunc @vectors(vector<1xf32>, vector<2x4xf32>)
extfunc @vectors(vector<1 x f32>, vector<2x4xf32>) extfunc @vectors(vector<1 x f32>, vector<2x4xf32>)
; CHECK: extfunc @tensors(i1, i1, i1, i1) ; CHECK: extfunc @tensors(tensor<??f32>, tensor<??vector<2x4xf32>>, tensor<1x?x4x?x?xint>, tensor<i8>)
extfunc @tensors(tensor<?? f32>, tensor<?? vector<2x4xf32>>, extfunc @tensors(tensor<?? f32>, tensor<?? vector<2x4xf32>>,
tensor<1x?x4x?x?xint>, tensor<i8>) tensor<1x?x4x?x?xint>, tensor<i8>)