forked from OSchip/llvm-project
parent
80b6bd24b3
commit
642f3e8847
|
@ -50,8 +50,10 @@ enum class TypeKind {
|
||||||
// Derived types.
|
// Derived types.
|
||||||
Function,
|
Function,
|
||||||
Vector,
|
Vector,
|
||||||
|
RankedTensor,
|
||||||
|
UnrankedTensor,
|
||||||
|
|
||||||
// TODO: Tensor / MemRef types.
|
// TODO: MemRef types.
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Instances of the Type class are immutable, uniqued, immortal, and owned by
|
/// Instances of the Type class are immutable, uniqued, immortal, and owned by
|
||||||
|
@ -198,7 +200,7 @@ private:
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
/// Vector types represent multi-dimensional SIMD vectors, and have fixed a
|
/// Vector types represent multi-dimensional SIMD vectors, and have a fixed
|
||||||
/// known constant shape with one or more dimension.
|
/// known constant shape with one or more dimension.
|
||||||
class VectorType : public Type {
|
class VectorType : public Type {
|
||||||
public:
|
public:
|
||||||
|
@ -225,6 +227,63 @@ private:
|
||||||
MLIRContext *context);
|
MLIRContext *context);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Tensor types represent multi-dimensional arrays, and have two variants:
|
||||||
|
/// RankedTensorType and UnrankedTensorType.
|
||||||
|
class TensorType : public Type {
|
||||||
|
public:
|
||||||
|
Type *getElementType() const { return elementType; }
|
||||||
|
|
||||||
|
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||||
|
static bool classof(const Type *type) {
|
||||||
|
return type->getKind() == TypeKind::RankedTensor ||
|
||||||
|
type->getKind() == TypeKind::UnrankedTensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
/// The type of each scalar element of the tensor.
|
||||||
|
Type *elementType;
|
||||||
|
|
||||||
|
TensorType(TypeKind kind, Type *elementType, MLIRContext *context);
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Ranked tensor types represent multi-dimensional arrays that have a shape
|
||||||
|
/// with a fixed number of dimensions. Each shape element can be a positive
|
||||||
|
/// integer or unknown (represented by any negative integer).
|
||||||
|
class RankedTensorType : public TensorType {
|
||||||
|
public:
|
||||||
|
static RankedTensorType *get(ArrayRef<int> shape,
|
||||||
|
Type *elementType);
|
||||||
|
|
||||||
|
ArrayRef<int> getShape() const {
|
||||||
|
return ArrayRef<int>(shapeElements, getSubclassData());
|
||||||
|
}
|
||||||
|
|
||||||
|
unsigned getRank() const { return getShape().size(); }
|
||||||
|
|
||||||
|
static bool classof(const Type *type) {
|
||||||
|
return type->getKind() == TypeKind::RankedTensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const int *shapeElements;
|
||||||
|
|
||||||
|
RankedTensorType(ArrayRef<int> shape, Type *elementType,
|
||||||
|
MLIRContext *context);
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Unranked tensor types represent multi-dimensional arrays that have an
|
||||||
|
/// unknown shape.
|
||||||
|
class UnrankedTensorType : public TensorType {
|
||||||
|
public:
|
||||||
|
static UnrankedTensorType *get(Type *elementType);
|
||||||
|
|
||||||
|
static bool classof(const Type *type) {
|
||||||
|
return type->getKind() == TypeKind::UnrankedTensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
UnrankedTensorType(Type *elementType, MLIRContext *context);
|
||||||
|
};
|
||||||
|
|
||||||
} // end namespace mlir
|
} // end namespace mlir
|
||||||
|
|
||||||
|
|
|
@ -60,6 +60,40 @@ struct VectorTypeKeyInfo : DenseMapInfo<VectorType*> {
|
||||||
return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
|
return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorType*> {
|
||||||
|
// Ranked tensors are uniqued based on their element type and shape.
|
||||||
|
using KeyTy = std::pair<Type*, ArrayRef<int>>;
|
||||||
|
using DenseMapInfo<RankedTensorType*>::getHashValue;
|
||||||
|
using DenseMapInfo<RankedTensorType*>::isEqual;
|
||||||
|
|
||||||
|
static unsigned getHashValue(KeyTy key) {
|
||||||
|
return hash_combine(DenseMapInfo<Type*>::getHashValue(key.first),
|
||||||
|
hash_combine_range(key.second.begin(),
|
||||||
|
key.second.end()));
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isEqual(const KeyTy &lhs, const RankedTensorType *rhs) {
|
||||||
|
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
||||||
|
return false;
|
||||||
|
return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
struct UnrankedTensorTypeKeyInfo : DenseMapInfo<UnrankedTensorType*> {
|
||||||
|
// Ranked tensors are uniqued based on their element type and shape.
|
||||||
|
using KeyTy = Type*;
|
||||||
|
using DenseMapInfo<UnrankedTensorType*>::getHashValue;
|
||||||
|
using DenseMapInfo<UnrankedTensorType*>::isEqual;
|
||||||
|
|
||||||
|
static unsigned getHashValue(KeyTy key) {
|
||||||
|
return hash_combine(DenseMapInfo<Type*>::getHashValue(key));
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool isEqual(const KeyTy &lhs, const UnrankedTensorType *rhs) {
|
||||||
|
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
||||||
|
return false;
|
||||||
|
return lhs == rhs->getElementType();
|
||||||
|
}
|
||||||
|
};
|
||||||
} // end anonymous namespace.
|
} // end anonymous namespace.
|
||||||
|
|
||||||
|
|
||||||
|
@ -82,6 +116,14 @@ public:
|
||||||
using VectorTypeSet = DenseSet<VectorType*, VectorTypeKeyInfo>;
|
using VectorTypeSet = DenseSet<VectorType*, VectorTypeKeyInfo>;
|
||||||
VectorTypeSet vectors;
|
VectorTypeSet vectors;
|
||||||
|
|
||||||
|
/// Ranked tensor type uniquing.
|
||||||
|
using RankedTensorTypeSet = DenseSet<RankedTensorType*,
|
||||||
|
RankedTensorTypeKeyInfo>;
|
||||||
|
RankedTensorTypeSet rankedTensors;
|
||||||
|
|
||||||
|
/// Unranked tensor type uniquing.
|
||||||
|
DenseMap<Type*, UnrankedTensorType*> unrankedTensors;
|
||||||
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/// Copy the specified array of elements into memory managed by our bump
|
/// Copy the specified array of elements into memory managed by our bump
|
||||||
|
@ -198,3 +240,69 @@ VectorType *VectorType::get(ArrayRef<unsigned> shape, Type *elementType) {
|
||||||
// Cache and return it.
|
// Cache and return it.
|
||||||
return *existing.first = result;
|
return *existing.first = result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TensorType::TensorType(TypeKind kind, Type *elementType, MLIRContext *context)
|
||||||
|
: Type(kind, context), elementType(elementType) {
|
||||||
|
assert((isa<PrimitiveType>(elementType) || isa<VectorType>(elementType)) &&
|
||||||
|
"tensor elements must be primitives or vectors");
|
||||||
|
assert(isa<TensorType>(this));
|
||||||
|
}
|
||||||
|
|
||||||
|
RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType,
|
||||||
|
MLIRContext *context)
|
||||||
|
: TensorType(TypeKind::RankedTensor, elementType, context),
|
||||||
|
shapeElements(shape.data()) {
|
||||||
|
setSubclassData(shape.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
UnrankedTensorType::UnrankedTensorType(Type *elementType, MLIRContext *context)
|
||||||
|
: TensorType(TypeKind::UnrankedTensor, elementType, context) {
|
||||||
|
}
|
||||||
|
|
||||||
|
RankedTensorType *RankedTensorType::get(ArrayRef<int> shape,
|
||||||
|
Type *elementType) {
|
||||||
|
auto *context = elementType->getContext();
|
||||||
|
auto &impl = context->getImpl();
|
||||||
|
|
||||||
|
// Look to see if we already have this ranked tensor type.
|
||||||
|
RankedTensorTypeKeyInfo::KeyTy key(elementType, shape);
|
||||||
|
auto existing = impl.rankedTensors.insert_as(nullptr, key);
|
||||||
|
|
||||||
|
// If we already have it, return that value.
|
||||||
|
if (!existing.second)
|
||||||
|
return *existing.first;
|
||||||
|
|
||||||
|
// On the first use, we allocate them into the bump pointer.
|
||||||
|
auto *result = impl.allocator.Allocate<RankedTensorType>();
|
||||||
|
|
||||||
|
// Copy the shape into the bump pointer.
|
||||||
|
shape = impl.copyInto(shape);
|
||||||
|
|
||||||
|
// Initialize the memory using placement new.
|
||||||
|
new (result) RankedTensorType(shape, elementType, context);
|
||||||
|
|
||||||
|
// Cache and return it.
|
||||||
|
return *existing.first = result;
|
||||||
|
}
|
||||||
|
|
||||||
|
UnrankedTensorType *UnrankedTensorType::get(Type *elementType) {
|
||||||
|
auto *context = elementType->getContext();
|
||||||
|
auto &impl = context->getImpl();
|
||||||
|
|
||||||
|
// Look to see if we already have this unranked tensor type.
|
||||||
|
auto existing = impl.unrankedTensors.insert({elementType, nullptr});
|
||||||
|
|
||||||
|
// If we already have it, return that value.
|
||||||
|
if (!existing.second)
|
||||||
|
return existing.first->second;
|
||||||
|
|
||||||
|
// On the first use, we allocate them into the bump pointer.
|
||||||
|
auto *result = impl.allocator.Allocate<UnrankedTensorType>();
|
||||||
|
|
||||||
|
// Initialize the memory using placement new.
|
||||||
|
new (result) UnrankedTensorType(elementType, context);
|
||||||
|
|
||||||
|
// Cache and return it.
|
||||||
|
return existing.first->second = result;
|
||||||
|
}
|
||||||
|
|
|
@ -60,6 +60,24 @@ void Type::print(raw_ostream &os) const {
|
||||||
os << *v->getElementType() << '>';
|
os << *v->getElementType() << '>';
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
case TypeKind::RankedTensor: {
|
||||||
|
auto *v = cast<RankedTensorType>(this);
|
||||||
|
os << "tensor<";
|
||||||
|
for (auto dim : v->getShape()) {
|
||||||
|
if (dim < 0)
|
||||||
|
os << '?';
|
||||||
|
else
|
||||||
|
os << dim;
|
||||||
|
os << 'x';
|
||||||
|
}
|
||||||
|
os << *v->getElementType() << '>';
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
case TypeKind::UnrankedTensor: {
|
||||||
|
auto *v = cast<UnrankedTensorType>(this);
|
||||||
|
os << "tensor<??" << *v->getElementType() << '>';
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -345,8 +345,9 @@ Type *Parser::parseTensorType() {
|
||||||
if (!consumeIf(Token::greater))
|
if (!consumeIf(Token::greater))
|
||||||
return (emitError("expected '>' in tensor type"), nullptr);
|
return (emitError("expected '>' in tensor type"), nullptr);
|
||||||
|
|
||||||
// FIXME: Add an IR representation for tensor types.
|
if (isUnranked)
|
||||||
return Type::getI1(context);
|
return UnrankedTensorType::get(elementType);
|
||||||
|
return RankedTensorType::get(dimensions, elementType);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Parse a memref type.
|
/// Parse a memref type.
|
||||||
|
|
|
@ -20,7 +20,7 @@ extfunc @missingReturn()
|
||||||
; CHECK: extfunc @vectors(vector<1xf32>, vector<2x4xf32>)
|
; CHECK: extfunc @vectors(vector<1xf32>, vector<2x4xf32>)
|
||||||
extfunc @vectors(vector<1 x f32>, vector<2x4xf32>)
|
extfunc @vectors(vector<1 x f32>, vector<2x4xf32>)
|
||||||
|
|
||||||
; CHECK: extfunc @tensors(i1, i1, i1, i1)
|
; CHECK: extfunc @tensors(tensor<??f32>, tensor<??vector<2x4xf32>>, tensor<1x?x4x?x?xint>, tensor<i8>)
|
||||||
extfunc @tensors(tensor<?? f32>, tensor<?? vector<2x4xf32>>,
|
extfunc @tensors(tensor<?? f32>, tensor<?? vector<2x4xf32>>,
|
||||||
tensor<1x?x4x?x?xint>, tensor<i8>)
|
tensor<1x?x4x?x?xint>, tensor<i8>)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue