forked from OSchip/llvm-project
parent
80b6bd24b3
commit
642f3e8847
|
@ -50,8 +50,10 @@ enum class TypeKind {
|
|||
// Derived types.
|
||||
Function,
|
||||
Vector,
|
||||
RankedTensor,
|
||||
UnrankedTensor,
|
||||
|
||||
// TODO: Tensor / MemRef types.
|
||||
// TODO: MemRef types.
|
||||
};
|
||||
|
||||
/// Instances of the Type class are immutable, uniqued, immortal, and owned by
|
||||
|
@ -198,7 +200,7 @@ private:
|
|||
};
|
||||
|
||||
|
||||
/// Vector types represent multi-dimensional SIMD vectors, and have fixed a
|
||||
/// Vector types represent multi-dimensional SIMD vectors, and have a fixed
|
||||
/// known constant shape with one or more dimension.
|
||||
class VectorType : public Type {
|
||||
public:
|
||||
|
@ -225,6 +227,63 @@ private:
|
|||
MLIRContext *context);
|
||||
};
|
||||
|
||||
/// Tensor types represent multi-dimensional arrays, and have two variants:
|
||||
/// RankedTensorType and UnrankedTensorType.
|
||||
class TensorType : public Type {
|
||||
public:
|
||||
Type *getElementType() const { return elementType; }
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const Type *type) {
|
||||
return type->getKind() == TypeKind::RankedTensor ||
|
||||
type->getKind() == TypeKind::UnrankedTensor;
|
||||
}
|
||||
|
||||
protected:
|
||||
/// The type of each scalar element of the tensor.
|
||||
Type *elementType;
|
||||
|
||||
TensorType(TypeKind kind, Type *elementType, MLIRContext *context);
|
||||
};
|
||||
|
||||
/// Ranked tensor types represent multi-dimensional arrays that have a shape
|
||||
/// with a fixed number of dimensions. Each shape element can be a positive
|
||||
/// integer or unknown (represented by any negative integer).
|
||||
class RankedTensorType : public TensorType {
|
||||
public:
|
||||
static RankedTensorType *get(ArrayRef<int> shape,
|
||||
Type *elementType);
|
||||
|
||||
ArrayRef<int> getShape() const {
|
||||
return ArrayRef<int>(shapeElements, getSubclassData());
|
||||
}
|
||||
|
||||
unsigned getRank() const { return getShape().size(); }
|
||||
|
||||
static bool classof(const Type *type) {
|
||||
return type->getKind() == TypeKind::RankedTensor;
|
||||
}
|
||||
|
||||
private:
|
||||
const int *shapeElements;
|
||||
|
||||
RankedTensorType(ArrayRef<int> shape, Type *elementType,
|
||||
MLIRContext *context);
|
||||
};
|
||||
|
||||
/// Unranked tensor types represent multi-dimensional arrays that have an
|
||||
/// unknown shape.
|
||||
class UnrankedTensorType : public TensorType {
|
||||
public:
|
||||
static UnrankedTensorType *get(Type *elementType);
|
||||
|
||||
static bool classof(const Type *type) {
|
||||
return type->getKind() == TypeKind::UnrankedTensor;
|
||||
}
|
||||
|
||||
private:
|
||||
UnrankedTensorType(Type *elementType, MLIRContext *context);
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
|
|
|
@ -60,6 +60,40 @@ struct VectorTypeKeyInfo : DenseMapInfo<VectorType*> {
|
|||
return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
|
||||
}
|
||||
};
|
||||
struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorType*> {
|
||||
// Ranked tensors are uniqued based on their element type and shape.
|
||||
using KeyTy = std::pair<Type*, ArrayRef<int>>;
|
||||
using DenseMapInfo<RankedTensorType*>::getHashValue;
|
||||
using DenseMapInfo<RankedTensorType*>::isEqual;
|
||||
|
||||
static unsigned getHashValue(KeyTy key) {
|
||||
return hash_combine(DenseMapInfo<Type*>::getHashValue(key.first),
|
||||
hash_combine_range(key.second.begin(),
|
||||
key.second.end()));
|
||||
}
|
||||
|
||||
static bool isEqual(const KeyTy &lhs, const RankedTensorType *rhs) {
|
||||
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
||||
return false;
|
||||
return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
|
||||
}
|
||||
};
|
||||
struct UnrankedTensorTypeKeyInfo : DenseMapInfo<UnrankedTensorType*> {
|
||||
// Ranked tensors are uniqued based on their element type and shape.
|
||||
using KeyTy = Type*;
|
||||
using DenseMapInfo<UnrankedTensorType*>::getHashValue;
|
||||
using DenseMapInfo<UnrankedTensorType*>::isEqual;
|
||||
|
||||
static unsigned getHashValue(KeyTy key) {
|
||||
return hash_combine(DenseMapInfo<Type*>::getHashValue(key));
|
||||
}
|
||||
|
||||
static bool isEqual(const KeyTy &lhs, const UnrankedTensorType *rhs) {
|
||||
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
||||
return false;
|
||||
return lhs == rhs->getElementType();
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
||||
|
||||
|
@ -82,6 +116,14 @@ public:
|
|||
using VectorTypeSet = DenseSet<VectorType*, VectorTypeKeyInfo>;
|
||||
VectorTypeSet vectors;
|
||||
|
||||
/// Ranked tensor type uniquing.
|
||||
using RankedTensorTypeSet = DenseSet<RankedTensorType*,
|
||||
RankedTensorTypeKeyInfo>;
|
||||
RankedTensorTypeSet rankedTensors;
|
||||
|
||||
/// Unranked tensor type uniquing.
|
||||
DenseMap<Type*, UnrankedTensorType*> unrankedTensors;
|
||||
|
||||
|
||||
public:
|
||||
/// Copy the specified array of elements into memory managed by our bump
|
||||
|
@ -198,3 +240,69 @@ VectorType *VectorType::get(ArrayRef<unsigned> shape, Type *elementType) {
|
|||
// Cache and return it.
|
||||
return *existing.first = result;
|
||||
}
|
||||
|
||||
|
||||
TensorType::TensorType(TypeKind kind, Type *elementType, MLIRContext *context)
|
||||
: Type(kind, context), elementType(elementType) {
|
||||
assert((isa<PrimitiveType>(elementType) || isa<VectorType>(elementType)) &&
|
||||
"tensor elements must be primitives or vectors");
|
||||
assert(isa<TensorType>(this));
|
||||
}
|
||||
|
||||
RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType,
|
||||
MLIRContext *context)
|
||||
: TensorType(TypeKind::RankedTensor, elementType, context),
|
||||
shapeElements(shape.data()) {
|
||||
setSubclassData(shape.size());
|
||||
}
|
||||
|
||||
UnrankedTensorType::UnrankedTensorType(Type *elementType, MLIRContext *context)
|
||||
: TensorType(TypeKind::UnrankedTensor, elementType, context) {
|
||||
}
|
||||
|
||||
RankedTensorType *RankedTensorType::get(ArrayRef<int> shape,
|
||||
Type *elementType) {
|
||||
auto *context = elementType->getContext();
|
||||
auto &impl = context->getImpl();
|
||||
|
||||
// Look to see if we already have this ranked tensor type.
|
||||
RankedTensorTypeKeyInfo::KeyTy key(elementType, shape);
|
||||
auto existing = impl.rankedTensors.insert_as(nullptr, key);
|
||||
|
||||
// If we already have it, return that value.
|
||||
if (!existing.second)
|
||||
return *existing.first;
|
||||
|
||||
// On the first use, we allocate them into the bump pointer.
|
||||
auto *result = impl.allocator.Allocate<RankedTensorType>();
|
||||
|
||||
// Copy the shape into the bump pointer.
|
||||
shape = impl.copyInto(shape);
|
||||
|
||||
// Initialize the memory using placement new.
|
||||
new (result) RankedTensorType(shape, elementType, context);
|
||||
|
||||
// Cache and return it.
|
||||
return *existing.first = result;
|
||||
}
|
||||
|
||||
UnrankedTensorType *UnrankedTensorType::get(Type *elementType) {
|
||||
auto *context = elementType->getContext();
|
||||
auto &impl = context->getImpl();
|
||||
|
||||
// Look to see if we already have this unranked tensor type.
|
||||
auto existing = impl.unrankedTensors.insert({elementType, nullptr});
|
||||
|
||||
// If we already have it, return that value.
|
||||
if (!existing.second)
|
||||
return existing.first->second;
|
||||
|
||||
// On the first use, we allocate them into the bump pointer.
|
||||
auto *result = impl.allocator.Allocate<UnrankedTensorType>();
|
||||
|
||||
// Initialize the memory using placement new.
|
||||
new (result) UnrankedTensorType(elementType, context);
|
||||
|
||||
// Cache and return it.
|
||||
return existing.first->second = result;
|
||||
}
|
||||
|
|
|
@ -60,6 +60,24 @@ void Type::print(raw_ostream &os) const {
|
|||
os << *v->getElementType() << '>';
|
||||
return;
|
||||
}
|
||||
case TypeKind::RankedTensor: {
|
||||
auto *v = cast<RankedTensorType>(this);
|
||||
os << "tensor<";
|
||||
for (auto dim : v->getShape()) {
|
||||
if (dim < 0)
|
||||
os << '?';
|
||||
else
|
||||
os << dim;
|
||||
os << 'x';
|
||||
}
|
||||
os << *v->getElementType() << '>';
|
||||
return;
|
||||
}
|
||||
case TypeKind::UnrankedTensor: {
|
||||
auto *v = cast<UnrankedTensorType>(this);
|
||||
os << "tensor<??" << *v->getElementType() << '>';
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -345,8 +345,9 @@ Type *Parser::parseTensorType() {
|
|||
if (!consumeIf(Token::greater))
|
||||
return (emitError("expected '>' in tensor type"), nullptr);
|
||||
|
||||
// FIXME: Add an IR representation for tensor types.
|
||||
return Type::getI1(context);
|
||||
if (isUnranked)
|
||||
return UnrankedTensorType::get(elementType);
|
||||
return RankedTensorType::get(dimensions, elementType);
|
||||
}
|
||||
|
||||
/// Parse a memref type.
|
||||
|
|
|
@ -20,7 +20,7 @@ extfunc @missingReturn()
|
|||
; CHECK: extfunc @vectors(vector<1xf32>, vector<2x4xf32>)
|
||||
extfunc @vectors(vector<1 x f32>, vector<2x4xf32>)
|
||||
|
||||
; CHECK: extfunc @tensors(i1, i1, i1, i1)
|
||||
; CHECK: extfunc @tensors(tensor<??f32>, tensor<??vector<2x4xf32>>, tensor<1x?x4x?x?xint>, tensor<i8>)
|
||||
extfunc @tensors(tensor<?? f32>, tensor<?? vector<2x4xf32>>,
|
||||
tensor<1x?x4x?x?xint>, tensor<i8>)
|
||||
|
||||
|
|
Loading…
Reference in New Issue